Deepcube / train.py
hanxiaofeng
first commit
b570cf2
import os
import torch
import random
import numpy as np
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor, model_checkpoint
from pytorch_lightning.loggers import TensorBoardLogger
from config import Config
from dataset.dataloader import RubikDataModule
from model.DeepcubeA_module import DeepcubeA
import datetime
torch.set_float32_matmul_precision('medium')
def main():
# 解析配置
config = Config()
args = config.parse_args()
# 设置随机种子
seed_everything(args.seed, workers=True)
args.log_dir = os.path.join(args.log_dir, datetime.datetime.now().strftime("%Y%m%d_%H%M"))
args.checkpoint_dir = os.path.join(args.log_dir, args.checkpoint_dir)
args.converged_checkpoint_dir = os.path.join(args.log_dir, args.converged_checkpoint_dir)
# 设置 accelerator & devices
if args.devices.lower() == "cpu":
accelerator = "cpu"
devices = 1 # CPU 默认只用一个进程
elif args.devices.lower() == "auto":
accelerator = "gpu" if torch.cuda.is_available() else "cpu"
devices = "auto"
else:
# 用户指定了 GPU id(s)
accelerator = "gpu"
if "," in args.devices:
devices = [int(x) for x in args.devices.split(",")]
else:
devices = [int(args.devices)]
# 创建必要的目录
os.makedirs(args.log_dir, exist_ok=True)
os.makedirs(args.checkpoint_dir, exist_ok=True)
os.makedirs(args.converged_checkpoint_dir, exist_ok=True)
# 初始化模型(只初始化一次,后续复用)
model = DeepcubeA(args)
# 设置初始K值和最大K值
initial_K = 16
max_K = args.K # 可以根据需要调整
model_e_checkpoint = "logs/20250818_1819/converged_checkpoints/final_model_K_14.pth"
model.model_theta_e.load_state_dict(torch.load(model_e_checkpoint))
model_checkpoint = "logs/20250818_1819/converged_checkpoints/final_model_K_15.pth"
model.model_theta.load_state_dict(torch.load(model_checkpoint))
for K in range(initial_K, max_K + 1):
print(f'\n--- 开始训练 K={K} ---')
# 更新模型的K值
model.K = K
# 创建新的数据集配置
args.K = K # 设置当前K值
# 初始化新的数据模块
data_module = RubikDataModule(args)
# # 设置回调函数,暂时不添加这个,因为好像没什么用
# checkpoint_callback = ModelCheckpoint(
# dirpath=args.checkpoint_dir,
# filename=f'K_{K}_'+'{epoch}-{val_loss:.2f}',
# save_top_k=3,
# monitor='val_loss',
# mode='min'
# )
early_stopping_callback = EarlyStopping(
monitor='val_loss',
patience=5,
mode='min',
)
# lr_monitor = LearningRateMonitor(logging_interval='epoch')
# # 设置日志记录器(每个K值使用不同的日志目录)
# logger = TensorBoardLogger(
# save_dir=args.log_dir,
# name=f'train_logs_K_{K}'
# )
# 初始化新的训练器,默认每个epoch验证一次,即5000步
trainer = Trainer(
max_epochs=args.max_epochs,
accelerator=accelerator,
precision="16-mixed", # 启用混合精度
devices=devices,
logger=False,
callbacks=[early_stopping_callback],
deterministic=True,
enable_progress_bar=True,
enable_checkpointing=True
)
print(trainer.log_every_n_steps)
# 训练模型
trainer.fit(model, datamodule=data_module)
print(f'--- 完成训练 K={K} ---\n')
if __name__ == '__main__':
main()