import torch def move_to_device(data, device): if isinstance(data, torch.Tensor): return data.to(device) elif isinstance(data, dict): return {k: move_to_device(v, device) for k, v in data.items()} elif isinstance(data, list): return [move_to_device(i, device) for i in data] else: return data