from collections import defaultdict import torch import torch.nn.functional as F from training.utils import move_to_device def generate_class_prototypes(model, dataloader, device, max_batches=None): """ Generates the mean embedding for each class from the training data. """ model.eval() class_embeddings = {} print("Generating class prototypes...") with torch.no_grad(): for i, batch in enumerate(dataloader): batch = move_to_device(batch, device) if max_batches is not None and i >= max_batches: print(f"Stopping prototype generation after {i} batches.") break inputs = batch['anchor'] labels = batch['y'] embeddings = model(**inputs) for embed, label in zip(embeddings, labels): if label.item() not in class_embeddings: class_embeddings[label.item()] = [embed] else: class_embeddings[label.item()].append(embed) if i % 100 == 0: print(f"Processed {i * dataloader.batch_size} samples...") prototypes = {} for class_id, embed_list in class_embeddings.items(): stacked_embeddings = torch.stack(embed_list) mean_prototype = torch.mean(stacked_embeddings, dim=0) normalized_prototype = F.normalize(mean_prototype.unsqueeze(0), p=2, dim=1).squeeze(0) prototypes[class_id] = normalized_prototype print(f"Generated {len(prototypes)} class prototypes.") class_ids = sorted(prototypes.keys()) prototype_tensor = torch.stack([prototypes[c] for c in class_ids]).to(device) return prototype_tensor, class_ids BOLD_GREEN = '\033[1;32m' BOLD_RED = '\033[1;31m' RESET = '\033[0m' def evaluate_ncm_accuracy_top3(model, dataloader, prototypes, class_ids, id_to_tag, device): """ Evaluates the model on the test set using the Nearest Class Mean (NCM) classifier and calculates both Top-1 and Top-3 accuracy. """ model.eval() class_top1_correct_predictions = defaultdict(int) class_top3_correct_predictions = defaultdict(int) class_total_samples = defaultdict(int) total_top1_correct_predictions = 0 total_top3_correct_predictions = 0 total_samples = 0 prototypes = prototypes.to(device) print("Starting evaluation for Top-1 and Top-3 accuracy...") with torch.no_grad(): for i, batch in enumerate(dataloader): batch = move_to_device(batch, device) T_inputs = batch['anchor'] true_labels = batch['y'] test_embeddings = model(**T_inputs) similarity_matrix = test_embeddings @ prototypes.T predicted_class_indices_top1 = torch.argmax(similarity_matrix, dim=1) predicted_labels_top1 = torch.tensor([class_ids[idx.item()] for idx in predicted_class_indices_top1], device=device) is_correct_top1 = (predicted_labels_top1 == true_labels) K = 3 topK_results = torch.topk(similarity_matrix, K, dim=1) predicted_class_indices_topK = topK_results.indices topk_predicted_labels = torch.tensor([[class_ids[idx.item()] for idx in row] for row in predicted_class_indices_topK], device=device) is_in_topk = (topk_predicted_labels == true_labels.unsqueeze(1).expand(-1, K)) is_correct_top3 = torch.any(is_in_topk, dim=1) total_samples += true_labels.size(0) total_top1_correct_predictions += is_correct_top1.sum().item() total_top3_correct_predictions += is_correct_top3.sum().item() for true_label, is_corr1, is_corr3 in zip(true_labels.tolist(), is_correct_top1.tolist(), is_correct_top3.tolist()): class_total_samples[true_label] += 1 if is_corr1: class_top1_correct_predictions[true_label] += 1 if is_corr3: class_top3_correct_predictions[true_label] += 1 total_top1_accuracy = total_top1_correct_predictions / total_samples if total_samples > 0 else 0.0 total_top3_accuracy = total_top3_correct_predictions / total_samples if total_samples > 0 else 0.0 # 💡 New print("\n--- Summary of NCM Evaluation ---") print(f"Total NCM **Top-1** Accuracy: {total_top1_accuracy:.4f} ({total_top1_correct_predictions}/{total_samples} samples)") print(f"Total NCM **Top-3** Accuracy: {total_top3_accuracy:.4f} ({total_top3_correct_predictions}/{total_samples} samples)") # 💡 New print("-----------------------------------") all_class_ids = sorted(class_total_samples.keys()) top1_classes_above_50 = [] top1_count_above_50 = 0 for class_id in all_class_ids: corr_top1 = class_top1_correct_predictions[class_id] total = class_total_samples[class_id] accuracy_top1 = corr_top1 / total if total > 0 else 0.0 class_tag = id_to_tag.get(class_id, f"Unknown ID {class_id}") if accuracy_top1 >= 0.50: top1_classes_above_50.append((class_tag, accuracy_top1, corr_top1, total)) top1_count_above_50 += 1 print("-----------------------------------") print(f"Total number of classes with a **Top-1** accuracy of 50% or more: {BOLD_GREEN}{(top1_count_above_50/len(all_class_ids)):.4f} ({top1_count_above_50}/{len(all_class_ids)}){RESET}") print("-----------------------------------") print("\n--- Class Accuracy Analysis (Top-3 Threshold: 50%) ---") classes_above_50 = [] classes_below_50 = [] count_above_50 = 0 for class_id in all_class_ids: corr_top3 = class_top3_correct_predictions[class_id] total = class_total_samples[class_id] accuracy_top3 = corr_top3 / total if total > 0 else 0.0 class_tag = id_to_tag.get(class_id, f"Unknown ID {class_id}") if accuracy_top3 >= 0.50: classes_above_50.append((class_tag, accuracy_top3, corr_top3, total)) count_above_50 += 1 else: classes_below_50.append((class_tag, accuracy_top3, corr_top3, total)) print("-----------------------------------") print(f"Total number of classes with a **Top-3** accuracy of 50% or more: {BOLD_GREEN} {(count_above_50/len(all_class_ids)):.4f} ({count_above_50}/{len(all_class_ids)}) {RESET}") print("-----------------------------------") print("--- End of Evaluation ---") return total_top1_accuracy, total_top3_accuracy