Spaces:
Paused
Paused
| import torch | |
| import torch.nn as nn | |
| import random | |
| from training.data.load_data import get_dataloders, save_preprocessed_data_to_output | |
| from training.model.loss import ProxyAnchorLoss | |
| from training.model.model import FusedFeatureModel | |
| from training.model.feature_extractor import FeatureExtractor | |
| from training.test_helpers import evaluate_ncm_accuracy_top3 | |
| from training.train_helpers import train_or_load_model | |
| from training.helpers.args import parse_and_get_args | |
| SEED_VALUE = 42 | |
| torch.manual_seed(SEED_VALUE) | |
| random.seed(SEED_VALUE) | |
| FULL_TRAIN = True | |
| RENDERS_FOLDER_NAME = "BlenderRenders7" | |
| TEST_FOLDER_NAME = "test_data3" | |
| RENDERS_FOLDER = f"/mnt/c/Users/KMult/Desktop/Praca_inzynierska/models/{RENDERS_FOLDER_NAME}" | |
| TEST_FOLDER = f"/mnt/c/Users/KMult/Desktop/Praca_inzynierska/models/{TEST_FOLDER_NAME}" | |
| VALIDATION_FOLDER = "/mnt/c/Users/KMult/Desktop/Praca_inzynierska/models/validation_data" | |
| BATCH_PRINT_INTERVAL = 50 | |
| NUM_EPOCHS = 50 | |
| VALIDATION_CHECK_INTERVAL = 50 | |
| VALIDATION_PATIENCE = 60 | |
| VALIDATION_MIN_DELTA = 0.001 | |
| sample_size = 8192 | |
| FILTER_HARD_DATA = False | |
| LOAD_NETWORK = False | |
| LOAD_SPECIFIC_MODEL_NAME = None | |
| # LOAD_NETWORK = True | |
| # LOAD_SPECIFIC_MODEL_NAME = "BESTID_2_fused_feature_model.pth_full_clip1_segformer0_midas0_dpt0_gate0_batch64_traintypehardmining_bigfusionhead2_lr2e-07_margin0.8_alpha64.0_datasetsize114272_rendersBlenderRenders7_testdatatest_data3.model" | |
| args = parse_and_get_args(LOAD_SPECIFIC_MODEL_NAME) | |
| print(args) | |
| BATCH_SIZE = args.batch | |
| TRAIN_TYPE = args.train_type.lower() | |
| MODELS_USED = { | |
| 'clip': args.clip, | |
| 'segformer': args.segformer, | |
| 'midas': args.midas, | |
| 'dpt': args.dpt, | |
| } | |
| VECTORS_SAVE_FOLDER = f"training/preprocessed_vectors/mega_clip_segformer_midas_dpt4.pth" | |
| TEST_VECTORS_SAVE_FOLDER = f"training/preprocessed_vectors/test_mega_clip_segformer_midas_dpt4.pth" | |
| feature_extractor = FeatureExtractor(use_models=MODELS_USED) | |
| # save_preprocessed_data_to_output(root_folder=RENDERS_FOLDER, root_test_folder=TEST_FOLDER, output_file_path=VECTORS_SAVE_FOLDER, test_output_file_path=TEST_VECTORS_SAVE_FOLDER, feature_extractor=feature_extractor, processor_flags=MODELS_USED) | |
| MODELS_USED_FOR_TRAINING = MODELS_USED.copy() | |
| MODELS_USED_FOR_TRAINING['gate'] = args.gate | |
| train_dataloader, validation_dataloader, test_dataloader, id_to_tag = get_dataloders(RENDERS_FOLDER, TEST_FOLDER, preprocessed_file_path=VECTORS_SAVE_FOLDER, test_preprocessed_file_path=TEST_VECTORS_SAVE_FOLDER, validation_root_dir=None, is_full_train=FULL_TRAIN, batch_size=BATCH_SIZE, sample_size=sample_size, processor_flags=MODELS_USED, filter_hard_data=FILTER_HARD_DATA) | |
| dataset_size = len(train_dataloader.dataset) | |
| MODEL_NAME = ( | |
| "fused_feature_model.pth" | |
| + ("_full" if FULL_TRAIN else "_sampled") | |
| + f"_clip{args.clip}" | |
| + f"_segformer{args.segformer}" | |
| + f"_midas{args.midas}" | |
| + f"_dpt{args.dpt}" | |
| + f"_gate{args.gate}" | |
| + f"_batch{args.batch}" | |
| + f"_traintype{args.train_type}" | |
| + f"_bigfusionhead{args.big_fusion_head}" | |
| + f"_lr{args.lr}" | |
| + f"_margin{args.margin}" | |
| + f"_alpha{args.alpha}" | |
| + f"_datasetsize{dataset_size}" | |
| + f"_renders{RENDERS_FOLDER_NAME}" | |
| + f"_testdata{TEST_FOLDER_NAME}" | |
| + ".model" | |
| ) | |
| EMBEDDING_DIM = 1024 if args.big_fusion_head >= 2 else 512 | |
| NUM_BUILDINGS = len(id_to_tag) | |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| model = FusedFeatureModel(feature_dims=feature_extractor.feature_dims, embedding_dim=EMBEDDING_DIM, use_gate=bool(args.gate), big_fusion_head=args.big_fusion_head, use_models=MODELS_USED_FOR_TRAINING).to(device) | |
| optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) | |
| # criterion = nn.TripletMarginLoss(margin=0.2).to(device) | |
| criterion = ProxyAnchorLoss( | |
| num_classes=NUM_BUILDINGS, | |
| embedding_dim=EMBEDDING_DIM, | |
| margin=args.margin, | |
| alpha=args.alpha | |
| ).to(device) | |
| model, prototype_tensor, class_ids, total_batches_processed, min_validation_loss = train_or_load_model( | |
| model, | |
| load_network=LOAD_NETWORK, | |
| specific_model_name=LOAD_SPECIFIC_MODEL_NAME, | |
| optimizer=optimizer, | |
| criterion=criterion, | |
| train_dataloader=train_dataloader, | |
| device=device, | |
| num_epochs=NUM_EPOCHS, | |
| batch_print_interval=BATCH_PRINT_INTERVAL, | |
| val_dataloader=validation_dataloader, | |
| validation_interval=VALIDATION_CHECK_INTERVAL, | |
| model_name=MODEL_NAME, | |
| patience=VALIDATION_PATIENCE, | |
| min_delta=VALIDATION_MIN_DELTA, | |
| train_type=TRAIN_TYPE, | |
| ) | |
| print("\nModel with min validation loss: ", min_validation_loss, " loaded.") | |
| final_accuracy, accuracy_top3 = evaluate_ncm_accuracy_top3(model, test_dataloader, prototype_tensor, class_ids, id_to_tag, device) | |
| print(f"\nFinal Test Accuracy (NCM): {final_accuracy:.4f}, Top-3 Accuracy: {accuracy_top3:.4f}") | |