File size: 4,759 Bytes
292197f
 
 
8c5f108
 
 
 
 
 
1c3baf1
292197f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import torch
import torch.nn as nn
import random
from training.data.load_data import get_dataloders, save_preprocessed_data_to_output
from training.model.loss import ProxyAnchorLoss
from training.model.model import FusedFeatureModel
from training.model.feature_extractor import FeatureExtractor
from training.test_helpers import evaluate_ncm_accuracy_top3
from training.train_helpers import train_or_load_model
from training.helpers.args import parse_and_get_args


SEED_VALUE = 42

torch.manual_seed(SEED_VALUE) 
random.seed(SEED_VALUE)


FULL_TRAIN = True

RENDERS_FOLDER_NAME = "BlenderRenders7"
TEST_FOLDER_NAME = "test_data3"
RENDERS_FOLDER = f"/mnt/c/Users/KMult/Desktop/Praca_inzynierska/models/{RENDERS_FOLDER_NAME}"
TEST_FOLDER = f"/mnt/c/Users/KMult/Desktop/Praca_inzynierska/models/{TEST_FOLDER_NAME}"
VALIDATION_FOLDER = "/mnt/c/Users/KMult/Desktop/Praca_inzynierska/models/validation_data"

BATCH_PRINT_INTERVAL = 50 
NUM_EPOCHS = 50
VALIDATION_CHECK_INTERVAL = 50
VALIDATION_PATIENCE = 60
VALIDATION_MIN_DELTA = 0.001
sample_size = 8192


FILTER_HARD_DATA = False


LOAD_NETWORK = False 
LOAD_SPECIFIC_MODEL_NAME = None

# LOAD_NETWORK = True
# LOAD_SPECIFIC_MODEL_NAME = "BESTID_2_fused_feature_model.pth_full_clip1_segformer0_midas0_dpt0_gate0_batch64_traintypehardmining_bigfusionhead2_lr2e-07_margin0.8_alpha64.0_datasetsize114272_rendersBlenderRenders7_testdatatest_data3.model"

args = parse_and_get_args(LOAD_SPECIFIC_MODEL_NAME)
print(args)

BATCH_SIZE = args.batch
TRAIN_TYPE = args.train_type.lower()

MODELS_USED = {
    'clip': args.clip,
    'segformer': args.segformer,
    'midas': args.midas,
    'dpt': args.dpt,
}

VECTORS_SAVE_FOLDER = f"training/preprocessed_vectors/mega_clip_segformer_midas_dpt4.pth"
TEST_VECTORS_SAVE_FOLDER = f"training/preprocessed_vectors/test_mega_clip_segformer_midas_dpt4.pth"
feature_extractor = FeatureExtractor(use_models=MODELS_USED)

# save_preprocessed_data_to_output(root_folder=RENDERS_FOLDER, root_test_folder=TEST_FOLDER, output_file_path=VECTORS_SAVE_FOLDER, test_output_file_path=TEST_VECTORS_SAVE_FOLDER, feature_extractor=feature_extractor, processor_flags=MODELS_USED)

MODELS_USED_FOR_TRAINING = MODELS_USED.copy()
MODELS_USED_FOR_TRAINING['gate'] = args.gate

train_dataloader, validation_dataloader, test_dataloader, id_to_tag = get_dataloders(RENDERS_FOLDER, TEST_FOLDER, preprocessed_file_path=VECTORS_SAVE_FOLDER, test_preprocessed_file_path=TEST_VECTORS_SAVE_FOLDER, validation_root_dir=None, is_full_train=FULL_TRAIN, batch_size=BATCH_SIZE, sample_size=sample_size, processor_flags=MODELS_USED, filter_hard_data=FILTER_HARD_DATA)
dataset_size = len(train_dataloader.dataset)

MODEL_NAME = (
    "fused_feature_model.pth"
    + ("_full" if FULL_TRAIN else "_sampled")
    + f"_clip{args.clip}"
    + f"_segformer{args.segformer}"
    + f"_midas{args.midas}"
    + f"_dpt{args.dpt}"
    + f"_gate{args.gate}"
    + f"_batch{args.batch}"
    + f"_traintype{args.train_type}"
    + f"_bigfusionhead{args.big_fusion_head}"
    + f"_lr{args.lr}"
    + f"_margin{args.margin}"
    + f"_alpha{args.alpha}"
    + f"_datasetsize{dataset_size}"
    + f"_renders{RENDERS_FOLDER_NAME}"
    + f"_testdata{TEST_FOLDER_NAME}"
    + ".model"
)

EMBEDDING_DIM = 1024 if args.big_fusion_head >= 2 else 512
NUM_BUILDINGS = len(id_to_tag)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = FusedFeatureModel(feature_dims=feature_extractor.feature_dims, embedding_dim=EMBEDDING_DIM, use_gate=bool(args.gate), big_fusion_head=args.big_fusion_head, use_models=MODELS_USED_FOR_TRAINING).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
# criterion = nn.TripletMarginLoss(margin=0.2).to(device)
criterion = ProxyAnchorLoss(
    num_classes=NUM_BUILDINGS, 
    embedding_dim=EMBEDDING_DIM, 
    margin=args.margin, 
    alpha=args.alpha 
).to(device)

model, prototype_tensor, class_ids, total_batches_processed, min_validation_loss = train_or_load_model(
    model, 
    load_network=LOAD_NETWORK, 
    specific_model_name=LOAD_SPECIFIC_MODEL_NAME,
    optimizer=optimizer, 
    criterion=criterion, 
    train_dataloader=train_dataloader, 
    device=device, 
    num_epochs=NUM_EPOCHS, 
    batch_print_interval=BATCH_PRINT_INTERVAL, 
    val_dataloader=validation_dataloader,
    validation_interval=VALIDATION_CHECK_INTERVAL,
    model_name=MODEL_NAME,
    patience=VALIDATION_PATIENCE,
    min_delta=VALIDATION_MIN_DELTA,
    train_type=TRAIN_TYPE,
)

print("\nModel with min validation loss: ", min_validation_loss, " loaded.")

final_accuracy, accuracy_top3 = evaluate_ncm_accuracy_top3(model, test_dataloader, prototype_tensor, class_ids, id_to_tag, device)

print(f"\nFinal Test Accuracy (NCM): {final_accuracy:.4f}, Top-3 Accuracy: {accuracy_top3:.4f}")