#!/usr/bin/env python # Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Bake Context Embedding into MagpieTTS Checkpoint This script converts a MagpieTTS decoder_ce checkpoint by: 1. Loading a reference audio file 2. Running it through the context_encoder to get the embedding 3. Saving a new checkpoint with: - The baked context embedding as a buffer - All original weights EXCEPT context_encoder weights 4. Saving a modified config without context_encoder settings Usage: python scripts/magpietts/bake_context_embedding.py \ --input_checkpoint /path/to/original.ckpt \ --config_path /path/to/config.yaml \ --output_checkpoint /path/to/baked.ckpt \ --context_audio /path/to/reference.wav The resulting checkpoint will be smaller (no context_encoder weights) and will always use the baked reference audio embedding for voice cloning, regardless of what context audio is provided at inference time. A modified config file will be saved alongside the output checkpoint with context_encoder settings removed. """ from __future__ import annotations import argparse import os from copy import deepcopy from typing import Tuple from examples.tts.magpietts.utils import update_config_for_inference import soundfile as sf import torch from omegaconf import OmegaConf, open_dict from nemo.collections.tts.models import MagpieTTSModel from nemo.collections.tts.parts.utils.helpers import get_mask_from_lengths from nemo.utils import logging def load_audio(audio_path: str, target_sample_rate: int = 24000) -> Tuple[torch.Tensor, torch.Tensor]: """Load audio file and return tensor with length. Args: audio_path: Path to audio file. target_sample_rate: Expected sample rate. Audio will be resampled if needed. Returns: Tuple of (audio_tensor, audio_lens) with shapes (1, T) and (1,). Raises: ValueError: If audio file cannot be loaded or has wrong sample rate. """ audio, sr = sf.read(audio_path, dtype='float32') if sr != target_sample_rate: try: import librosa audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sample_rate) logging.info(f"Resampled audio from {sr}Hz to {target_sample_rate}Hz") except ImportError: raise ValueError( f"Audio sample rate {sr} does not match target {target_sample_rate}. " "Install librosa for automatic resampling: pip install librosa" ) # Convert to tensor: (T,) -> (1, T) audio_tensor = torch.tensor(audio).unsqueeze(0) audio_lens = torch.tensor([audio_tensor.shape[1]]) return audio_tensor, audio_lens def bake_model_context_embedding( model: MagpieTTSModel, context_audio: torch.Tensor, context_audio_lens: torch.Tensor, ) -> None: """Compute and store the context embedding from reference audio into the model. This function runs the context audio through the model's context_encoder and stores the resulting embedding as a buffer. After baking, the context_encoder weights can be removed from the checkpoint to reduce model size. Only supported for decoder_ce model type. Args: model: MagpieTTSModel instance with context_encoder. context_audio: Reference audio waveform. Shape: (1, T_samples). context_audio_lens: Length of audio in samples. Shape: (1,). Raises: ValueError: If model type is not decoder_ce. RuntimeError: If context_encoder is not available. """ if model.model_type != 'decoder_ce': raise ValueError( f"Baking context embedding is only supported for decoder_ce model type, got {model.model_type}" ) if not hasattr(model, 'context_encoder'): raise RuntimeError("context_encoder not found. Cannot bake embedding.") with torch.no_grad(): # Convert audio to codec tokens context_audio_codes, context_audio_codes_lens = model.audio_to_codes( context_audio, context_audio_lens, audio_type='context' ) context_audio_codes = model.pad_audio_codes( context_audio_codes, model.frame_stacking_factor, pad_token=0 ) context_audio_embedded = model.embed_audio_tokens(context_audio_codes) # Compute context length after frame stacking context_input_lens = torch.ceil( context_audio_codes_lens / model.frame_stacking_factor ).to(context_audio_codes_lens.dtype) context_mask = get_mask_from_lengths(context_input_lens) # Run through context encoder context_embedding = model.context_encoder( context_audio_embedded, context_mask, cond=None, cond_mask=None )['output'] # Store as buffers (squeeze batch dim since we store single embedding) model.baked_context_embedding = context_embedding.squeeze(0) # (T, E) model.baked_context_embedding_len = context_input_lens.squeeze(0) # scalar logging.info( f"Baked context embedding with shape {model.baked_context_embedding.shape}, " f"length {model.baked_context_embedding_len.item()}" ) def bake_context_embedding( input_checkpoint: str, config_path: str, output_checkpoint: str, context_audio: str, device: str = 'cuda', ) -> None: """Bake context embedding into checkpoint. Args: input_checkpoint: Path to original MagpieTTS checkpoint (.ckpt). config_path: Path to model config file (.yaml). output_checkpoint: Path to save the baked checkpoint. context_audio: Path to reference audio file for baking. device: Device to run inference on ('cuda' or 'cpu'). Raises: ValueError: If model type is not decoder_ce. FileNotFoundError: If input files don't exist. """ # Validate inputs if not os.path.exists(input_checkpoint): raise FileNotFoundError(f"Input checkpoint not found: {input_checkpoint}") if not os.path.exists(config_path): raise FileNotFoundError(f"Config file not found: {config_path}") if not os.path.exists(context_audio): raise FileNotFoundError(f"Context audio not found: {context_audio}") logging.info(f"Loading model from {input_checkpoint}") logging.info(f"Using config from {config_path}") # Load config cfg = OmegaConf.load(config_path) if "cfg" in cfg: cfg = cfg.cfg print(cfg) with open_dict(cfg): cfg, cfg_sample_rate = update_config_for_inference( cfg, "/nemo_codec_checkpoints/21fps_causal_codecmodel.nemo", False, False, ) # Load model model = MagpieTTSModel(cfg) ckpt = torch.load(input_checkpoint, weights_only=False, map_location=device) state_dict = ckpt.get('state_dict', ckpt) model.load_state_dict(state_dict, strict=False) model = model.to(device) model.eval() # Validate model type if model.model_type != 'decoder_ce': raise ValueError( f"Baking context embedding is only supported for decoder_ce model type, " f"got {model.model_type}" ) # Check that context_encoder exists if not hasattr(model, 'context_encoder'): raise RuntimeError( "Model does not have context_encoder. It may already have a baked embedding." ) # Load reference audio logging.info(f"Loading reference audio from {context_audio}") sample_rate = model.sample_rate audio_tensor, audio_lens = load_audio(context_audio, target_sample_rate=sample_rate) audio_tensor = audio_tensor.to(device) audio_lens = audio_lens.to(device) logging.info(f"Reference audio duration: {audio_lens[0].item() / sample_rate:.2f}s") # Bake the embedding logging.info("Computing context embedding...") bake_model_context_embedding(model, audio_tensor, audio_lens) # Verify baking worked if not model.has_baked_context_embedding: raise RuntimeError("Failed to bake context embedding") logging.info( f"Baked embedding shape: {model.baked_context_embedding.shape}, " f"length: {model.baked_context_embedding_len.item()}" ) # Save the model - state_dict will automatically exclude context_encoder logging.info(f"Saving baked checkpoint to {output_checkpoint}") # Get state dict (will exclude context_encoder due to has_baked_context_embedding) state_dict = model.state_dict() # Explicitly remove any remaining context_encoder keys context_encoder_keys = [k for k in state_dict.keys() if 'context_encoder' in k] for key in context_encoder_keys: del state_dict[key] # Count excluded keys for reporting original_ckpt = torch.load(input_checkpoint, weights_only=False, map_location='cpu') original_state_dict = original_ckpt.get('state_dict', original_ckpt) excluded_keys = [k for k in original_state_dict.keys() if 'context_encoder' in k] logging.info(f"Removed {len(excluded_keys)} context_encoder parameters") # Calculate size reduction if excluded_keys: original_size = sum(original_state_dict[k].numel() for k in excluded_keys) logging.info(f"Approximate size reduction: {original_size * 4 / 1024 / 1024:.1f} MB (float32)") # Create modified config without context_encoder logging.info("Creating modified config without context_encoder...") modified_cfg = deepcopy(cfg) with open_dict(modified_cfg): if 'model' in modified_cfg and 'context_encoder' in modified_cfg.model: del modified_cfg.model.context_encoder logging.info("Removed 'context_encoder' from config") # Add flag to indicate this checkpoint has baked embedding if 'model' in modified_cfg: modified_cfg.model.has_baked_context_embedding = True # Save checkpoint output_dir = os.path.dirname(output_checkpoint) if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir) torch.save({'state_dict': state_dict}, output_checkpoint) logging.info(f"Saved baked checkpoint to {output_checkpoint}") # Save modified config output_config_path = output_checkpoint.replace('.ckpt', '_config.yaml') if output_config_path == output_checkpoint: output_config_path = output_checkpoint + '_config.yaml' OmegaConf.save(modified_cfg, output_config_path) logging.info(f"Saved modified config to {output_config_path}") # Verify the saved checkpoint logging.info("Verifying saved checkpoint...") loaded_state = torch.load(output_checkpoint, weights_only=False, map_location='cpu')['state_dict'] assert 'baked_context_embedding' in loaded_state, "baked_context_embedding not in saved checkpoint" assert 'baked_context_embedding_len' in loaded_state, "baked_context_embedding_len not in saved checkpoint" assert not any( 'context_encoder' in k for k in loaded_state.keys() ), "context_encoder keys should not be in saved checkpoint" # Verify the saved config logging.info("Verifying saved config...") loaded_cfg = OmegaConf.load(output_config_path) assert 'context_encoder' not in loaded_cfg.get('model', {}), "context_encoder should not be in saved config" logging.info("Verification successful!") def main(): parser = argparse.ArgumentParser( description="Bake context embedding into MagpieTTS checkpoint", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=__doc__, ) parser.add_argument( '--input_checkpoint', type=str, required=True, help='Path to original MagpieTTS checkpoint (.ckpt)', ) parser.add_argument( '--config_path', type=str, required=True, help='Path to model config file (.yaml)', ) parser.add_argument( '--output_checkpoint', type=str, required=True, help='Path to save the baked checkpoint', ) parser.add_argument( '--context_audio', type=str, required=True, help='Path to reference audio file for baking', ) parser.add_argument( '--device', type=str, default='cuda', choices=['cuda', 'cpu'], help='Device to run inference on (default: cuda)', ) args = parser.parse_args() bake_context_embedding( input_checkpoint=args.input_checkpoint, config_path=args.config_path, output_checkpoint=args.output_checkpoint, context_audio=args.context_audio, device=args.device, ) if __name__ == '__main__': main()