Spaces:
Runtime error
Runtime error
| import torch.nn as nn | |
| from .backbone.vit import ViT | |
| from .head.topdown_heatmap_simple_head import TopdownHeatmapSimpleHead | |
| __all__ = ['ViTPose'] | |
| class ViTPose(nn.Module): | |
| def __init__(self, cfg: dict) -> None: | |
| super(ViTPose, self).__init__() | |
| backbone_cfg = {k: v for k, v in cfg['backbone'].items() if k != 'type'} | |
| head_cfg = {k: v for k, v in cfg['keypoint_head'].items() if k != 'type'} | |
| self.backbone = ViT(**backbone_cfg) | |
| self.keypoint_head = TopdownHeatmapSimpleHead(**head_cfg) | |
| def forward_features(self, x): | |
| return self.backbone(x) | |
| def forward(self, x): | |
| return self.keypoint_head(self.backbone(x)) | |