ZZZdream95's picture
server with supervisord
1c3baf1
import torch
import torch.nn as nn
import random
from training.data.load_data import get_dataloders, save_preprocessed_data_to_output
from training.model.loss import ProxyAnchorLoss
from training.model.model import FusedFeatureModel
from training.model.feature_extractor import FeatureExtractor
from training.test_helpers import evaluate_ncm_accuracy_top3
from training.train_helpers import train_or_load_model
from training.helpers.args import parse_and_get_args
SEED_VALUE = 42
torch.manual_seed(SEED_VALUE)
random.seed(SEED_VALUE)
FULL_TRAIN = True
RENDERS_FOLDER_NAME = "BlenderRenders7"
TEST_FOLDER_NAME = "test_data3"
RENDERS_FOLDER = f"/mnt/c/Users/KMult/Desktop/Praca_inzynierska/models/{RENDERS_FOLDER_NAME}"
TEST_FOLDER = f"/mnt/c/Users/KMult/Desktop/Praca_inzynierska/models/{TEST_FOLDER_NAME}"
VALIDATION_FOLDER = "/mnt/c/Users/KMult/Desktop/Praca_inzynierska/models/validation_data"
BATCH_PRINT_INTERVAL = 50
NUM_EPOCHS = 50
VALIDATION_CHECK_INTERVAL = 50
VALIDATION_PATIENCE = 60
VALIDATION_MIN_DELTA = 0.001
sample_size = 8192
FILTER_HARD_DATA = False
LOAD_NETWORK = False
LOAD_SPECIFIC_MODEL_NAME = None
# LOAD_NETWORK = True
# LOAD_SPECIFIC_MODEL_NAME = "BESTID_2_fused_feature_model.pth_full_clip1_segformer0_midas0_dpt0_gate0_batch64_traintypehardmining_bigfusionhead2_lr2e-07_margin0.8_alpha64.0_datasetsize114272_rendersBlenderRenders7_testdatatest_data3.model"
args = parse_and_get_args(LOAD_SPECIFIC_MODEL_NAME)
print(args)
BATCH_SIZE = args.batch
TRAIN_TYPE = args.train_type.lower()
MODELS_USED = {
'clip': args.clip,
'segformer': args.segformer,
'midas': args.midas,
'dpt': args.dpt,
}
VECTORS_SAVE_FOLDER = f"training/preprocessed_vectors/mega_clip_segformer_midas_dpt4.pth"
TEST_VECTORS_SAVE_FOLDER = f"training/preprocessed_vectors/test_mega_clip_segformer_midas_dpt4.pth"
feature_extractor = FeatureExtractor(use_models=MODELS_USED)
# save_preprocessed_data_to_output(root_folder=RENDERS_FOLDER, root_test_folder=TEST_FOLDER, output_file_path=VECTORS_SAVE_FOLDER, test_output_file_path=TEST_VECTORS_SAVE_FOLDER, feature_extractor=feature_extractor, processor_flags=MODELS_USED)
MODELS_USED_FOR_TRAINING = MODELS_USED.copy()
MODELS_USED_FOR_TRAINING['gate'] = args.gate
train_dataloader, validation_dataloader, test_dataloader, id_to_tag = get_dataloders(RENDERS_FOLDER, TEST_FOLDER, preprocessed_file_path=VECTORS_SAVE_FOLDER, test_preprocessed_file_path=TEST_VECTORS_SAVE_FOLDER, validation_root_dir=None, is_full_train=FULL_TRAIN, batch_size=BATCH_SIZE, sample_size=sample_size, processor_flags=MODELS_USED, filter_hard_data=FILTER_HARD_DATA)
dataset_size = len(train_dataloader.dataset)
MODEL_NAME = (
"fused_feature_model.pth"
+ ("_full" if FULL_TRAIN else "_sampled")
+ f"_clip{args.clip}"
+ f"_segformer{args.segformer}"
+ f"_midas{args.midas}"
+ f"_dpt{args.dpt}"
+ f"_gate{args.gate}"
+ f"_batch{args.batch}"
+ f"_traintype{args.train_type}"
+ f"_bigfusionhead{args.big_fusion_head}"
+ f"_lr{args.lr}"
+ f"_margin{args.margin}"
+ f"_alpha{args.alpha}"
+ f"_datasetsize{dataset_size}"
+ f"_renders{RENDERS_FOLDER_NAME}"
+ f"_testdata{TEST_FOLDER_NAME}"
+ ".model"
)
EMBEDDING_DIM = 1024 if args.big_fusion_head >= 2 else 512
NUM_BUILDINGS = len(id_to_tag)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = FusedFeatureModel(feature_dims=feature_extractor.feature_dims, embedding_dim=EMBEDDING_DIM, use_gate=bool(args.gate), big_fusion_head=args.big_fusion_head, use_models=MODELS_USED_FOR_TRAINING).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
# criterion = nn.TripletMarginLoss(margin=0.2).to(device)
criterion = ProxyAnchorLoss(
num_classes=NUM_BUILDINGS,
embedding_dim=EMBEDDING_DIM,
margin=args.margin,
alpha=args.alpha
).to(device)
model, prototype_tensor, class_ids, total_batches_processed, min_validation_loss = train_or_load_model(
model,
load_network=LOAD_NETWORK,
specific_model_name=LOAD_SPECIFIC_MODEL_NAME,
optimizer=optimizer,
criterion=criterion,
train_dataloader=train_dataloader,
device=device,
num_epochs=NUM_EPOCHS,
batch_print_interval=BATCH_PRINT_INTERVAL,
val_dataloader=validation_dataloader,
validation_interval=VALIDATION_CHECK_INTERVAL,
model_name=MODEL_NAME,
patience=VALIDATION_PATIENCE,
min_delta=VALIDATION_MIN_DELTA,
train_type=TRAIN_TYPE,
)
print("\nModel with min validation loss: ", min_validation_loss, " loaded.")
final_accuracy, accuracy_top3 = evaluate_ncm_accuracy_top3(model, test_dataloader, prototype_tensor, class_ids, id_to_tag, device)
print(f"\nFinal Test Accuracy (NCM): {final_accuracy:.4f}, Top-3 Accuracy: {accuracy_top3:.4f}")