Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import random | |
| import numpy as np | |
| from pytorch_lightning import Trainer, seed_everything | |
| from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor, model_checkpoint | |
| from pytorch_lightning.loggers import TensorBoardLogger | |
| from config import Config | |
| from dataset.dataloader import RubikDataModule | |
| from model.DeepcubeA_module import DeepcubeA | |
| import datetime | |
| torch.set_float32_matmul_precision('medium') | |
| def main(): | |
| # 解析配置 | |
| config = Config() | |
| args = config.parse_args() | |
| # 设置随机种子 | |
| seed_everything(args.seed, workers=True) | |
| args.log_dir = os.path.join(args.log_dir, datetime.datetime.now().strftime("%Y%m%d_%H%M")) | |
| args.checkpoint_dir = os.path.join(args.log_dir, args.checkpoint_dir) | |
| args.converged_checkpoint_dir = os.path.join(args.log_dir, args.converged_checkpoint_dir) | |
| # 设置 accelerator & devices | |
| if args.devices.lower() == "cpu": | |
| accelerator = "cpu" | |
| devices = 1 # CPU 默认只用一个进程 | |
| elif args.devices.lower() == "auto": | |
| accelerator = "gpu" if torch.cuda.is_available() else "cpu" | |
| devices = "auto" | |
| else: | |
| # 用户指定了 GPU id(s) | |
| accelerator = "gpu" | |
| if "," in args.devices: | |
| devices = [int(x) for x in args.devices.split(",")] | |
| else: | |
| devices = [int(args.devices)] | |
| # 创建必要的目录 | |
| os.makedirs(args.log_dir, exist_ok=True) | |
| os.makedirs(args.checkpoint_dir, exist_ok=True) | |
| os.makedirs(args.converged_checkpoint_dir, exist_ok=True) | |
| # 初始化模型(只初始化一次,后续复用) | |
| model = DeepcubeA(args) | |
| # 设置初始K值和最大K值 | |
| initial_K = 16 | |
| max_K = args.K # 可以根据需要调整 | |
| model_e_checkpoint = "logs/20250818_1819/converged_checkpoints/final_model_K_14.pth" | |
| model.model_theta_e.load_state_dict(torch.load(model_e_checkpoint)) | |
| model_checkpoint = "logs/20250818_1819/converged_checkpoints/final_model_K_15.pth" | |
| model.model_theta.load_state_dict(torch.load(model_checkpoint)) | |
| for K in range(initial_K, max_K + 1): | |
| print(f'\n--- 开始训练 K={K} ---') | |
| # 更新模型的K值 | |
| model.K = K | |
| # 创建新的数据集配置 | |
| args.K = K # 设置当前K值 | |
| # 初始化新的数据模块 | |
| data_module = RubikDataModule(args) | |
| # # 设置回调函数,暂时不添加这个,因为好像没什么用 | |
| # checkpoint_callback = ModelCheckpoint( | |
| # dirpath=args.checkpoint_dir, | |
| # filename=f'K_{K}_'+'{epoch}-{val_loss:.2f}', | |
| # save_top_k=3, | |
| # monitor='val_loss', | |
| # mode='min' | |
| # ) | |
| early_stopping_callback = EarlyStopping( | |
| monitor='val_loss', | |
| patience=5, | |
| mode='min', | |
| ) | |
| # lr_monitor = LearningRateMonitor(logging_interval='epoch') | |
| # # 设置日志记录器(每个K值使用不同的日志目录) | |
| # logger = TensorBoardLogger( | |
| # save_dir=args.log_dir, | |
| # name=f'train_logs_K_{K}' | |
| # ) | |
| # 初始化新的训练器,默认每个epoch验证一次,即5000步 | |
| trainer = Trainer( | |
| max_epochs=args.max_epochs, | |
| accelerator=accelerator, | |
| precision="16-mixed", # 启用混合精度 | |
| devices=devices, | |
| logger=False, | |
| callbacks=[early_stopping_callback], | |
| deterministic=True, | |
| enable_progress_bar=True, | |
| enable_checkpointing=True | |
| ) | |
| print(trainer.log_every_n_steps) | |
| # 训练模型 | |
| trainer.fit(model, datamodule=data_module) | |
| print(f'--- 完成训练 K={K} ---\n') | |
| if __name__ == '__main__': | |
| main() |