import numpy as np import torch from training.test_helpers import generate_class_prototypes from training.utils import move_to_device import time import os import glob import re MODELS_SAVE_DIR = "training/saved_models/newclip_mega8" def train_or_load_model(model, load_network=True, retrain_model=False, specific_model_name=None, optimizer=None, criterion=None, train_dataloader=None, device=None, num_epochs=1, batch_print_interval=5, val_dataloader=None, validation_interval=20, model_name="fused_feature_model.pth", patience=5, min_delta=0.05, train_type="standard"): save_dir = os.path.dirname(model_name) if os.path.dirname(model_name) else MODELS_SAVE_DIR os.makedirs(save_dir, exist_ok=True) latest_path = None # 1. Check for a specific checkpoint name provided by the user if specific_model_name: latest_path = os.path.join(save_dir, specific_model_name) if not os.path.exists(latest_path): print(f"Warning: Specific checkpoint file not found at: {latest_path}. Skipping model load.") latest_path = None # Do not attempt to load if path is invalid # 2. If no specific name was provided or it was invalid, look for the latest checkpoint if latest_path is None: base_filename = os.path.basename(model_name) max_num, latest_path = get_latest_checkpoint_info(save_dir, base_filename) # --- Loading model --- if load_network and latest_path: try: print(f"Attempting to load latest model and prototypes: {latest_path}") checkpoint = torch.load(latest_path, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) total_batches_processed = checkpoint.get('total_batches_processed', None) min_validation_loss = checkpoint.get('min_validation_loss', None) print(f"Model with min validation loss {min_validation_loss} loaded") if not retrain_model: model.eval() if 'prototype_tensor' in checkpoint and 'class_ids' in checkpoint: prototype_tensor = checkpoint['prototype_tensor'] class_ids = checkpoint['class_ids'] print(f"Model and Prototypes successfully loaded from {latest_path}") return model, prototype_tensor, class_ids, total_batches_processed, min_validation_loss else: print(f"Model state loaded successfully from {latest_path}, but **Prototypes are missing**. Generating prototypes now.") prototype_tensor, class_ids = generate_class_prototypes(model, train_dataloader, device) torch.save({ 'model_state_dict': model.state_dict(), 'prototype_tensor': prototype_tensor, 'class_ids': class_ids, 'total_batches_processed': total_batches_processed, "id_to_tag": train_dataloader.dataset.id_to_tag, }, latest_path) print(f"Prototypes generated and saved to {latest_path}") return model, prototype_tensor, class_ids, total_batches_processed, min_validation_loss except Exception as e: print(f"Error loading model from {latest_path}: {e}") print("Proceeding to train from scratch.") # --- Training --- print(f"Training from scratch (or resuming failed load).") if train_type == "standard": total_batches_processed, min_validation_loss = train_model( model, optimizer, criterion, train_dataloader, device, num_epochs=num_epochs, batch_print_interval=batch_print_interval, val_dataloader=val_dataloader, validation_interval=validation_interval, patience=patience, min_delta=min_delta ) elif train_type == "hardmining": total_batches_processed, min_validation_loss = train_model_hard_mining( model, optimizer, criterion, train_dataloader, device, num_epochs=num_epochs, batch_print_interval=batch_print_interval, val_dataloader=val_dataloader, validation_interval=validation_interval, patience=patience, min_delta=min_delta ) elif train_type == "curriculum": total_batches_processed, min_validation_loss = train_model_with_curriculum( model, optimizer, criterion, train_dataloader, device, num_epochs=num_epochs, batch_print_interval=batch_print_interval, val_dataloader=val_dataloader, validation_interval=validation_interval, patience=patience, min_delta=min_delta ) # --- Saving --- next_num = max_num + 1 new_filename = f"{next_num}_{base_filename}" save_path = os.path.join(save_dir, new_filename) print(f"Min Validation Loss after training: {min_validation_loss}") torch.save({ 'model_state_dict': model.state_dict(), "total_batches_processed": total_batches_processed, "min_validation_loss": min_validation_loss }, save_path) print(f"Model saved (before prototypes) to {save_path}") prototype_tensor, class_ids = generate_class_prototypes(model, train_dataloader, device) torch.save({ 'model_state_dict': model.state_dict(), 'prototype_tensor': prototype_tensor, 'class_ids': class_ids, 'total_batches_processed': total_batches_processed, 'min_validation_loss': min_validation_loss, "id_to_tag": train_dataloader.dataset.id_to_tag, }, save_path) print(f"Model and Prototypes saved to {save_path}") return model, prototype_tensor, class_ids, total_batches_processed, min_validation_loss def train_model(model, optimizer, criterion, train_dataloader, device, num_epochs=1, batch_print_interval=5, val_dataloader=None, validation_interval=20, patience=5, min_delta=0.05): early_stopper = EarlyStopper(patience=patience, min_delta=min_delta) running_loss = 0.0 running_time = 0.0 total_batches_processed = 0 model.train() for epoch in range(num_epochs): start_time = time.time() for i, batch in enumerate(train_dataloader): batch_idx = i + 1 total_batches_processed += 1 batch = move_to_device(batch, device) A_anchor = batch['anchor'] A_clip = A_anchor.get('clip') A_seg = A_anchor.get('segformer') A_dpt = A_anchor.get('dpt') A_midas = A_anchor.get('midas') # --- Safe extraction for Positive (P) --- P_positive = batch['positive'] P_clip = P_positive.get('clip') P_seg = P_positive.get('segformer') P_dpt = P_positive.get('dpt') P_midas = P_positive.get('midas') # --- Safe extraction for Negative (N) --- N_negative = batch['negative'] N_clip = N_negative.get('clip') N_seg = N_negative.get('segformer') N_dpt = N_negative.get('dpt') N_midas = N_negative.get('midas') if A_clip.device.type != 'cuda' or next(model.parameters()).device.type != 'cuda': print("\n*** CRITICAL DEVICE SWITCH DETECTED ***") print(f"Tensor is on {A_clip.device.type}. Training speed will be severely impacted.") anchor_embed = model(A_clip, A_seg, A_dpt, A_midas) positive_embed = model(P_clip, P_seg, P_dpt, P_midas) negative_embed = model(N_clip, N_seg, N_dpt, N_midas) loss = criterion(anchor_embed, positive_embed, negative_embed) # Backpropagation optimizer.zero_grad() loss.backward() optimizer.step() running_loss += loss.item() if batch_idx % batch_print_interval == 0: avg_loss = running_loss / batch_print_interval elapsed_time = time.time() - start_time print(f'Epoch [{epoch+1}/{num_epochs}], ' f'Batch [{batch_idx}/{len(train_dataloader)}], ' f'Loss: {avg_loss:.4f}, ' f'Time/5 Batches: {elapsed_time:.2f}s') running_loss = 0.0 running_time += elapsed_time start_time = time.time() if val_dataloader is not None and batch_idx % validation_interval == 0: val_loss, num_val_batches = validate_model_mining( model, criterion, val_dataloader, device ) print(f"[Validation @ Batch {batch_idx}] Checked {num_val_batches} Val Batches. Loss: {val_loss:.4f}\n") if early_stopper.early_stop(val_loss, model): print(f"\n*** Early stopping triggered! ***") print(f"Validation loss has not improved for {early_stopper.patience} validation checks.") # Load the best weights before exiting if early_stopper.best_model_state is not None: model.load_state_dict(early_stopper.best_model_state) print("Restored best model weights.") return total_batches_processed, early_stopper.min_validation_loss model.train() start_time = time.time() if batch_idx % validation_interval == 0: if torch.cuda.is_available(): torch.cuda.empty_cache() print(f"--- Epoch {epoch+1} finished. ---") print(f"Total time for epoch: {running_time:.2f}s") if early_stopper.best_model_state is not None: model.load_state_dict(early_stopper.best_model_state) print("Training finished. Restored best model weights based on validation loss.") return total_batches_processed, early_stopper.min_validation_loss def validate_model(model, criterion, val_dataloader, device, batches_to_check=None): """ Evaluates the model on the validation dataset for a specified number of batches. If batches_to_check is None, it runs over the entire val_dataloader. """ model.eval() total_val_loss = 0.0 num_batches = 0 start_time = time.time() with torch.no_grad(): for batch_idx, batch in enumerate(val_dataloader): if batches_to_check is not None and batch_idx >= batches_to_check: break batch = move_to_device(batch, device) A_anchor = batch['anchor'] A_clip = A_anchor.get('clip') A_seg = A_anchor.get('segformer') A_dpt = A_anchor.get('dpt') A_midas = A_anchor.get('midas') # --- Safe extraction for Positive (P) --- P_positive = batch['positive'] P_clip = P_positive.get('clip') P_seg = P_positive.get('segformer') P_dpt = P_positive.get('dpt') P_midas = P_positive.get('midas') # --- Safe extraction for Negative (N) --- N_negative = batch['negative'] N_clip = N_negative.get('clip') N_seg = N_negative.get('segformer') N_dpt = N_negative.get('dpt') N_midas = N_negative.get('midas') anchor_embed = model(A_clip, A_seg, A_dpt, A_midas) positive_embed = model(P_clip, P_seg, P_dpt, P_midas) negative_embed = model(N_clip, N_seg, N_dpt, N_midas) loss = criterion(anchor_embed, positive_embed, negative_embed) total_val_loss += loss.item() num_batches += 1 end_time = time.time() validation_time = end_time - start_time print(f"Validation took {validation_time:.2f} seconds.") avg_val_loss = total_val_loss / num_batches if num_batches > 0 else 0.0 return avg_val_loss, num_batches def get_latest_checkpoint_info(save_dir, base_filename): search_pattern = os.path.join(save_dir, f"*_{base_filename}") existing_files = glob.glob(search_pattern) max_num = 0 latest_path = None pattern = re.compile(r'^(\d+)_') for file in existing_files: name = os.path.basename(file) match = pattern.match(name) if match: current_num = int(match.group(1)) if current_num > max_num: max_num = current_num latest_path = file return max_num, latest_path class EarlyStopper: """ Early stopping to stop training when the validation loss does not improve after a given patience. """ def __init__(self, patience=5, min_delta=0): self.patience = patience self.min_delta = min_delta self.counter = 0 self.min_validation_loss = np.inf self.best_model_state = None def early_stop(self, validation_loss, model): """ Returns True if early stopping criteria are met. Stores the best model state if the current loss is an improvement. """ if validation_loss < self.min_validation_loss - self.min_delta: self.min_validation_loss = validation_loss print(f"New minimum validation loss: {self.min_validation_loss:.4f}. Saving best model state.") self.counter = 0 self.best_model_state = model.state_dict() elif validation_loss > self.min_validation_loss + self.min_delta: self.counter += 1 if self.counter >= self.patience: return True return False def batch_hard_mining(embeddings, labels, margin): """ Implements BatchHard Triplet Mining. Finds the hardest positive and negative for every anchor in the batch. """ # Calculate all pairwise distances (Euclidean) pairwise_dist = torch.cdist(embeddings, embeddings, p=2.0) # Get masks # labels_equal[i, j] is True if labels[i] == labels[j] labels_equal = labels.unsqueeze(0) == labels.unsqueeze(1) # 1. Hardest Positive (P_h): Max distance among positives mask_anchor_positive = labels_equal.triu(diagonal=1) # Upper triangle, exclude diagonal (A=P) # Set non-positives to a very small number (or 0) so max() finds the hardest positive max_dist = pairwise_dist.max() dist_positive = pairwise_dist * mask_anchor_positive.float() # Find the max (hardest) positive distance for each row (Anchor) # We set non-positive distances to a small value so they don't affect the max dist_positive[mask_anchor_positive.logical_not()] = 0 # Find the hardest positive distance for each anchor (row) # This requires looking across all positive pairs that include that anchor. # It's computationally simpler to find the max distance for all P in the batch # for each A (row). # Max distance to a positive for each Anchor (row) dist_ap, _ = torch.max(dist_positive, dim=1) # 2. Hardest Negative (N_h): Min distance among negatives mask_anchor_negative = labels_equal.logical_not() # Set positives and diagonal to a very large number (inf) so min() finds the hardest negative dist_negative = pairwise_dist + max_dist * (1 - mask_anchor_negative.float()) # Find the min (hardest) negative distance for each Anchor (row) dist_an, _ = torch.min(dist_negative, dim=1) # 3. Compute Triplet Loss on the mined triplets # Loss: max(0, D_ap - D_an + margin) loss_triplet = torch.relu(dist_ap - dist_an + margin) # Return the average non-zero loss if loss_triplet.numel() == 0: return torch.tensor(0.0, device=embeddings.device, requires_grad=True) # Only average over the triplets that contribute to the loss (loss > 0) return loss_triplet.mean() def batch_semi_hard_mining(embeddings, labels, margin): """ Implements Batch Semi-Hard Triplet Mining. Finds the hardest *violating* positive and the negative that is *outside* the margin but *closer* than the hardest negative (or the one that is closest to d(a,p)) for every anchor in the batch. A more robust way is to select the negative that violates the margin but is closest to d(a,p), or simply select the hardest negative among those that satisfy the semi-hard condition: d(a,p) < d(a,n) < d(a,p) + margin. This implementation will strictly follow the original formulation: 1. Hardest Positive (P_h): Max distance among positives (d_ap). 2. Semi-Hard Negative (N_sh): Negative distance d_an such that d_ap < d_an, but d_an < d_ap + margin. If multiple satisfy this, we take the one closest to d_ap (the "hardest" semi-hard). """ # Calculate all pairwise distances (Euclidean) pairwise_dist = torch.cdist(embeddings, embeddings, p=2.0) # Get masks # labels_equal[i, j] is True if labels[i] == labels[j] labels_equal = labels.unsqueeze(0) == labels.unsqueeze(1) # --- 1. Hardest Positive (P_h) --- mask_anchor_positive = labels_equal.triu(diagonal=1) # Set non-positives to 0 (since dists are positive, 0 won't be max) dist_positive = pairwise_dist * mask_anchor_positive.float() dist_positive[mask_anchor_positive.logical_not()] = 0 # Max distance to a positive for each Anchor (row) dist_ap, _ = torch.max(dist_positive, dim=1) # --- 2. Semi-Hard Negative (N_sh) --- mask_anchor_negative = labels_equal.logical_not() # Ensure d(a,n) > d(a,p) (Hardness condition) # dist_ap is (B), expand to (B, B) dist_ap_expanded = dist_ap.unsqueeze(1) # Condition 1: d(a,n) > d(a,p) mask_positive_violating = pairwise_dist > dist_ap_expanded # Condition 2: d(a,n) < d(a,p) + margin (Semi-Hard condition) mask_margin_satisfying = pairwise_dist < dist_ap_expanded + margin # The Semi-Hard Negative Mask: # Must be a negative, must be harder than the positive, and must satisfy the margin. mask_semi_hard = mask_anchor_negative & mask_positive_violating & mask_margin_satisfying # If no semi-hard negative exists for an anchor, we must find a valid substitute # to avoid a zero-distance result, which could lead to loss=0 inappropriately. # Create a distance matrix for MIN operation: # 1. Start with the original pairwise_dist. dist_negative = pairwise_dist.clone() # 2. For non-semi-hard triplets, set the distance to a large value (Max + margin) # so the torch.min() operation will choose a semi-hard one, if it exists. # If *no* semi-hard negative exists for an anchor, we want to choose the # hardest negative that *violates* the margin (i.e., the hardest negative). # Temporarily set non-negatives to a large number dist_negative[mask_anchor_negative.logical_not()] = 1e9 # The distance to minimize is: d(a,n) - d(a,p) # We want the negative that is closest to d(a,p) but still satisfies the semi-hard condition. # We will choose the hardest *non-violating* negative that is still a negative (i.e., d(a,n) > d(a,p)) # If a semi-hard negative exists, its mask is True. # If a semi-hard negative *doesn't* exist, the common practice is to fall back to the # hardest negative (which would violate the margin $d(a,n) < d(a,p)+\alpha$). # Use the Semi-Hard Mask to define the relevant distances dist_semi_hard = pairwise_dist.clone() dist_semi_hard[mask_semi_hard.logical_not()] = 1e9 # Non-semi-hard dists are huge # Find the min distance among the semi-hard negatives for each Anchor (row) dist_an_semi_hard, _ = torch.min(dist_semi_hard, dim=1) # Handle Anchors with NO Semi-Hard Negative: # If the min distance is still 1e9, it means no semi-hard negative was found. # In this case, we fall back to the HARDEST negative (closest d(a,n) > d(a,p) but d(a,n) < d(a,p)+margin is NOT met). # Mask for anchors that found no semi-hard negative (distance is 1e9) mask_no_semi_hard = dist_an_semi_hard == 1e9 # For those anchors, fall back to the hardest negative (the original Batch Hard negative) if mask_no_semi_hard.any(): # Mask for all Negatives (Hardest Negative, d(a,n) < d(a,p) + margin is not required) dist_all_negatives = pairwise_dist.clone() dist_all_negatives[mask_anchor_negative.logical_not()] = 1e9 # Find the actual hardest negative for all anchors dist_an_hard, _ = torch.min(dist_all_negatives, dim=1) # Replace the 1e9 with the actual hardest negative distance dist_an_semi_hard[mask_no_semi_hard] = dist_an_hard[mask_no_semi_hard] dist_an = dist_an_semi_hard # Final negative distance to use # --- 3. Compute Triplet Loss on the mined triplets --- # Loss: max(0, D_ap - D_an + margin) loss_triplet = torch.relu(dist_ap - dist_an + margin) # Only average over the triplets that contribute to the loss (loss > 0) # Note: We must check for at least one positive triplet being present in the batch if loss_triplet.numel() == 0 or dist_ap.sum() == 0: return torch.tensor(0.0, device=embeddings.device, requires_grad=True) # Note: If we fall back to the hardest negative, the loss contribution might be 0 # (if d_an > d_ap + margin), but we still include it in the average (a common implementation choice). # Since we are using the `torch.relu` here, the final loss will only be averaged over *all* anchors # for which the loss calculation is > 0. # Final check: only average over anchors that actually have a hard positive (dist_ap > 0) # The most common implementation just uses the mean over the entire batch, which is simpler and less prone to edge cases. return loss_triplet.mean() def validate_model_mining(model, criterion, val_dataloader, device): """ Calculates validation loss using Online Hard Mining (BatchHard). Args: model: The FusedFeatureModel. criterion: The loss criterion (used primarily to extract the margin). val_dataloader: DataLoader using the MultiModalDataset. device: 'cuda' or 'cpu'. Returns: (float, int): Average validation loss and number of batches checked. """ model.eval() total_val_loss = 0.0 num_val_batches = 0 # CRITICAL: Extract the margin used in the criterion # Assuming criterion is torch.nn.TripletMarginLoss # MARGIN = criterion.margin with torch.no_grad(): for batch in val_dataloader: # Move batch to device (assuming move_to_device is defined) # You must ensure the move_to_device helper moves nested dicts correctly batch = move_to_device(batch, device) inputs = batch['anchor'] labels = batch['y'] # True class labels (Shape: Batch Size) # 1. Forward Pass: Compute all Embeddings in the batch embeddings = model(**inputs) # 2. Loss Calculation: Online Hard Mining Loss # The loss is computed only on the hardest triplets found in the batch. # loss = batch_semi_hard_mining(embeddings, labels, MARGIN) loss = criterion(embeddings, labels) total_val_loss += loss.item() num_val_batches += 1 if num_val_batches == 0: return 0.0, 0 avg_val_loss = total_val_loss / num_val_batches return avg_val_loss, num_val_batches # --- REVISED TRAINING LOOP --- def train_model_hard_mining(model, optimizer, criterion, train_dataloader, device, num_epochs=1, batch_print_interval=5, val_dataloader=None, validation_interval=20, patience=5, min_delta=0.05): # Extract margin from criterion (assuming it's TripletMarginLoss) # MARGIN = criterion.margin early_stopper = EarlyStopper(patience=patience, min_delta=min_delta) running_loss = 0.0 running_time = 0.0 total_batches_processed = 0 model.train() for epoch in range(num_epochs): start_time = time.time() for i, batch in enumerate(train_dataloader): batch_idx = i + 1 total_batches_processed += 1 batch = move_to_device(batch, device) inputs = batch['anchor'] labels = batch['y'] # True class labels (Shape: Batch Size) # --- 1. Compute all Embeddings in the batch --- # Note: We need to pass the tensors out of the dict structure for the model call # This is complex when inputs are dicts. We'll extract only the required tensors: # The model call needs to be simplified to handle the batch of inputs embeddings = model(**inputs) # Embeddings shape: (Batch Size, Embedding Dim) # --- 2. Online Hard Mining --- # Use the BatchHard miner to find the hardest triplets and calculate loss # loss = batch_semi_hard_mining(embeddings, labels, MARGIN) loss = criterion(embeddings, labels) # Backpropagation optimizer.zero_grad() loss.backward() optimizer.step() running_loss += loss.item() if batch_idx % batch_print_interval == 0: avg_loss = running_loss / batch_print_interval elapsed_time = time.time() - start_time print(f'Epoch [{epoch+1}/{num_epochs}], ' f'Batch [{batch_idx}/{len(train_dataloader)}], ' f'Online Hard Mining Loss: {avg_loss:.4f}, ' f'Time/{batch_print_interval} Batches: {elapsed_time:.2f}s') running_loss = 0.0 running_time += elapsed_time start_time = time.time() if val_dataloader is not None and batch_idx % validation_interval == 0: val_loss, num_val_batches = validate_model_mining( model, criterion, val_dataloader, device ) print(f"[Validation @ Batch {batch_idx}] Checked {num_val_batches} Val Batches. Loss: {val_loss:.4f}\n") if early_stopper.early_stop(val_loss, model): print(f"\n*** Early stopping triggered! ***") if early_stopper.best_model_state is not None: model.load_state_dict(early_stopper.best_model_state) print("Restored best model weights.") return total_batches_processed, early_stopper.min_validation_loss model.train() start_time = time.time() if batch_idx % validation_interval == 0 and torch.cuda.is_available(): torch.cuda.empty_cache() print(f"--- Epoch {epoch+1} finished. ---") print(f"Total time for epoch: {running_time:.2f}s") if early_stopper.best_model_state is not None: model.load_state_dict(early_stopper.best_model_state) print("Training finished. Restored best model weights based on validation loss.") return total_batches_processed, early_stopper.min_validation_loss def train_model_with_curriculum(model, optimizer, criterion, train_dataloader, device, num_epochs=1, batch_print_interval=5, val_dataloader=None, validation_interval=20, patience=5, min_delta=0.05): # CRITICAL: Extract the margin for Batch Hard Mining # Assuming criterion is torch.nn.TripletMarginLoss # MARGIN = criterion.margin early_stopper = EarlyStopper(patience=patience, min_delta=min_delta) FINE_TUNE_LR_FACTOR = 0.1 # e.g., drop LR by 10x CLIP_LAYERS_TO_UNFREEZE = 1 SEGFORMER_LAYERS_TO_UNFREEZE = 1 DPT_LAYERS_TO_UNFREEZE = 0 MIDAS_LAYERS_TO_UNFREEZE = 0 # You may add other backbones here, e.g., 'segformer': 2, 'dpt': 1 BACKBONES_TO_UNFREEZE = {'clip': CLIP_LAYERS_TO_UNFREEZE} # Get the current base LR initial_lr = optimizer.param_groups[0]['lr'] running_loss = 0.0 total_batches_processed = 0 model.train() for epoch in range(num_epochs): start_time = time.time() # Determine the mining strategy for the current epoch # Epoch 1 (index 0) uses standard pre-sampled triplets (Random/Semi-hard) # Epoch 2+ (index 1+) uses Online Batch Hard Mining is_hard_mining_epoch = epoch >= 1 if is_hard_mining_epoch and epoch == 1: print("\n--- Switching to Hard Mining Mode for Training Dataset and loading Best Model from Triplet Loss ---") train_dataloader.dataset.hard_mining_mode = True model.load_state_dict(early_stopper.best_model_state) if epoch == 2: print("\n--- PHASE 2: Starting Fine-Tuning (Epoch 3). Unfreezing last layers and dropping LR. ---") # 1. Unfreeze the last N layers of selected backbones for backbone_name, n_layers in BACKBONES_TO_UNFREEZE.items(): # The 'unfreeze_last_n_layers' function is assumed to be part of the model model.unfreeze_last_n_layers(backbone_name, n=n_layers) # 2. Drop the learning rate for stable fine-tuning new_lr = initial_lr * FINE_TUNE_LR_FACTOR adjust_learning_rate(optimizer, new_lr) # You need to define this helper function print(f"Learning Rate adjusted for fine-tuning: {initial_lr:.6f} -> {new_lr:.6f}") mining_strategy = "Hard Mining" if is_hard_mining_epoch else "Standard Triplet Loss" if epoch >= 2: mining_strategy += " + Fine-Tuning" print(f"\n--- Epoch {epoch+1}/{num_epochs} | Using {mining_strategy} ---") for i, batch in enumerate(train_dataloader): batch_idx = i + 1 total_batches_processed += 1 batch = move_to_device(batch, device) if not is_hard_mining_epoch: # --- STANDARD TRIPLET LOSS (Epoch 1) --- # Input structure is Anchor/Positive/Negative dicts A_anchor = batch['anchor'] P_positive = batch['positive'] N_negative = batch['negative'] # Extract multimodal inputs (A_clip, A_seg, etc. from A_anchor) A_clip, A_seg, A_dpt, A_midas = (A_anchor.get('clip'), A_anchor.get('segformer'), A_anchor.get('dpt'), A_anchor.get('midas')) P_clip, P_seg, P_dpt, P_midas = (P_positive.get('clip'), P_positive.get('segformer'), P_positive.get('dpt'), P_positive.get('midas')) N_clip, N_seg, N_dpt, N_midas = (N_negative.get('clip'), N_negative.get('segformer'), N_negative.get('dpt'), N_negative.get('midas')) anchor_embed = model(A_clip, A_seg, A_dpt, A_midas) positive_embed = model(P_clip, P_seg, P_dpt, P_midas) negative_embed = model(N_clip, N_seg, N_dpt, N_midas) loss = criterion(anchor_embed, positive_embed, negative_embed) else: # --- ONLINE BATCH HARD MINING (Epoch 2+) --- # Input structure is 'anchor' inputs and 'y' labels inputs = batch['anchor'] labels = batch['y'] # True class labels # Extract inputs for the model's forward pass clip_inputs = inputs.get('clip') segformer_inputs = inputs.get('segformer') dpt_inputs = inputs.get('dpt') midas_inputs = inputs.get('midas') # Compute all Embeddings in the batch embeddings = model(clip_inputs, segformer_inputs, dpt_inputs, midas_inputs) # Use the BatchHard miner to find the hardest triplets and calculate loss # loss = batch_hard_mining(embeddings, labels, MARGIN) loss = criterion(embeddings, labels) # --- BACKPROPAGATION --- optimizer.zero_grad() loss.backward() optimizer.step() running_loss += loss.item() # --- PRINTING --- if batch_idx % batch_print_interval == 0: avg_loss = running_loss / batch_print_interval elapsed_time = time.time() - start_time print(f'Epoch [{epoch+1}/{num_epochs}], ' f'Batch [{batch_idx}/{len(train_dataloader)}], ' f'{mining_strategy} Loss: {avg_loss:.4f}, ' f'Time/{batch_print_interval} Batches: {elapsed_time:.2f}s') running_loss = 0.0 start_time = time.time() # --- VALIDATION --- if val_dataloader is not None and batch_idx % validation_interval == 0: # **IMPORTANT:** Always use the more robust Hard Mining validation # to get a real assessment of the embedding space's quality. val_loss, num_val_batches = validate_model_mining( model, criterion, val_dataloader, device ) print(f"[Validation @ Batch {batch_idx}] Checked {num_val_batches} Val Batches. Loss: {val_loss:.4f}\n") early_stopper.early_stop(val_loss, model) if early_stopper.early_stop(val_loss, model): print(f"\n*** Early stopping triggered! ***") if early_stopper.best_model_state is not None: model.load_state_dict(early_stopper.best_model_state) print("Restored best model weights.") return total_batches_processed, early_stopper.min_validation_loss model.train() start_time = time.time() if torch.cuda.is_available(): torch.cuda.empty_cache() print(f"--- Epoch {epoch+1} finished. ---") if early_stopper.best_model_state is not None: model.load_state_dict(early_stopper.best_model_state) print("Training finished. Restored best model weights based on validation loss.") return total_batches_processed, early_stopper.min_validation_loss def adjust_learning_rate(optimizer, new_lr): """ Sets the learning rate for all parameter groups in the optimizer. Args: optimizer (torch.optim.Optimizer): The optimizer whose learning rate to adjust. new_lr (float): The new learning rate value. """ for param_group in optimizer.param_groups: param_group['lr'] = new_lr