bachelor-project / training /train_helpers.py
ZZZdream95's picture
server with supervisord
1c3baf1
import numpy as np
import torch
from training.test_helpers import generate_class_prototypes
from training.utils import move_to_device
import time
import os
import glob
import re
MODELS_SAVE_DIR = "training/saved_models/newclip_mega8"
def train_or_load_model(model, load_network=True, retrain_model=False, specific_model_name=None, optimizer=None, criterion=None, train_dataloader=None, device=None, num_epochs=1, batch_print_interval=5, val_dataloader=None, validation_interval=20, model_name="fused_feature_model.pth", patience=5, min_delta=0.05, train_type="standard"):
save_dir = os.path.dirname(model_name) if os.path.dirname(model_name) else MODELS_SAVE_DIR
os.makedirs(save_dir, exist_ok=True)
latest_path = None
# 1. Check for a specific checkpoint name provided by the user
if specific_model_name:
latest_path = os.path.join(save_dir, specific_model_name)
if not os.path.exists(latest_path):
print(f"Warning: Specific checkpoint file not found at: {latest_path}. Skipping model load.")
latest_path = None # Do not attempt to load if path is invalid
# 2. If no specific name was provided or it was invalid, look for the latest checkpoint
if latest_path is None:
base_filename = os.path.basename(model_name)
max_num, latest_path = get_latest_checkpoint_info(save_dir, base_filename)
# --- Loading model ---
if load_network and latest_path:
try:
print(f"Attempting to load latest model and prototypes: {latest_path}")
checkpoint = torch.load(latest_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
total_batches_processed = checkpoint.get('total_batches_processed', None)
min_validation_loss = checkpoint.get('min_validation_loss', None)
print(f"Model with min validation loss {min_validation_loss} loaded")
if not retrain_model:
model.eval()
if 'prototype_tensor' in checkpoint and 'class_ids' in checkpoint:
prototype_tensor = checkpoint['prototype_tensor']
class_ids = checkpoint['class_ids']
print(f"Model and Prototypes successfully loaded from {latest_path}")
return model, prototype_tensor, class_ids, total_batches_processed, min_validation_loss
else:
print(f"Model state loaded successfully from {latest_path}, but **Prototypes are missing**. Generating prototypes now.")
prototype_tensor, class_ids = generate_class_prototypes(model, train_dataloader, device)
torch.save({
'model_state_dict': model.state_dict(),
'prototype_tensor': prototype_tensor,
'class_ids': class_ids,
'total_batches_processed': total_batches_processed,
"id_to_tag": train_dataloader.dataset.id_to_tag,
}, latest_path)
print(f"Prototypes generated and saved to {latest_path}")
return model, prototype_tensor, class_ids, total_batches_processed, min_validation_loss
except Exception as e:
print(f"Error loading model from {latest_path}: {e}")
print("Proceeding to train from scratch.")
# --- Training ---
print(f"Training from scratch (or resuming failed load).")
if train_type == "standard":
total_batches_processed, min_validation_loss = train_model(
model,
optimizer,
criterion,
train_dataloader,
device,
num_epochs=num_epochs,
batch_print_interval=batch_print_interval,
val_dataloader=val_dataloader,
validation_interval=validation_interval,
patience=patience,
min_delta=min_delta
)
elif train_type == "hardmining":
total_batches_processed, min_validation_loss = train_model_hard_mining(
model,
optimizer,
criterion,
train_dataloader,
device,
num_epochs=num_epochs,
batch_print_interval=batch_print_interval,
val_dataloader=val_dataloader,
validation_interval=validation_interval,
patience=patience,
min_delta=min_delta
)
elif train_type == "curriculum":
total_batches_processed, min_validation_loss = train_model_with_curriculum(
model,
optimizer,
criterion,
train_dataloader,
device,
num_epochs=num_epochs,
batch_print_interval=batch_print_interval,
val_dataloader=val_dataloader,
validation_interval=validation_interval,
patience=patience,
min_delta=min_delta
)
# --- Saving ---
next_num = max_num + 1
new_filename = f"{next_num}_{base_filename}"
save_path = os.path.join(save_dir, new_filename)
print(f"Min Validation Loss after training: {min_validation_loss}")
torch.save({
'model_state_dict': model.state_dict(),
"total_batches_processed": total_batches_processed,
"min_validation_loss": min_validation_loss
}, save_path)
print(f"Model saved (before prototypes) to {save_path}")
prototype_tensor, class_ids = generate_class_prototypes(model, train_dataloader, device)
torch.save({
'model_state_dict': model.state_dict(),
'prototype_tensor': prototype_tensor,
'class_ids': class_ids,
'total_batches_processed': total_batches_processed,
'min_validation_loss': min_validation_loss,
"id_to_tag": train_dataloader.dataset.id_to_tag,
}, save_path)
print(f"Model and Prototypes saved to {save_path}")
return model, prototype_tensor, class_ids, total_batches_processed, min_validation_loss
def train_model(model, optimizer, criterion, train_dataloader, device, num_epochs=1, batch_print_interval=5, val_dataloader=None, validation_interval=20, patience=5, min_delta=0.05):
early_stopper = EarlyStopper(patience=patience, min_delta=min_delta)
running_loss = 0.0
running_time = 0.0
total_batches_processed = 0
model.train()
for epoch in range(num_epochs):
start_time = time.time()
for i, batch in enumerate(train_dataloader):
batch_idx = i + 1
total_batches_processed += 1
batch = move_to_device(batch, device)
A_anchor = batch['anchor']
A_clip = A_anchor.get('clip')
A_seg = A_anchor.get('segformer')
A_dpt = A_anchor.get('dpt')
A_midas = A_anchor.get('midas')
# --- Safe extraction for Positive (P) ---
P_positive = batch['positive']
P_clip = P_positive.get('clip')
P_seg = P_positive.get('segformer')
P_dpt = P_positive.get('dpt')
P_midas = P_positive.get('midas')
# --- Safe extraction for Negative (N) ---
N_negative = batch['negative']
N_clip = N_negative.get('clip')
N_seg = N_negative.get('segformer')
N_dpt = N_negative.get('dpt')
N_midas = N_negative.get('midas')
if A_clip.device.type != 'cuda' or next(model.parameters()).device.type != 'cuda':
print("\n*** CRITICAL DEVICE SWITCH DETECTED ***")
print(f"Tensor is on {A_clip.device.type}. Training speed will be severely impacted.")
anchor_embed = model(A_clip, A_seg, A_dpt, A_midas)
positive_embed = model(P_clip, P_seg, P_dpt, P_midas)
negative_embed = model(N_clip, N_seg, N_dpt, N_midas)
loss = criterion(anchor_embed, positive_embed, negative_embed)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item()
if batch_idx % batch_print_interval == 0:
avg_loss = running_loss / batch_print_interval
elapsed_time = time.time() - start_time
print(f'Epoch [{epoch+1}/{num_epochs}], '
f'Batch [{batch_idx}/{len(train_dataloader)}], '
f'Loss: {avg_loss:.4f}, '
f'Time/5 Batches: {elapsed_time:.2f}s')
running_loss = 0.0
running_time += elapsed_time
start_time = time.time()
if val_dataloader is not None and batch_idx % validation_interval == 0:
val_loss, num_val_batches = validate_model_mining(
model, criterion, val_dataloader, device
)
print(f"[Validation @ Batch {batch_idx}] Checked {num_val_batches} Val Batches. Loss: {val_loss:.4f}\n")
if early_stopper.early_stop(val_loss, model):
print(f"\n*** Early stopping triggered! ***")
print(f"Validation loss has not improved for {early_stopper.patience} validation checks.")
# Load the best weights before exiting
if early_stopper.best_model_state is not None:
model.load_state_dict(early_stopper.best_model_state)
print("Restored best model weights.")
return total_batches_processed, early_stopper.min_validation_loss
model.train()
start_time = time.time()
if batch_idx % validation_interval == 0:
if torch.cuda.is_available():
torch.cuda.empty_cache()
print(f"--- Epoch {epoch+1} finished. ---")
print(f"Total time for epoch: {running_time:.2f}s")
if early_stopper.best_model_state is not None:
model.load_state_dict(early_stopper.best_model_state)
print("Training finished. Restored best model weights based on validation loss.")
return total_batches_processed, early_stopper.min_validation_loss
def validate_model(model, criterion, val_dataloader, device, batches_to_check=None):
"""
Evaluates the model on the validation dataset for a specified number of batches.
If batches_to_check is None, it runs over the entire val_dataloader.
"""
model.eval()
total_val_loss = 0.0
num_batches = 0
start_time = time.time()
with torch.no_grad():
for batch_idx, batch in enumerate(val_dataloader):
if batches_to_check is not None and batch_idx >= batches_to_check:
break
batch = move_to_device(batch, device)
A_anchor = batch['anchor']
A_clip = A_anchor.get('clip')
A_seg = A_anchor.get('segformer')
A_dpt = A_anchor.get('dpt')
A_midas = A_anchor.get('midas')
# --- Safe extraction for Positive (P) ---
P_positive = batch['positive']
P_clip = P_positive.get('clip')
P_seg = P_positive.get('segformer')
P_dpt = P_positive.get('dpt')
P_midas = P_positive.get('midas')
# --- Safe extraction for Negative (N) ---
N_negative = batch['negative']
N_clip = N_negative.get('clip')
N_seg = N_negative.get('segformer')
N_dpt = N_negative.get('dpt')
N_midas = N_negative.get('midas')
anchor_embed = model(A_clip, A_seg, A_dpt, A_midas)
positive_embed = model(P_clip, P_seg, P_dpt, P_midas)
negative_embed = model(N_clip, N_seg, N_dpt, N_midas)
loss = criterion(anchor_embed, positive_embed, negative_embed)
total_val_loss += loss.item()
num_batches += 1
end_time = time.time()
validation_time = end_time - start_time
print(f"Validation took {validation_time:.2f} seconds.")
avg_val_loss = total_val_loss / num_batches if num_batches > 0 else 0.0
return avg_val_loss, num_batches
def get_latest_checkpoint_info(save_dir, base_filename):
search_pattern = os.path.join(save_dir, f"*_{base_filename}")
existing_files = glob.glob(search_pattern)
max_num = 0
latest_path = None
pattern = re.compile(r'^(\d+)_')
for file in existing_files:
name = os.path.basename(file)
match = pattern.match(name)
if match:
current_num = int(match.group(1))
if current_num > max_num:
max_num = current_num
latest_path = file
return max_num, latest_path
class EarlyStopper:
"""
Early stopping to stop training when the validation loss does not improve
after a given patience.
"""
def __init__(self, patience=5, min_delta=0):
self.patience = patience
self.min_delta = min_delta
self.counter = 0
self.min_validation_loss = np.inf
self.best_model_state = None
def early_stop(self, validation_loss, model):
"""
Returns True if early stopping criteria are met.
Stores the best model state if the current loss is an improvement.
"""
if validation_loss < self.min_validation_loss - self.min_delta:
self.min_validation_loss = validation_loss
print(f"New minimum validation loss: {self.min_validation_loss:.4f}. Saving best model state.")
self.counter = 0
self.best_model_state = model.state_dict()
elif validation_loss > self.min_validation_loss + self.min_delta:
self.counter += 1
if self.counter >= self.patience:
return True
return False
def batch_hard_mining(embeddings, labels, margin):
"""
Implements BatchHard Triplet Mining. Finds the hardest positive and negative
for every anchor in the batch.
"""
# Calculate all pairwise distances (Euclidean)
pairwise_dist = torch.cdist(embeddings, embeddings, p=2.0)
# Get masks
# labels_equal[i, j] is True if labels[i] == labels[j]
labels_equal = labels.unsqueeze(0) == labels.unsqueeze(1)
# 1. Hardest Positive (P_h): Max distance among positives
mask_anchor_positive = labels_equal.triu(diagonal=1) # Upper triangle, exclude diagonal (A=P)
# Set non-positives to a very small number (or 0) so max() finds the hardest positive
max_dist = pairwise_dist.max()
dist_positive = pairwise_dist * mask_anchor_positive.float()
# Find the max (hardest) positive distance for each row (Anchor)
# We set non-positive distances to a small value so they don't affect the max
dist_positive[mask_anchor_positive.logical_not()] = 0
# Find the hardest positive distance for each anchor (row)
# This requires looking across all positive pairs that include that anchor.
# It's computationally simpler to find the max distance for all P in the batch
# for each A (row).
# Max distance to a positive for each Anchor (row)
dist_ap, _ = torch.max(dist_positive, dim=1)
# 2. Hardest Negative (N_h): Min distance among negatives
mask_anchor_negative = labels_equal.logical_not()
# Set positives and diagonal to a very large number (inf) so min() finds the hardest negative
dist_negative = pairwise_dist + max_dist * (1 - mask_anchor_negative.float())
# Find the min (hardest) negative distance for each Anchor (row)
dist_an, _ = torch.min(dist_negative, dim=1)
# 3. Compute Triplet Loss on the mined triplets
# Loss: max(0, D_ap - D_an + margin)
loss_triplet = torch.relu(dist_ap - dist_an + margin)
# Return the average non-zero loss
if loss_triplet.numel() == 0:
return torch.tensor(0.0, device=embeddings.device, requires_grad=True)
# Only average over the triplets that contribute to the loss (loss > 0)
return loss_triplet.mean()
def batch_semi_hard_mining(embeddings, labels, margin):
"""
Implements Batch Semi-Hard Triplet Mining. Finds the hardest *violating* positive and the negative that is *outside* the margin but *closer* than the hardest negative (or the one that is closest to d(a,p)) for
every anchor in the batch.
A more robust way is to select the negative that violates the margin but
is closest to d(a,p), or simply select the hardest negative among those
that satisfy the semi-hard condition: d(a,p) < d(a,n) < d(a,p) + margin.
This implementation will strictly follow the original formulation:
1. Hardest Positive (P_h): Max distance among positives (d_ap).
2. Semi-Hard Negative (N_sh): Negative distance d_an such that
d_ap < d_an, but d_an < d_ap + margin.
If multiple satisfy this, we take the one closest to d_ap (the "hardest" semi-hard).
"""
# Calculate all pairwise distances (Euclidean)
pairwise_dist = torch.cdist(embeddings, embeddings, p=2.0)
# Get masks
# labels_equal[i, j] is True if labels[i] == labels[j]
labels_equal = labels.unsqueeze(0) == labels.unsqueeze(1)
# --- 1. Hardest Positive (P_h) ---
mask_anchor_positive = labels_equal.triu(diagonal=1)
# Set non-positives to 0 (since dists are positive, 0 won't be max)
dist_positive = pairwise_dist * mask_anchor_positive.float()
dist_positive[mask_anchor_positive.logical_not()] = 0
# Max distance to a positive for each Anchor (row)
dist_ap, _ = torch.max(dist_positive, dim=1)
# --- 2. Semi-Hard Negative (N_sh) ---
mask_anchor_negative = labels_equal.logical_not()
# Ensure d(a,n) > d(a,p) (Hardness condition)
# dist_ap is (B), expand to (B, B)
dist_ap_expanded = dist_ap.unsqueeze(1)
# Condition 1: d(a,n) > d(a,p)
mask_positive_violating = pairwise_dist > dist_ap_expanded
# Condition 2: d(a,n) < d(a,p) + margin (Semi-Hard condition)
mask_margin_satisfying = pairwise_dist < dist_ap_expanded + margin
# The Semi-Hard Negative Mask:
# Must be a negative, must be harder than the positive, and must satisfy the margin.
mask_semi_hard = mask_anchor_negative & mask_positive_violating & mask_margin_satisfying
# If no semi-hard negative exists for an anchor, we must find a valid substitute
# to avoid a zero-distance result, which could lead to loss=0 inappropriately.
# Create a distance matrix for MIN operation:
# 1. Start with the original pairwise_dist.
dist_negative = pairwise_dist.clone()
# 2. For non-semi-hard triplets, set the distance to a large value (Max + margin)
# so the torch.min() operation will choose a semi-hard one, if it exists.
# If *no* semi-hard negative exists for an anchor, we want to choose the
# hardest negative that *violates* the margin (i.e., the hardest negative).
# Temporarily set non-negatives to a large number
dist_negative[mask_anchor_negative.logical_not()] = 1e9
# The distance to minimize is: d(a,n) - d(a,p)
# We want the negative that is closest to d(a,p) but still satisfies the semi-hard condition.
# We will choose the hardest *non-violating* negative that is still a negative (i.e., d(a,n) > d(a,p))
# If a semi-hard negative exists, its mask is True.
# If a semi-hard negative *doesn't* exist, the common practice is to fall back to the
# hardest negative (which would violate the margin $d(a,n) < d(a,p)+\alpha$).
# Use the Semi-Hard Mask to define the relevant distances
dist_semi_hard = pairwise_dist.clone()
dist_semi_hard[mask_semi_hard.logical_not()] = 1e9 # Non-semi-hard dists are huge
# Find the min distance among the semi-hard negatives for each Anchor (row)
dist_an_semi_hard, _ = torch.min(dist_semi_hard, dim=1)
# Handle Anchors with NO Semi-Hard Negative:
# If the min distance is still 1e9, it means no semi-hard negative was found.
# In this case, we fall back to the HARDEST negative (closest d(a,n) > d(a,p) but d(a,n) < d(a,p)+margin is NOT met).
# Mask for anchors that found no semi-hard negative (distance is 1e9)
mask_no_semi_hard = dist_an_semi_hard == 1e9
# For those anchors, fall back to the hardest negative (the original Batch Hard negative)
if mask_no_semi_hard.any():
# Mask for all Negatives (Hardest Negative, d(a,n) < d(a,p) + margin is not required)
dist_all_negatives = pairwise_dist.clone()
dist_all_negatives[mask_anchor_negative.logical_not()] = 1e9
# Find the actual hardest negative for all anchors
dist_an_hard, _ = torch.min(dist_all_negatives, dim=1)
# Replace the 1e9 with the actual hardest negative distance
dist_an_semi_hard[mask_no_semi_hard] = dist_an_hard[mask_no_semi_hard]
dist_an = dist_an_semi_hard # Final negative distance to use
# --- 3. Compute Triplet Loss on the mined triplets ---
# Loss: max(0, D_ap - D_an + margin)
loss_triplet = torch.relu(dist_ap - dist_an + margin)
# Only average over the triplets that contribute to the loss (loss > 0)
# Note: We must check for at least one positive triplet being present in the batch
if loss_triplet.numel() == 0 or dist_ap.sum() == 0:
return torch.tensor(0.0, device=embeddings.device, requires_grad=True)
# Note: If we fall back to the hardest negative, the loss contribution might be 0
# (if d_an > d_ap + margin), but we still include it in the average (a common implementation choice).
# Since we are using the `torch.relu` here, the final loss will only be averaged over *all* anchors
# for which the loss calculation is > 0.
# Final check: only average over anchors that actually have a hard positive (dist_ap > 0)
# The most common implementation just uses the mean over the entire batch, which is simpler and less prone to edge cases.
return loss_triplet.mean()
def validate_model_mining(model, criterion, val_dataloader, device):
"""
Calculates validation loss using Online Hard Mining (BatchHard).
Args:
model: The FusedFeatureModel.
criterion: The loss criterion (used primarily to extract the margin).
val_dataloader: DataLoader using the MultiModalDataset.
device: 'cuda' or 'cpu'.
Returns:
(float, int): Average validation loss and number of batches checked.
"""
model.eval()
total_val_loss = 0.0
num_val_batches = 0
# CRITICAL: Extract the margin used in the criterion
# Assuming criterion is torch.nn.TripletMarginLoss
# MARGIN = criterion.margin
with torch.no_grad():
for batch in val_dataloader:
# Move batch to device (assuming move_to_device is defined)
# You must ensure the move_to_device helper moves nested dicts correctly
batch = move_to_device(batch, device)
inputs = batch['anchor']
labels = batch['y'] # True class labels (Shape: Batch Size)
# 1. Forward Pass: Compute all Embeddings in the batch
embeddings = model(**inputs)
# 2. Loss Calculation: Online Hard Mining Loss
# The loss is computed only on the hardest triplets found in the batch.
# loss = batch_semi_hard_mining(embeddings, labels, MARGIN)
loss = criterion(embeddings, labels)
total_val_loss += loss.item()
num_val_batches += 1
if num_val_batches == 0:
return 0.0, 0
avg_val_loss = total_val_loss / num_val_batches
return avg_val_loss, num_val_batches
# --- REVISED TRAINING LOOP ---
def train_model_hard_mining(model, optimizer, criterion, train_dataloader, device, num_epochs=1, batch_print_interval=5, val_dataloader=None, validation_interval=20, patience=5, min_delta=0.05):
# Extract margin from criterion (assuming it's TripletMarginLoss)
# MARGIN = criterion.margin
early_stopper = EarlyStopper(patience=patience, min_delta=min_delta)
running_loss = 0.0
running_time = 0.0
total_batches_processed = 0
model.train()
for epoch in range(num_epochs):
start_time = time.time()
for i, batch in enumerate(train_dataloader):
batch_idx = i + 1
total_batches_processed += 1
batch = move_to_device(batch, device)
inputs = batch['anchor']
labels = batch['y'] # True class labels (Shape: Batch Size)
# --- 1. Compute all Embeddings in the batch ---
# Note: We need to pass the tensors out of the dict structure for the model call
# This is complex when inputs are dicts. We'll extract only the required tensors:
# The model call needs to be simplified to handle the batch of inputs
embeddings = model(**inputs) # Embeddings shape: (Batch Size, Embedding Dim)
# --- 2. Online Hard Mining ---
# Use the BatchHard miner to find the hardest triplets and calculate loss
# loss = batch_semi_hard_mining(embeddings, labels, MARGIN)
loss = criterion(embeddings, labels)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item()
if batch_idx % batch_print_interval == 0:
avg_loss = running_loss / batch_print_interval
elapsed_time = time.time() - start_time
print(f'Epoch [{epoch+1}/{num_epochs}], '
f'Batch [{batch_idx}/{len(train_dataloader)}], '
f'Online Hard Mining Loss: {avg_loss:.4f}, '
f'Time/{batch_print_interval} Batches: {elapsed_time:.2f}s')
running_loss = 0.0
running_time += elapsed_time
start_time = time.time()
if val_dataloader is not None and batch_idx % validation_interval == 0:
val_loss, num_val_batches = validate_model_mining(
model, criterion, val_dataloader, device
)
print(f"[Validation @ Batch {batch_idx}] Checked {num_val_batches} Val Batches. Loss: {val_loss:.4f}\n")
if early_stopper.early_stop(val_loss, model):
print(f"\n*** Early stopping triggered! ***")
if early_stopper.best_model_state is not None:
model.load_state_dict(early_stopper.best_model_state)
print("Restored best model weights.")
return total_batches_processed, early_stopper.min_validation_loss
model.train()
start_time = time.time()
if batch_idx % validation_interval == 0 and torch.cuda.is_available():
torch.cuda.empty_cache()
print(f"--- Epoch {epoch+1} finished. ---")
print(f"Total time for epoch: {running_time:.2f}s")
if early_stopper.best_model_state is not None:
model.load_state_dict(early_stopper.best_model_state)
print("Training finished. Restored best model weights based on validation loss.")
return total_batches_processed, early_stopper.min_validation_loss
def train_model_with_curriculum(model, optimizer, criterion, train_dataloader, device, num_epochs=1, batch_print_interval=5, val_dataloader=None, validation_interval=20, patience=5, min_delta=0.05):
# CRITICAL: Extract the margin for Batch Hard Mining
# Assuming criterion is torch.nn.TripletMarginLoss
# MARGIN = criterion.margin
early_stopper = EarlyStopper(patience=patience, min_delta=min_delta)
FINE_TUNE_LR_FACTOR = 0.1 # e.g., drop LR by 10x
CLIP_LAYERS_TO_UNFREEZE = 1
SEGFORMER_LAYERS_TO_UNFREEZE = 1
DPT_LAYERS_TO_UNFREEZE = 0
MIDAS_LAYERS_TO_UNFREEZE = 0
# You may add other backbones here, e.g., 'segformer': 2, 'dpt': 1
BACKBONES_TO_UNFREEZE = {'clip': CLIP_LAYERS_TO_UNFREEZE}
# Get the current base LR
initial_lr = optimizer.param_groups[0]['lr']
running_loss = 0.0
total_batches_processed = 0
model.train()
for epoch in range(num_epochs):
start_time = time.time()
# Determine the mining strategy for the current epoch
# Epoch 1 (index 0) uses standard pre-sampled triplets (Random/Semi-hard)
# Epoch 2+ (index 1+) uses Online Batch Hard Mining
is_hard_mining_epoch = epoch >= 1
if is_hard_mining_epoch and epoch == 1:
print("\n--- Switching to Hard Mining Mode for Training Dataset and loading Best Model from Triplet Loss ---")
train_dataloader.dataset.hard_mining_mode = True
model.load_state_dict(early_stopper.best_model_state)
if epoch == 2:
print("\n--- PHASE 2: Starting Fine-Tuning (Epoch 3). Unfreezing last layers and dropping LR. ---")
# 1. Unfreeze the last N layers of selected backbones
for backbone_name, n_layers in BACKBONES_TO_UNFREEZE.items():
# The 'unfreeze_last_n_layers' function is assumed to be part of the model
model.unfreeze_last_n_layers(backbone_name, n=n_layers)
# 2. Drop the learning rate for stable fine-tuning
new_lr = initial_lr * FINE_TUNE_LR_FACTOR
adjust_learning_rate(optimizer, new_lr) # You need to define this helper function
print(f"Learning Rate adjusted for fine-tuning: {initial_lr:.6f} -> {new_lr:.6f}")
mining_strategy = "Hard Mining" if is_hard_mining_epoch else "Standard Triplet Loss"
if epoch >= 2:
mining_strategy += " + Fine-Tuning"
print(f"\n--- Epoch {epoch+1}/{num_epochs} | Using {mining_strategy} ---")
for i, batch in enumerate(train_dataloader):
batch_idx = i + 1
total_batches_processed += 1
batch = move_to_device(batch, device)
if not is_hard_mining_epoch:
# --- STANDARD TRIPLET LOSS (Epoch 1) ---
# Input structure is Anchor/Positive/Negative dicts
A_anchor = batch['anchor']
P_positive = batch['positive']
N_negative = batch['negative']
# Extract multimodal inputs (A_clip, A_seg, etc. from A_anchor)
A_clip, A_seg, A_dpt, A_midas = (A_anchor.get('clip'), A_anchor.get('segformer'), A_anchor.get('dpt'), A_anchor.get('midas'))
P_clip, P_seg, P_dpt, P_midas = (P_positive.get('clip'), P_positive.get('segformer'), P_positive.get('dpt'), P_positive.get('midas'))
N_clip, N_seg, N_dpt, N_midas = (N_negative.get('clip'), N_negative.get('segformer'), N_negative.get('dpt'), N_negative.get('midas'))
anchor_embed = model(A_clip, A_seg, A_dpt, A_midas)
positive_embed = model(P_clip, P_seg, P_dpt, P_midas)
negative_embed = model(N_clip, N_seg, N_dpt, N_midas)
loss = criterion(anchor_embed, positive_embed, negative_embed)
else:
# --- ONLINE BATCH HARD MINING (Epoch 2+) ---
# Input structure is 'anchor' inputs and 'y' labels
inputs = batch['anchor']
labels = batch['y'] # True class labels
# Extract inputs for the model's forward pass
clip_inputs = inputs.get('clip')
segformer_inputs = inputs.get('segformer')
dpt_inputs = inputs.get('dpt')
midas_inputs = inputs.get('midas')
# Compute all Embeddings in the batch
embeddings = model(clip_inputs, segformer_inputs, dpt_inputs, midas_inputs)
# Use the BatchHard miner to find the hardest triplets and calculate loss
# loss = batch_hard_mining(embeddings, labels, MARGIN)
loss = criterion(embeddings, labels)
# --- BACKPROPAGATION ---
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item()
# --- PRINTING ---
if batch_idx % batch_print_interval == 0:
avg_loss = running_loss / batch_print_interval
elapsed_time = time.time() - start_time
print(f'Epoch [{epoch+1}/{num_epochs}], '
f'Batch [{batch_idx}/{len(train_dataloader)}], '
f'{mining_strategy} Loss: {avg_loss:.4f}, '
f'Time/{batch_print_interval} Batches: {elapsed_time:.2f}s')
running_loss = 0.0
start_time = time.time()
# --- VALIDATION ---
if val_dataloader is not None and batch_idx % validation_interval == 0:
# **IMPORTANT:** Always use the more robust Hard Mining validation
# to get a real assessment of the embedding space's quality.
val_loss, num_val_batches = validate_model_mining(
model, criterion, val_dataloader, device
)
print(f"[Validation @ Batch {batch_idx}] Checked {num_val_batches} Val Batches. Loss: {val_loss:.4f}\n")
early_stopper.early_stop(val_loss, model)
if early_stopper.early_stop(val_loss, model):
print(f"\n*** Early stopping triggered! ***")
if early_stopper.best_model_state is not None:
model.load_state_dict(early_stopper.best_model_state)
print("Restored best model weights.")
return total_batches_processed, early_stopper.min_validation_loss
model.train()
start_time = time.time()
if torch.cuda.is_available():
torch.cuda.empty_cache()
print(f"--- Epoch {epoch+1} finished. ---")
if early_stopper.best_model_state is not None:
model.load_state_dict(early_stopper.best_model_state)
print("Training finished. Restored best model weights based on validation loss.")
return total_batches_processed, early_stopper.min_validation_loss
def adjust_learning_rate(optimizer, new_lr):
"""
Sets the learning rate for all parameter groups in the optimizer.
Args:
optimizer (torch.optim.Optimizer): The optimizer whose learning rate to adjust.
new_lr (float): The new learning rate value.
"""
for param_group in optimizer.param_groups:
param_group['lr'] = new_lr