Spaces:
Runtime error
Runtime error
| from collections import OrderedDict | |
| from importlib import import_module | |
| import os | |
| import random | |
| import re | |
| import warnings | |
| from typing import Union, Any | |
| import numpy as np | |
| import torch | |
| from torch import distributed as dist | |
| import torch.nn as nn | |
| from torch.nn.parallel import DataParallel, DistributedDataParallel | |
| from .dist_util import get_dist_info | |
| MODULE_WRAPPERS = [DataParallel, DistributedDataParallel] | |
| MODEL_ABBR_MAP = { | |
| 's': 'small', | |
| 'b': 'base', | |
| 'l': 'large', | |
| 'h': 'huge' | |
| } | |
| def infer_dataset_by_path(model_path: str) -> Union[str, Any]: | |
| model = os.path.basename(model_path) | |
| p = r'-([a-zA-Z0-9_]+)\.[pth, onnx, engine]' | |
| m = re.search(p, model) | |
| if not m: | |
| raise ValueError('Could not infer the dataset from ckpt name, specify it') | |
| return m.group(1) | |
| def dyn_model_import(dataset: str, model: str): | |
| config_name = f'configs.ViTPose_{dataset}' | |
| imp = import_module(config_name) | |
| model = f'model_{MODEL_ABBR_MAP[model]}' | |
| return getattr(imp, model) | |
| def init_random_seed(seed=None, device='cuda'): | |
| """Initialize random seed. | |
| If the seed is not set, the seed will be automatically randomized, | |
| and then broadcast to all processes to prevent some potential bugs. | |
| Args: | |
| seed (int, Optional): The seed. Default to None. | |
| device (str): The device where the seed will be put on. | |
| Default to 'cuda'. | |
| Returns: | |
| int: Seed to be used. | |
| """ | |
| if seed is not None: | |
| return seed | |
| # Make sure all ranks share the same random seed to prevent | |
| # some potential bugs. Please refer to | |
| # https://github.com/open-mmlab/mmdetection/issues/6339 | |
| rank, world_size = get_dist_info() | |
| seed = np.random.randint(2**31) | |
| if world_size == 1: | |
| return seed | |
| if rank == 0: | |
| random_num = torch.tensor(seed, dtype=torch.int32, device=device) | |
| else: | |
| random_num = torch.tensor(0, dtype=torch.int32, device=device) | |
| dist.broadcast(random_num, src=0) | |
| return random_num.item() | |
| def set_random_seed(seed: int, | |
| deterministic: bool = False, | |
| use_rank_shift: bool = False) -> None: | |
| """Set random seed. | |
| Args: | |
| seed (int): Seed to be used. | |
| deterministic (bool): Whether to set the deterministic option for | |
| CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` | |
| to True and `torch.backends.cudnn.benchmark` to False. | |
| Default: False. | |
| rank_shift (bool): Whether to add rank number to the random seed to | |
| have different random seed in different threads. Default: False. | |
| """ | |
| if use_rank_shift: | |
| rank, _ = get_dist_info() | |
| seed += rank | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| os.environ['PYTHONHASHSEED'] = str(seed) | |
| if deterministic: | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| def is_module_wrapper(module: nn.Module) -> bool: | |
| """ Check if module wrrapper exists recursively """ | |
| def is_module_in_wrapper(module, module_wrapper): | |
| module_wrappers = tuple(module_wrapper.module_dict.values()) | |
| if isinstance(module, module_wrappers): | |
| return True | |
| for child in module_wrapper.children.values(): | |
| if is_module_in_wrapper(module, child): | |
| return True | |
| return is_module_in_wrapper(module, MODULE_WRAPPERS) | |
| def load_state_dict(module, state_dict, strict=False, logger=None): | |
| """Load state_dict to a module. | |
| This method is modified from :meth:`torch.nn.Module.load_state_dict`. | |
| Default value for ``strict`` is set to ``False`` and the message for | |
| param mismatch will be shown even if strict is False. | |
| Args: | |
| module (Module): Module that receives the state_dict. | |
| state_dict (OrderedDict): Weights. | |
| strict (bool): whether to strictly enforce that the keys | |
| in :attr:`state_dict` match the keys returned by this module's | |
| :meth:`~torch.nn.Module.state_dict` function. Default: ``False``. | |
| logger (:obj:`logging.Logger`, optional): Logger to log the error | |
| message. If not specified, print function will be used. | |
| """ | |
| unexpected_keys = [] | |
| all_missing_keys = [] | |
| err_msg = [] | |
| metadata = getattr(state_dict, '_metadata', None) | |
| state_dict = state_dict.copy() | |
| if metadata is not None: | |
| state_dict._metadata = metadata | |
| # use _load_from_state_dict to enable checkpoint version control | |
| def load(module, prefix=''): | |
| # recursively check parallel module in case that the model has a | |
| # complicated structure, e.g., nn.Module(nn.Module(DDP)) | |
| if is_module_wrapper(module): | |
| module = module.module | |
| local_metadata = {} if metadata is None else metadata.get( | |
| prefix[:-1], {}) | |
| module._load_from_state_dict(state_dict, prefix, local_metadata, True, | |
| all_missing_keys, unexpected_keys, | |
| err_msg) | |
| for name, child in module._modules.items(): | |
| if child is not None: | |
| load(child, prefix + name + '.') | |
| load(module) | |
| load = None # break load->load reference cycle | |
| # ignore "num_batches_tracked" of BN layers | |
| missing_keys = [ | |
| key for key in all_missing_keys if 'num_batches_tracked' not in key | |
| ] | |
| if unexpected_keys: | |
| err_msg.append('unexpected key in source ' | |
| f'state_dict: {", ".join(unexpected_keys)}\n') | |
| if missing_keys: | |
| err_msg.append( | |
| f'missing keys in source state_dict: {", ".join(missing_keys)}\n') | |
| rank, _ = get_dist_info() | |
| if len(err_msg) > 0 and rank == 0: | |
| err_msg.insert( | |
| 0, 'The model and loaded state dict do not match exactly\n') | |
| err_msg = '\n'.join(err_msg) | |
| if strict: | |
| raise RuntimeError(err_msg) | |
| elif logger is not None: | |
| logger.warning(err_msg) | |
| else: | |
| print(err_msg) | |
| def load_checkpoint(model, | |
| filename, | |
| map_location='cpu', | |
| strict=False, | |
| logger=None): | |
| """Load checkpoint from a file or URI. | |
| Args: | |
| model (Module): Module to load checkpoint. | |
| filename (str): Accept local filepath, URL, ``torchvision://xxx``, | |
| ``open-mmlab://xxx``. | |
| map_location (str): Same as :func:`torch.load`. | |
| strict (bool): Whether to allow different params for the model and | |
| checkpoint. | |
| logger (:mod:`logging.Logger` or None): The logger for error message. | |
| Returns: | |
| dict or OrderedDict: The loaded checkpoint. | |
| """ | |
| checkpoint = torch.load(filename, map_location=map_location) | |
| # OrderedDict is a subclass of dict | |
| if not isinstance(checkpoint, dict): | |
| raise RuntimeError( | |
| f'No state_dict found in checkpoint file {filename}') | |
| # get state_dict from checkpoint | |
| if 'state_dict' in checkpoint: | |
| state_dict_tmp = checkpoint['state_dict'] | |
| else: | |
| state_dict_tmp = checkpoint | |
| state_dict = OrderedDict() | |
| # strip prefix of state_dict | |
| for k, v in state_dict_tmp.items(): | |
| if k.startswith('module.backbone.'): | |
| state_dict[k[16:]] = v | |
| elif k.startswith('module.'): | |
| state_dict[k[7:]] = v | |
| elif k.startswith('backbone.'): | |
| state_dict[k[9:]] = v | |
| else: | |
| state_dict[k] = v | |
| # load state_dict | |
| load_state_dict(model, state_dict, strict, logger) | |
| return checkpoint | |
| def resize(input, | |
| size=None, | |
| scale_factor=None, | |
| mode='nearest', | |
| align_corners=None, | |
| warning=True): | |
| if warning: | |
| if size is not None and align_corners: | |
| input_h, input_w = int(input.shape[0]), int(input.shape[1]) | |
| output_h, output_w = int(size[0]), int(size[1]) | |
| if output_h > input_h or output_w > output_h: | |
| if ((output_h > 1 and output_w > 1 and input_h > 1 | |
| and input_w > 1) and (output_h - 1) % (input_h - 1) | |
| and (output_w - 1) % (input_w - 1)): | |
| warnings.warn( | |
| f'When align_corners={align_corners}, ' | |
| 'the output would more aligned if ' | |
| f'input size {(input_h, input_w)} is `x+1` and ' | |
| f'out size {(output_h, output_w)} is `nx+1`') | |
| def constant_init(module: nn.Module, val: float, bias: float = 0) -> None: | |
| if hasattr(module, 'weight') and module.weight is not None: | |
| nn.init.constant_(module.weight, val) | |
| if hasattr(module, 'bias') and module.bias is not None: | |
| nn.init.constant_(module.bias, bias) | |
| def normal_init(module: nn.Module, | |
| mean: float = 0, | |
| std: float = 1, | |
| bias: float = 0) -> None: | |
| if hasattr(module, 'weight') and module.weight is not None: | |
| nn.init.normal_(module.weight, mean, std) | |
| if hasattr(module, 'bias') and module.bias is not None: | |
| nn.init.constant_(module.bias, bias) | |