Spaces:
Paused
Paused
| import numpy as np | |
| import torch | |
| from training.test_helpers import generate_class_prototypes | |
| from training.utils import move_to_device | |
| import time | |
| import os | |
| import glob | |
| import re | |
| MODELS_SAVE_DIR = "training/saved_models/newclip_mega8" | |
| def train_or_load_model(model, load_network=True, retrain_model=False, specific_model_name=None, optimizer=None, criterion=None, train_dataloader=None, device=None, num_epochs=1, batch_print_interval=5, val_dataloader=None, validation_interval=20, model_name="fused_feature_model.pth", patience=5, min_delta=0.05, train_type="standard"): | |
| save_dir = os.path.dirname(model_name) if os.path.dirname(model_name) else MODELS_SAVE_DIR | |
| os.makedirs(save_dir, exist_ok=True) | |
| latest_path = None | |
| # 1. Check for a specific checkpoint name provided by the user | |
| if specific_model_name: | |
| latest_path = os.path.join(save_dir, specific_model_name) | |
| if not os.path.exists(latest_path): | |
| print(f"Warning: Specific checkpoint file not found at: {latest_path}. Skipping model load.") | |
| latest_path = None # Do not attempt to load if path is invalid | |
| # 2. If no specific name was provided or it was invalid, look for the latest checkpoint | |
| if latest_path is None: | |
| base_filename = os.path.basename(model_name) | |
| max_num, latest_path = get_latest_checkpoint_info(save_dir, base_filename) | |
| # --- Loading model --- | |
| if load_network and latest_path: | |
| try: | |
| print(f"Attempting to load latest model and prototypes: {latest_path}") | |
| checkpoint = torch.load(latest_path, map_location=device) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| total_batches_processed = checkpoint.get('total_batches_processed', None) | |
| min_validation_loss = checkpoint.get('min_validation_loss', None) | |
| print(f"Model with min validation loss {min_validation_loss} loaded") | |
| if not retrain_model: | |
| model.eval() | |
| if 'prototype_tensor' in checkpoint and 'class_ids' in checkpoint: | |
| prototype_tensor = checkpoint['prototype_tensor'] | |
| class_ids = checkpoint['class_ids'] | |
| print(f"Model and Prototypes successfully loaded from {latest_path}") | |
| return model, prototype_tensor, class_ids, total_batches_processed, min_validation_loss | |
| else: | |
| print(f"Model state loaded successfully from {latest_path}, but **Prototypes are missing**. Generating prototypes now.") | |
| prototype_tensor, class_ids = generate_class_prototypes(model, train_dataloader, device) | |
| torch.save({ | |
| 'model_state_dict': model.state_dict(), | |
| 'prototype_tensor': prototype_tensor, | |
| 'class_ids': class_ids, | |
| 'total_batches_processed': total_batches_processed, | |
| "id_to_tag": train_dataloader.dataset.id_to_tag, | |
| }, latest_path) | |
| print(f"Prototypes generated and saved to {latest_path}") | |
| return model, prototype_tensor, class_ids, total_batches_processed, min_validation_loss | |
| except Exception as e: | |
| print(f"Error loading model from {latest_path}: {e}") | |
| print("Proceeding to train from scratch.") | |
| # --- Training --- | |
| print(f"Training from scratch (or resuming failed load).") | |
| if train_type == "standard": | |
| total_batches_processed, min_validation_loss = train_model( | |
| model, | |
| optimizer, | |
| criterion, | |
| train_dataloader, | |
| device, | |
| num_epochs=num_epochs, | |
| batch_print_interval=batch_print_interval, | |
| val_dataloader=val_dataloader, | |
| validation_interval=validation_interval, | |
| patience=patience, | |
| min_delta=min_delta | |
| ) | |
| elif train_type == "hardmining": | |
| total_batches_processed, min_validation_loss = train_model_hard_mining( | |
| model, | |
| optimizer, | |
| criterion, | |
| train_dataloader, | |
| device, | |
| num_epochs=num_epochs, | |
| batch_print_interval=batch_print_interval, | |
| val_dataloader=val_dataloader, | |
| validation_interval=validation_interval, | |
| patience=patience, | |
| min_delta=min_delta | |
| ) | |
| elif train_type == "curriculum": | |
| total_batches_processed, min_validation_loss = train_model_with_curriculum( | |
| model, | |
| optimizer, | |
| criterion, | |
| train_dataloader, | |
| device, | |
| num_epochs=num_epochs, | |
| batch_print_interval=batch_print_interval, | |
| val_dataloader=val_dataloader, | |
| validation_interval=validation_interval, | |
| patience=patience, | |
| min_delta=min_delta | |
| ) | |
| # --- Saving --- | |
| next_num = max_num + 1 | |
| new_filename = f"{next_num}_{base_filename}" | |
| save_path = os.path.join(save_dir, new_filename) | |
| print(f"Min Validation Loss after training: {min_validation_loss}") | |
| torch.save({ | |
| 'model_state_dict': model.state_dict(), | |
| "total_batches_processed": total_batches_processed, | |
| "min_validation_loss": min_validation_loss | |
| }, save_path) | |
| print(f"Model saved (before prototypes) to {save_path}") | |
| prototype_tensor, class_ids = generate_class_prototypes(model, train_dataloader, device) | |
| torch.save({ | |
| 'model_state_dict': model.state_dict(), | |
| 'prototype_tensor': prototype_tensor, | |
| 'class_ids': class_ids, | |
| 'total_batches_processed': total_batches_processed, | |
| 'min_validation_loss': min_validation_loss, | |
| "id_to_tag": train_dataloader.dataset.id_to_tag, | |
| }, save_path) | |
| print(f"Model and Prototypes saved to {save_path}") | |
| return model, prototype_tensor, class_ids, total_batches_processed, min_validation_loss | |
| def train_model(model, optimizer, criterion, train_dataloader, device, num_epochs=1, batch_print_interval=5, val_dataloader=None, validation_interval=20, patience=5, min_delta=0.05): | |
| early_stopper = EarlyStopper(patience=patience, min_delta=min_delta) | |
| running_loss = 0.0 | |
| running_time = 0.0 | |
| total_batches_processed = 0 | |
| model.train() | |
| for epoch in range(num_epochs): | |
| start_time = time.time() | |
| for i, batch in enumerate(train_dataloader): | |
| batch_idx = i + 1 | |
| total_batches_processed += 1 | |
| batch = move_to_device(batch, device) | |
| A_anchor = batch['anchor'] | |
| A_clip = A_anchor.get('clip') | |
| A_seg = A_anchor.get('segformer') | |
| A_dpt = A_anchor.get('dpt') | |
| A_midas = A_anchor.get('midas') | |
| # --- Safe extraction for Positive (P) --- | |
| P_positive = batch['positive'] | |
| P_clip = P_positive.get('clip') | |
| P_seg = P_positive.get('segformer') | |
| P_dpt = P_positive.get('dpt') | |
| P_midas = P_positive.get('midas') | |
| # --- Safe extraction for Negative (N) --- | |
| N_negative = batch['negative'] | |
| N_clip = N_negative.get('clip') | |
| N_seg = N_negative.get('segformer') | |
| N_dpt = N_negative.get('dpt') | |
| N_midas = N_negative.get('midas') | |
| if A_clip.device.type != 'cuda' or next(model.parameters()).device.type != 'cuda': | |
| print("\n*** CRITICAL DEVICE SWITCH DETECTED ***") | |
| print(f"Tensor is on {A_clip.device.type}. Training speed will be severely impacted.") | |
| anchor_embed = model(A_clip, A_seg, A_dpt, A_midas) | |
| positive_embed = model(P_clip, P_seg, P_dpt, P_midas) | |
| negative_embed = model(N_clip, N_seg, N_dpt, N_midas) | |
| loss = criterion(anchor_embed, positive_embed, negative_embed) | |
| # Backpropagation | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| running_loss += loss.item() | |
| if batch_idx % batch_print_interval == 0: | |
| avg_loss = running_loss / batch_print_interval | |
| elapsed_time = time.time() - start_time | |
| print(f'Epoch [{epoch+1}/{num_epochs}], ' | |
| f'Batch [{batch_idx}/{len(train_dataloader)}], ' | |
| f'Loss: {avg_loss:.4f}, ' | |
| f'Time/5 Batches: {elapsed_time:.2f}s') | |
| running_loss = 0.0 | |
| running_time += elapsed_time | |
| start_time = time.time() | |
| if val_dataloader is not None and batch_idx % validation_interval == 0: | |
| val_loss, num_val_batches = validate_model_mining( | |
| model, criterion, val_dataloader, device | |
| ) | |
| print(f"[Validation @ Batch {batch_idx}] Checked {num_val_batches} Val Batches. Loss: {val_loss:.4f}\n") | |
| if early_stopper.early_stop(val_loss, model): | |
| print(f"\n*** Early stopping triggered! ***") | |
| print(f"Validation loss has not improved for {early_stopper.patience} validation checks.") | |
| # Load the best weights before exiting | |
| if early_stopper.best_model_state is not None: | |
| model.load_state_dict(early_stopper.best_model_state) | |
| print("Restored best model weights.") | |
| return total_batches_processed, early_stopper.min_validation_loss | |
| model.train() | |
| start_time = time.time() | |
| if batch_idx % validation_interval == 0: | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| print(f"--- Epoch {epoch+1} finished. ---") | |
| print(f"Total time for epoch: {running_time:.2f}s") | |
| if early_stopper.best_model_state is not None: | |
| model.load_state_dict(early_stopper.best_model_state) | |
| print("Training finished. Restored best model weights based on validation loss.") | |
| return total_batches_processed, early_stopper.min_validation_loss | |
| def validate_model(model, criterion, val_dataloader, device, batches_to_check=None): | |
| """ | |
| Evaluates the model on the validation dataset for a specified number of batches. | |
| If batches_to_check is None, it runs over the entire val_dataloader. | |
| """ | |
| model.eval() | |
| total_val_loss = 0.0 | |
| num_batches = 0 | |
| start_time = time.time() | |
| with torch.no_grad(): | |
| for batch_idx, batch in enumerate(val_dataloader): | |
| if batches_to_check is not None and batch_idx >= batches_to_check: | |
| break | |
| batch = move_to_device(batch, device) | |
| A_anchor = batch['anchor'] | |
| A_clip = A_anchor.get('clip') | |
| A_seg = A_anchor.get('segformer') | |
| A_dpt = A_anchor.get('dpt') | |
| A_midas = A_anchor.get('midas') | |
| # --- Safe extraction for Positive (P) --- | |
| P_positive = batch['positive'] | |
| P_clip = P_positive.get('clip') | |
| P_seg = P_positive.get('segformer') | |
| P_dpt = P_positive.get('dpt') | |
| P_midas = P_positive.get('midas') | |
| # --- Safe extraction for Negative (N) --- | |
| N_negative = batch['negative'] | |
| N_clip = N_negative.get('clip') | |
| N_seg = N_negative.get('segformer') | |
| N_dpt = N_negative.get('dpt') | |
| N_midas = N_negative.get('midas') | |
| anchor_embed = model(A_clip, A_seg, A_dpt, A_midas) | |
| positive_embed = model(P_clip, P_seg, P_dpt, P_midas) | |
| negative_embed = model(N_clip, N_seg, N_dpt, N_midas) | |
| loss = criterion(anchor_embed, positive_embed, negative_embed) | |
| total_val_loss += loss.item() | |
| num_batches += 1 | |
| end_time = time.time() | |
| validation_time = end_time - start_time | |
| print(f"Validation took {validation_time:.2f} seconds.") | |
| avg_val_loss = total_val_loss / num_batches if num_batches > 0 else 0.0 | |
| return avg_val_loss, num_batches | |
| def get_latest_checkpoint_info(save_dir, base_filename): | |
| search_pattern = os.path.join(save_dir, f"*_{base_filename}") | |
| existing_files = glob.glob(search_pattern) | |
| max_num = 0 | |
| latest_path = None | |
| pattern = re.compile(r'^(\d+)_') | |
| for file in existing_files: | |
| name = os.path.basename(file) | |
| match = pattern.match(name) | |
| if match: | |
| current_num = int(match.group(1)) | |
| if current_num > max_num: | |
| max_num = current_num | |
| latest_path = file | |
| return max_num, latest_path | |
| class EarlyStopper: | |
| """ | |
| Early stopping to stop training when the validation loss does not improve | |
| after a given patience. | |
| """ | |
| def __init__(self, patience=5, min_delta=0): | |
| self.patience = patience | |
| self.min_delta = min_delta | |
| self.counter = 0 | |
| self.min_validation_loss = np.inf | |
| self.best_model_state = None | |
| def early_stop(self, validation_loss, model): | |
| """ | |
| Returns True if early stopping criteria are met. | |
| Stores the best model state if the current loss is an improvement. | |
| """ | |
| if validation_loss < self.min_validation_loss - self.min_delta: | |
| self.min_validation_loss = validation_loss | |
| print(f"New minimum validation loss: {self.min_validation_loss:.4f}. Saving best model state.") | |
| self.counter = 0 | |
| self.best_model_state = model.state_dict() | |
| elif validation_loss > self.min_validation_loss + self.min_delta: | |
| self.counter += 1 | |
| if self.counter >= self.patience: | |
| return True | |
| return False | |
| def batch_hard_mining(embeddings, labels, margin): | |
| """ | |
| Implements BatchHard Triplet Mining. Finds the hardest positive and negative | |
| for every anchor in the batch. | |
| """ | |
| # Calculate all pairwise distances (Euclidean) | |
| pairwise_dist = torch.cdist(embeddings, embeddings, p=2.0) | |
| # Get masks | |
| # labels_equal[i, j] is True if labels[i] == labels[j] | |
| labels_equal = labels.unsqueeze(0) == labels.unsqueeze(1) | |
| # 1. Hardest Positive (P_h): Max distance among positives | |
| mask_anchor_positive = labels_equal.triu(diagonal=1) # Upper triangle, exclude diagonal (A=P) | |
| # Set non-positives to a very small number (or 0) so max() finds the hardest positive | |
| max_dist = pairwise_dist.max() | |
| dist_positive = pairwise_dist * mask_anchor_positive.float() | |
| # Find the max (hardest) positive distance for each row (Anchor) | |
| # We set non-positive distances to a small value so they don't affect the max | |
| dist_positive[mask_anchor_positive.logical_not()] = 0 | |
| # Find the hardest positive distance for each anchor (row) | |
| # This requires looking across all positive pairs that include that anchor. | |
| # It's computationally simpler to find the max distance for all P in the batch | |
| # for each A (row). | |
| # Max distance to a positive for each Anchor (row) | |
| dist_ap, _ = torch.max(dist_positive, dim=1) | |
| # 2. Hardest Negative (N_h): Min distance among negatives | |
| mask_anchor_negative = labels_equal.logical_not() | |
| # Set positives and diagonal to a very large number (inf) so min() finds the hardest negative | |
| dist_negative = pairwise_dist + max_dist * (1 - mask_anchor_negative.float()) | |
| # Find the min (hardest) negative distance for each Anchor (row) | |
| dist_an, _ = torch.min(dist_negative, dim=1) | |
| # 3. Compute Triplet Loss on the mined triplets | |
| # Loss: max(0, D_ap - D_an + margin) | |
| loss_triplet = torch.relu(dist_ap - dist_an + margin) | |
| # Return the average non-zero loss | |
| if loss_triplet.numel() == 0: | |
| return torch.tensor(0.0, device=embeddings.device, requires_grad=True) | |
| # Only average over the triplets that contribute to the loss (loss > 0) | |
| return loss_triplet.mean() | |
| def batch_semi_hard_mining(embeddings, labels, margin): | |
| """ | |
| Implements Batch Semi-Hard Triplet Mining. Finds the hardest *violating* positive and the negative that is *outside* the margin but *closer* than the hardest negative (or the one that is closest to d(a,p)) for | |
| every anchor in the batch. | |
| A more robust way is to select the negative that violates the margin but | |
| is closest to d(a,p), or simply select the hardest negative among those | |
| that satisfy the semi-hard condition: d(a,p) < d(a,n) < d(a,p) + margin. | |
| This implementation will strictly follow the original formulation: | |
| 1. Hardest Positive (P_h): Max distance among positives (d_ap). | |
| 2. Semi-Hard Negative (N_sh): Negative distance d_an such that | |
| d_ap < d_an, but d_an < d_ap + margin. | |
| If multiple satisfy this, we take the one closest to d_ap (the "hardest" semi-hard). | |
| """ | |
| # Calculate all pairwise distances (Euclidean) | |
| pairwise_dist = torch.cdist(embeddings, embeddings, p=2.0) | |
| # Get masks | |
| # labels_equal[i, j] is True if labels[i] == labels[j] | |
| labels_equal = labels.unsqueeze(0) == labels.unsqueeze(1) | |
| # --- 1. Hardest Positive (P_h) --- | |
| mask_anchor_positive = labels_equal.triu(diagonal=1) | |
| # Set non-positives to 0 (since dists are positive, 0 won't be max) | |
| dist_positive = pairwise_dist * mask_anchor_positive.float() | |
| dist_positive[mask_anchor_positive.logical_not()] = 0 | |
| # Max distance to a positive for each Anchor (row) | |
| dist_ap, _ = torch.max(dist_positive, dim=1) | |
| # --- 2. Semi-Hard Negative (N_sh) --- | |
| mask_anchor_negative = labels_equal.logical_not() | |
| # Ensure d(a,n) > d(a,p) (Hardness condition) | |
| # dist_ap is (B), expand to (B, B) | |
| dist_ap_expanded = dist_ap.unsqueeze(1) | |
| # Condition 1: d(a,n) > d(a,p) | |
| mask_positive_violating = pairwise_dist > dist_ap_expanded | |
| # Condition 2: d(a,n) < d(a,p) + margin (Semi-Hard condition) | |
| mask_margin_satisfying = pairwise_dist < dist_ap_expanded + margin | |
| # The Semi-Hard Negative Mask: | |
| # Must be a negative, must be harder than the positive, and must satisfy the margin. | |
| mask_semi_hard = mask_anchor_negative & mask_positive_violating & mask_margin_satisfying | |
| # If no semi-hard negative exists for an anchor, we must find a valid substitute | |
| # to avoid a zero-distance result, which could lead to loss=0 inappropriately. | |
| # Create a distance matrix for MIN operation: | |
| # 1. Start with the original pairwise_dist. | |
| dist_negative = pairwise_dist.clone() | |
| # 2. For non-semi-hard triplets, set the distance to a large value (Max + margin) | |
| # so the torch.min() operation will choose a semi-hard one, if it exists. | |
| # If *no* semi-hard negative exists for an anchor, we want to choose the | |
| # hardest negative that *violates* the margin (i.e., the hardest negative). | |
| # Temporarily set non-negatives to a large number | |
| dist_negative[mask_anchor_negative.logical_not()] = 1e9 | |
| # The distance to minimize is: d(a,n) - d(a,p) | |
| # We want the negative that is closest to d(a,p) but still satisfies the semi-hard condition. | |
| # We will choose the hardest *non-violating* negative that is still a negative (i.e., d(a,n) > d(a,p)) | |
| # If a semi-hard negative exists, its mask is True. | |
| # If a semi-hard negative *doesn't* exist, the common practice is to fall back to the | |
| # hardest negative (which would violate the margin $d(a,n) < d(a,p)+\alpha$). | |
| # Use the Semi-Hard Mask to define the relevant distances | |
| dist_semi_hard = pairwise_dist.clone() | |
| dist_semi_hard[mask_semi_hard.logical_not()] = 1e9 # Non-semi-hard dists are huge | |
| # Find the min distance among the semi-hard negatives for each Anchor (row) | |
| dist_an_semi_hard, _ = torch.min(dist_semi_hard, dim=1) | |
| # Handle Anchors with NO Semi-Hard Negative: | |
| # If the min distance is still 1e9, it means no semi-hard negative was found. | |
| # In this case, we fall back to the HARDEST negative (closest d(a,n) > d(a,p) but d(a,n) < d(a,p)+margin is NOT met). | |
| # Mask for anchors that found no semi-hard negative (distance is 1e9) | |
| mask_no_semi_hard = dist_an_semi_hard == 1e9 | |
| # For those anchors, fall back to the hardest negative (the original Batch Hard negative) | |
| if mask_no_semi_hard.any(): | |
| # Mask for all Negatives (Hardest Negative, d(a,n) < d(a,p) + margin is not required) | |
| dist_all_negatives = pairwise_dist.clone() | |
| dist_all_negatives[mask_anchor_negative.logical_not()] = 1e9 | |
| # Find the actual hardest negative for all anchors | |
| dist_an_hard, _ = torch.min(dist_all_negatives, dim=1) | |
| # Replace the 1e9 with the actual hardest negative distance | |
| dist_an_semi_hard[mask_no_semi_hard] = dist_an_hard[mask_no_semi_hard] | |
| dist_an = dist_an_semi_hard # Final negative distance to use | |
| # --- 3. Compute Triplet Loss on the mined triplets --- | |
| # Loss: max(0, D_ap - D_an + margin) | |
| loss_triplet = torch.relu(dist_ap - dist_an + margin) | |
| # Only average over the triplets that contribute to the loss (loss > 0) | |
| # Note: We must check for at least one positive triplet being present in the batch | |
| if loss_triplet.numel() == 0 or dist_ap.sum() == 0: | |
| return torch.tensor(0.0, device=embeddings.device, requires_grad=True) | |
| # Note: If we fall back to the hardest negative, the loss contribution might be 0 | |
| # (if d_an > d_ap + margin), but we still include it in the average (a common implementation choice). | |
| # Since we are using the `torch.relu` here, the final loss will only be averaged over *all* anchors | |
| # for which the loss calculation is > 0. | |
| # Final check: only average over anchors that actually have a hard positive (dist_ap > 0) | |
| # The most common implementation just uses the mean over the entire batch, which is simpler and less prone to edge cases. | |
| return loss_triplet.mean() | |
| def validate_model_mining(model, criterion, val_dataloader, device): | |
| """ | |
| Calculates validation loss using Online Hard Mining (BatchHard). | |
| Args: | |
| model: The FusedFeatureModel. | |
| criterion: The loss criterion (used primarily to extract the margin). | |
| val_dataloader: DataLoader using the MultiModalDataset. | |
| device: 'cuda' or 'cpu'. | |
| Returns: | |
| (float, int): Average validation loss and number of batches checked. | |
| """ | |
| model.eval() | |
| total_val_loss = 0.0 | |
| num_val_batches = 0 | |
| # CRITICAL: Extract the margin used in the criterion | |
| # Assuming criterion is torch.nn.TripletMarginLoss | |
| # MARGIN = criterion.margin | |
| with torch.no_grad(): | |
| for batch in val_dataloader: | |
| # Move batch to device (assuming move_to_device is defined) | |
| # You must ensure the move_to_device helper moves nested dicts correctly | |
| batch = move_to_device(batch, device) | |
| inputs = batch['anchor'] | |
| labels = batch['y'] # True class labels (Shape: Batch Size) | |
| # 1. Forward Pass: Compute all Embeddings in the batch | |
| embeddings = model(**inputs) | |
| # 2. Loss Calculation: Online Hard Mining Loss | |
| # The loss is computed only on the hardest triplets found in the batch. | |
| # loss = batch_semi_hard_mining(embeddings, labels, MARGIN) | |
| loss = criterion(embeddings, labels) | |
| total_val_loss += loss.item() | |
| num_val_batches += 1 | |
| if num_val_batches == 0: | |
| return 0.0, 0 | |
| avg_val_loss = total_val_loss / num_val_batches | |
| return avg_val_loss, num_val_batches | |
| # --- REVISED TRAINING LOOP --- | |
| def train_model_hard_mining(model, optimizer, criterion, train_dataloader, device, num_epochs=1, batch_print_interval=5, val_dataloader=None, validation_interval=20, patience=5, min_delta=0.05): | |
| # Extract margin from criterion (assuming it's TripletMarginLoss) | |
| # MARGIN = criterion.margin | |
| early_stopper = EarlyStopper(patience=patience, min_delta=min_delta) | |
| running_loss = 0.0 | |
| running_time = 0.0 | |
| total_batches_processed = 0 | |
| model.train() | |
| for epoch in range(num_epochs): | |
| start_time = time.time() | |
| for i, batch in enumerate(train_dataloader): | |
| batch_idx = i + 1 | |
| total_batches_processed += 1 | |
| batch = move_to_device(batch, device) | |
| inputs = batch['anchor'] | |
| labels = batch['y'] # True class labels (Shape: Batch Size) | |
| # --- 1. Compute all Embeddings in the batch --- | |
| # Note: We need to pass the tensors out of the dict structure for the model call | |
| # This is complex when inputs are dicts. We'll extract only the required tensors: | |
| # The model call needs to be simplified to handle the batch of inputs | |
| embeddings = model(**inputs) # Embeddings shape: (Batch Size, Embedding Dim) | |
| # --- 2. Online Hard Mining --- | |
| # Use the BatchHard miner to find the hardest triplets and calculate loss | |
| # loss = batch_semi_hard_mining(embeddings, labels, MARGIN) | |
| loss = criterion(embeddings, labels) | |
| # Backpropagation | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| running_loss += loss.item() | |
| if batch_idx % batch_print_interval == 0: | |
| avg_loss = running_loss / batch_print_interval | |
| elapsed_time = time.time() - start_time | |
| print(f'Epoch [{epoch+1}/{num_epochs}], ' | |
| f'Batch [{batch_idx}/{len(train_dataloader)}], ' | |
| f'Online Hard Mining Loss: {avg_loss:.4f}, ' | |
| f'Time/{batch_print_interval} Batches: {elapsed_time:.2f}s') | |
| running_loss = 0.0 | |
| running_time += elapsed_time | |
| start_time = time.time() | |
| if val_dataloader is not None and batch_idx % validation_interval == 0: | |
| val_loss, num_val_batches = validate_model_mining( | |
| model, criterion, val_dataloader, device | |
| ) | |
| print(f"[Validation @ Batch {batch_idx}] Checked {num_val_batches} Val Batches. Loss: {val_loss:.4f}\n") | |
| if early_stopper.early_stop(val_loss, model): | |
| print(f"\n*** Early stopping triggered! ***") | |
| if early_stopper.best_model_state is not None: | |
| model.load_state_dict(early_stopper.best_model_state) | |
| print("Restored best model weights.") | |
| return total_batches_processed, early_stopper.min_validation_loss | |
| model.train() | |
| start_time = time.time() | |
| if batch_idx % validation_interval == 0 and torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| print(f"--- Epoch {epoch+1} finished. ---") | |
| print(f"Total time for epoch: {running_time:.2f}s") | |
| if early_stopper.best_model_state is not None: | |
| model.load_state_dict(early_stopper.best_model_state) | |
| print("Training finished. Restored best model weights based on validation loss.") | |
| return total_batches_processed, early_stopper.min_validation_loss | |
| def train_model_with_curriculum(model, optimizer, criterion, train_dataloader, device, num_epochs=1, batch_print_interval=5, val_dataloader=None, validation_interval=20, patience=5, min_delta=0.05): | |
| # CRITICAL: Extract the margin for Batch Hard Mining | |
| # Assuming criterion is torch.nn.TripletMarginLoss | |
| # MARGIN = criterion.margin | |
| early_stopper = EarlyStopper(patience=patience, min_delta=min_delta) | |
| FINE_TUNE_LR_FACTOR = 0.1 # e.g., drop LR by 10x | |
| CLIP_LAYERS_TO_UNFREEZE = 1 | |
| SEGFORMER_LAYERS_TO_UNFREEZE = 1 | |
| DPT_LAYERS_TO_UNFREEZE = 0 | |
| MIDAS_LAYERS_TO_UNFREEZE = 0 | |
| # You may add other backbones here, e.g., 'segformer': 2, 'dpt': 1 | |
| BACKBONES_TO_UNFREEZE = {'clip': CLIP_LAYERS_TO_UNFREEZE} | |
| # Get the current base LR | |
| initial_lr = optimizer.param_groups[0]['lr'] | |
| running_loss = 0.0 | |
| total_batches_processed = 0 | |
| model.train() | |
| for epoch in range(num_epochs): | |
| start_time = time.time() | |
| # Determine the mining strategy for the current epoch | |
| # Epoch 1 (index 0) uses standard pre-sampled triplets (Random/Semi-hard) | |
| # Epoch 2+ (index 1+) uses Online Batch Hard Mining | |
| is_hard_mining_epoch = epoch >= 1 | |
| if is_hard_mining_epoch and epoch == 1: | |
| print("\n--- Switching to Hard Mining Mode for Training Dataset and loading Best Model from Triplet Loss ---") | |
| train_dataloader.dataset.hard_mining_mode = True | |
| model.load_state_dict(early_stopper.best_model_state) | |
| if epoch == 2: | |
| print("\n--- PHASE 2: Starting Fine-Tuning (Epoch 3). Unfreezing last layers and dropping LR. ---") | |
| # 1. Unfreeze the last N layers of selected backbones | |
| for backbone_name, n_layers in BACKBONES_TO_UNFREEZE.items(): | |
| # The 'unfreeze_last_n_layers' function is assumed to be part of the model | |
| model.unfreeze_last_n_layers(backbone_name, n=n_layers) | |
| # 2. Drop the learning rate for stable fine-tuning | |
| new_lr = initial_lr * FINE_TUNE_LR_FACTOR | |
| adjust_learning_rate(optimizer, new_lr) # You need to define this helper function | |
| print(f"Learning Rate adjusted for fine-tuning: {initial_lr:.6f} -> {new_lr:.6f}") | |
| mining_strategy = "Hard Mining" if is_hard_mining_epoch else "Standard Triplet Loss" | |
| if epoch >= 2: | |
| mining_strategy += " + Fine-Tuning" | |
| print(f"\n--- Epoch {epoch+1}/{num_epochs} | Using {mining_strategy} ---") | |
| for i, batch in enumerate(train_dataloader): | |
| batch_idx = i + 1 | |
| total_batches_processed += 1 | |
| batch = move_to_device(batch, device) | |
| if not is_hard_mining_epoch: | |
| # --- STANDARD TRIPLET LOSS (Epoch 1) --- | |
| # Input structure is Anchor/Positive/Negative dicts | |
| A_anchor = batch['anchor'] | |
| P_positive = batch['positive'] | |
| N_negative = batch['negative'] | |
| # Extract multimodal inputs (A_clip, A_seg, etc. from A_anchor) | |
| A_clip, A_seg, A_dpt, A_midas = (A_anchor.get('clip'), A_anchor.get('segformer'), A_anchor.get('dpt'), A_anchor.get('midas')) | |
| P_clip, P_seg, P_dpt, P_midas = (P_positive.get('clip'), P_positive.get('segformer'), P_positive.get('dpt'), P_positive.get('midas')) | |
| N_clip, N_seg, N_dpt, N_midas = (N_negative.get('clip'), N_negative.get('segformer'), N_negative.get('dpt'), N_negative.get('midas')) | |
| anchor_embed = model(A_clip, A_seg, A_dpt, A_midas) | |
| positive_embed = model(P_clip, P_seg, P_dpt, P_midas) | |
| negative_embed = model(N_clip, N_seg, N_dpt, N_midas) | |
| loss = criterion(anchor_embed, positive_embed, negative_embed) | |
| else: | |
| # --- ONLINE BATCH HARD MINING (Epoch 2+) --- | |
| # Input structure is 'anchor' inputs and 'y' labels | |
| inputs = batch['anchor'] | |
| labels = batch['y'] # True class labels | |
| # Extract inputs for the model's forward pass | |
| clip_inputs = inputs.get('clip') | |
| segformer_inputs = inputs.get('segformer') | |
| dpt_inputs = inputs.get('dpt') | |
| midas_inputs = inputs.get('midas') | |
| # Compute all Embeddings in the batch | |
| embeddings = model(clip_inputs, segformer_inputs, dpt_inputs, midas_inputs) | |
| # Use the BatchHard miner to find the hardest triplets and calculate loss | |
| # loss = batch_hard_mining(embeddings, labels, MARGIN) | |
| loss = criterion(embeddings, labels) | |
| # --- BACKPROPAGATION --- | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| running_loss += loss.item() | |
| # --- PRINTING --- | |
| if batch_idx % batch_print_interval == 0: | |
| avg_loss = running_loss / batch_print_interval | |
| elapsed_time = time.time() - start_time | |
| print(f'Epoch [{epoch+1}/{num_epochs}], ' | |
| f'Batch [{batch_idx}/{len(train_dataloader)}], ' | |
| f'{mining_strategy} Loss: {avg_loss:.4f}, ' | |
| f'Time/{batch_print_interval} Batches: {elapsed_time:.2f}s') | |
| running_loss = 0.0 | |
| start_time = time.time() | |
| # --- VALIDATION --- | |
| if val_dataloader is not None and batch_idx % validation_interval == 0: | |
| # **IMPORTANT:** Always use the more robust Hard Mining validation | |
| # to get a real assessment of the embedding space's quality. | |
| val_loss, num_val_batches = validate_model_mining( | |
| model, criterion, val_dataloader, device | |
| ) | |
| print(f"[Validation @ Batch {batch_idx}] Checked {num_val_batches} Val Batches. Loss: {val_loss:.4f}\n") | |
| early_stopper.early_stop(val_loss, model) | |
| if early_stopper.early_stop(val_loss, model): | |
| print(f"\n*** Early stopping triggered! ***") | |
| if early_stopper.best_model_state is not None: | |
| model.load_state_dict(early_stopper.best_model_state) | |
| print("Restored best model weights.") | |
| return total_batches_processed, early_stopper.min_validation_loss | |
| model.train() | |
| start_time = time.time() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| print(f"--- Epoch {epoch+1} finished. ---") | |
| if early_stopper.best_model_state is not None: | |
| model.load_state_dict(early_stopper.best_model_state) | |
| print("Training finished. Restored best model weights based on validation loss.") | |
| return total_batches_processed, early_stopper.min_validation_loss | |
| def adjust_learning_rate(optimizer, new_lr): | |
| """ | |
| Sets the learning rate for all parameter groups in the optimizer. | |
| Args: | |
| optimizer (torch.optim.Optimizer): The optimizer whose learning rate to adjust. | |
| new_lr (float): The new learning rate value. | |
| """ | |
| for param_group in optimizer.param_groups: | |
| param_group['lr'] = new_lr |