#!/usr/bin/env python # Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Inspect MagpieTTS Checkpoint A diagnostic script to check the contents of a MagpieTTS checkpoint: - Whether it has context_encoder weights - Whether it has baked context embeddings - Shape of baked embeddings if present Usage: python scripts/magpietts/inspect_checkpoint.py --checkpoint /path/to/checkpoint.ckpt """ from __future__ import annotations import argparse import os import torch def inspect_checkpoint(checkpoint_path: str) -> None: """Inspect a MagpieTTS checkpoint for context_encoder and baked embeddings. Args: checkpoint_path: Path to the checkpoint file (.ckpt). """ if not os.path.exists(checkpoint_path): raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") print(f"Loading checkpoint: {checkpoint_path}") ckpt = torch.load(checkpoint_path, weights_only=False, map_location='cpu') # Get state dict if 'state_dict' in ckpt: state_dict = ckpt['state_dict'] print("Found 'state_dict' key in checkpoint") else: state_dict = ckpt print("Checkpoint is a raw state_dict (no 'state_dict' wrapper)") print(f"\nTotal keys in state_dict: {len(state_dict)}") # Check for context_encoder weights context_encoder_keys = [k for k in state_dict.keys() if 'context_encoder' in k] print(f"\n{'=' * 60}") print("CONTEXT ENCODER WEIGHTS") print('=' * 60) if context_encoder_keys: print(f"✓ Found {len(context_encoder_keys)} context_encoder parameters") # Calculate size total_params = sum(state_dict[k].numel() for k in context_encoder_keys) size_mb = total_params * 4 / 1024 / 1024 # float32 print(f" Total parameters: {total_params:,}") print(f" Approximate size: {size_mb:.2f} MB (float32)") print("\n Sample keys:") for key in context_encoder_keys[:5]: print(f" - {key}: {state_dict[key].shape}") if len(context_encoder_keys) > 5: print(f" ... and {len(context_encoder_keys) - 5} more") else: print("✗ No context_encoder weights found") # Check for baked context embedding print(f"\n{'=' * 60}") print("BAKED CONTEXT EMBEDDING") print('=' * 60) has_baked_embedding = 'baked_context_embedding' in state_dict has_baked_embedding_len = 'baked_context_embedding_len' in state_dict if has_baked_embedding: embedding = state_dict['baked_context_embedding'] if embedding is not None and embedding.numel() > 0: print(f"✓ Found baked_context_embedding") print(f" Shape: {embedding.shape}") print(f" Dtype: {embedding.dtype}") print(f" Parameters: {embedding.numel():,}") size_mb = embedding.numel() * 4 / 1024 / 1024 print(f" Size: {size_mb:.4f} MB (float32)") else: print("✗ baked_context_embedding key exists but is None or empty") else: print("✗ No baked_context_embedding found") if has_baked_embedding_len: embedding_len = state_dict['baked_context_embedding_len'] if embedding_len is not None: print(f"✓ Found baked_context_embedding_len: {embedding_len.item()}") else: print("✗ baked_context_embedding_len key exists but is None") else: print("✗ No baked_context_embedding_len found") # Summary print(f"\n{'=' * 60}") print("SUMMARY") print('=' * 60) if context_encoder_keys and not has_baked_embedding: print("→ This is a STANDARD checkpoint with context_encoder") print(" Can be used for any voice cloning with dynamic context audio") elif has_baked_embedding and embedding is not None and embedding.numel() > 0: if context_encoder_keys: print("→ This checkpoint has BOTH context_encoder AND baked embedding") print(" This is unusual - consider removing context_encoder weights") else: print("→ This is a BAKED checkpoint") print(" Will always use the baked voice, ignoring input context audio") else: print("→ This checkpoint has NEITHER context_encoder NOR baked embedding") print(" This may indicate an issue or a different model type") def main(): parser = argparse.ArgumentParser( description="Inspect MagpieTTS checkpoint for context_encoder and baked embeddings", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=__doc__, ) parser.add_argument( '--checkpoint', type=str, required=True, help='Path to the checkpoint file (.ckpt)', ) args = parser.parse_args() inspect_checkpoint(args.checkpoint) if __name__ == '__main__': main()