# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import os import random import time from dataclasses import dataclass from functools import partial from typing import Any, Dict, List, Optional, Union import numpy as np import soundfile as sf import torch import wandb from hydra.utils import instantiate from lightning.pytorch import Trainer from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger from omegaconf import DictConfig, OmegaConf, open_dict from torch import nn from torch.utils.data import get_worker_info from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config from nemo.collections.tts.data.text_to_speech_dataset_lhotse import MagpieTTSLhotseDataset, setup_tokenizers from nemo.collections.tts.losses.aligner_loss import ForwardSumLoss from nemo.collections.tts.models import AudioCodecModel from nemo.collections.tts.modules import transformer_2501 from nemo.collections.tts.modules.aligner import AlignmentEncoder from nemo.collections.tts.modules.audio_codec_modules import VectorQuantizerIndexConverter from nemo.collections.tts.modules.magpietts_modules import ( CharAwareSubwordEncoder, EOSDetectionMethod, LocalTransformerType, SpecialAudioToken, cosine_schedule, ) from nemo.collections.tts.parts.utils.helpers import ( binarize_attention_parallel, get_mask_from_lengths, plot_alignment_to_numpy, ) from nemo.core.classes import ModelPT from nemo.core.classes.common import PretrainedModelInfo from nemo.utils import logging @dataclass class InferBatchOutput: """Output dataclass for MagpieTTS infer_batch method. This provides a consistent return type regardless of which optional outputs are requested. Attributes: predicted_audio: Generated audio waveforms. Shape: (B, T_audio). predicted_audio_lens: Length of each audio in samples. Shape: (B,). predicted_codes: Generated audio codec tokens. Shape: (B, num_codebooks, T_frames). predicted_codes_lens: Length of each code sequence in frames. Shape: (B,). rtf_metrics: Dictionary containing real-time factor and timing metrics. cross_attention_maps: Optional cross-attention visualization maps. List of numpy arrays, one per batch item. Only populated if return_cross_attn_probs=True. headwise_cross_attention_maps: Optional per-head cross-attention maps. Only populated if return_cross_attn_probs=True and compute_all_heads_attn_maps=True. """ predicted_audio: torch.Tensor predicted_audio_lens: torch.Tensor predicted_codes: torch.Tensor predicted_codes_lens: torch.Tensor rtf_metrics: Dict[str, Any] cross_attention_maps: Optional[List[Any]] = None headwise_cross_attention_maps: Optional[List[Any]] = None def worker_init_fn(worker_id): # For mp.set_start_method("spawn", force=True) # The dataset class should be picklable, so we initialize non-picklable objects here logging.info(f"Worker {worker_id} initializing...") worker_info = get_worker_info() dataset = worker_info.dataset # Get the dataset instance in this worker tokenizer = setup_tokenizers(dataset.tokenizer_config, mode=dataset.dataset_type) dataset.text_tokenizer = tokenizer class MagpieTTSModel(ModelPT): """ Magpie-TTS Model Base Class used for training a TTS model that can generate audio codes from transcript and a context audio/text Supports multiple model types: - multi_encoder_context_tts: Transcript and context audio go to different encoders. Transcript encoding feeds to layers given by cfg.model.transcript_decoder_layers and the context encoding feeds into the layers given by context_decoder_layers .Also supports text context which gets encoded by the same encoder as context audio. Only one of context audio or contex text is supported. - decoder_context_tts: Text goes into the encoder; context & target audio go to the decoder. Also supports text context. Supports fixed sized context so we set context_duration_min and context_duration_max to the same value (5 seconds). Text context, which is usually shorter than number of codec frames of 5 second of audio, is padded to the max context duration in this model. - decoder_ce: Same as decoder_context_tts except there is a small neural network between the context tensors and the decoder input. """ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): self.world_size = 1 if trainer is not None: self.world_size = trainer.num_nodes * trainer.num_devices # load codec, disable loading of loss modules not needed during inference codec_model_path = cfg.get('codecmodel_path') if codec_model_path.startswith('nvidia/'): codec_model = AudioCodecModel.from_pretrained(codec_model_path) else: codec_model_cfg = AudioCodecModel.restore_from(codec_model_path, return_config=True) if "use_scl_loss" in codec_model_cfg: codec_model_cfg.use_scl_loss = False codec_model = AudioCodecModel.restore_from( codec_model_path, strict=False, override_config_path=codec_model_cfg ) self.sample_rate = codec_model.sample_rate self.codec_model_samples_per_frame = codec_model.samples_per_frame # del codec discriminator to free memory del codec_model.discriminator # When using FSQ tokens, the codebook structure can be changed at any time. # An FSQ definition can be provided in `vector_quantizer` config to train with a codebook structure # that is different than in the audio codec checkpoint. vector_quantizer = cfg.get('vector_quantizer') if vector_quantizer is not None: vector_quantizer = instantiate(vector_quantizer) num_audio_codebooks = vector_quantizer.num_codebooks codebook_size = vector_quantizer.codebook_size codec_converter = VectorQuantizerIndexConverter( vector_quantizer_original=codec_model.vector_quantizer, vector_quantizer_new=vector_quantizer, ) data_num_audio_codebooks = codec_model.vector_quantizer.num_codebooks else: num_audio_codebooks = codec_model.num_codebooks data_num_audio_codebooks = num_audio_codebooks codebook_size = codec_model.codebook_size codec_converter = None # The dataloader needs to know the number of codebooks that the context codes were stored in # In the case where there are no context codes saved, and there is no context audio (in the text context path), # We create a dummy context code tensor that is only [context_BOS, context_EOS] that is repeated for # data_num_audio_codebooks self.data_num_audio_codebooks = data_num_audio_codebooks self.num_audio_codebooks = num_audio_codebooks self.codebook_size = codebook_size # Our codebooks start with actual audio codec tokens, followed by special tokens. # The `forced_*` options are for backward compatibility for models trained with older code. get_token_index = partial(SpecialAudioToken.get_index, base_codebook_size=self.codebook_size) self.audio_bos_id = cfg.get('forced_audio_bos_id', get_token_index(SpecialAudioToken.AUDIO_BOS)) self.audio_eos_id = cfg.get('forced_audio_eos_id', get_token_index(SpecialAudioToken.AUDIO_EOS)) self.context_audio_bos_id = cfg.get( 'forced_context_audio_bos_id', get_token_index(SpecialAudioToken.AUDIO_CONTEXT_BOS) ) self.context_audio_eos_id = cfg.get( 'forced_context_audio_eos_id', get_token_index(SpecialAudioToken.AUDIO_CONTEXT_EOS) ) self.mask_token_id = cfg.get('forced_mask_token_id', get_token_index(SpecialAudioToken.MASK_TOKEN)) self.num_all_tokens_per_codebook = cfg.get( 'forced_num_all_tokens_per_codebook', self.codebook_size + len(SpecialAudioToken) ) self.use_bpe_char_tokenizer = cfg.get('use_bpe_char_tokenizer', False) # The frame stacking factor controls how many consecutive frames are processed together by the base decoder # (and then refined into individual frames by the local transformer). A frame stacking factor of 1 means no # frame stacking. We have a separate embedding table for each of the stacked frames, e.g. for frame stacking # factor of 3, the entries of codebook 0 appear 3 times in the embedding table. self.frame_stacking_factor = cfg.get('frame_stacking_factor', 1) assert 'downsample_factor' not in cfg, '`downsample_factor` is deprecated, use `frame_stacking_factor` instead' # Setup tokenizer if hasattr(cfg, 'text_tokenizer'): # For backward compatibility for English-only models with open_dict(cfg): cfg.text_tokenizers = {"english_phoneme": cfg.text_tokenizer} del cfg['text_tokenizer'] self.use_text_conditioning_encoder = cfg.get('use_text_conditioning_encoder', False) # Using google-t5/t5-small as default text conditioning tokenizer for backward compatibility. self.text_conditioning_tokenizer_name = cfg.get('text_conditioning_tokenizer_name', None) self.legacy_text_conditioning = cfg.get('legacy_text_conditioning', False) if self.legacy_text_conditioning: if self.text_conditioning_tokenizer_name is None: self.text_conditioning_tokenizer_name = "google-t5/t5-small" tokenizer_target = "AutoTokenizer" if self.text_conditioning_tokenizer_name == "google-t5/t5-small": tokenizer_target = "T5Tokenizer" with open_dict(cfg): cfg.text_tokenizers[self.text_conditioning_tokenizer_name] = { '_target_': tokenizer_target, 'pretrained_model': self.text_conditioning_tokenizer_name, } elif self.text_conditioning_tokenizer_name is None: # If no text_conditioning_tokenizer_name is specified, use the first one as default # For text context tokenization self.text_conditioning_tokenizer_name = list(cfg.text_tokenizers.keys())[0] # TODO @xueyang: both tokenizers are only used to get some token ids. We # should kill them to save a small amount of mem resources since dataloader will initialize them # again after the worker processes are spawned. self.tokenizer = setup_tokenizers( all_tokenizers_config=cfg.text_tokenizers, mode='train', ) num_tokens_tokenizer = len(self.tokenizer.tokens) if self.legacy_text_conditioning: # Text context tokens are not a part of the the regular transcript embedding table in legacy models num_tokens_tokenizer -= self.tokenizer.num_tokens_per_tokenizer[self.text_conditioning_tokenizer_name] num_tokens = num_tokens_tokenizer + 2 # +2 for BOS and EOS self.bos_id = num_tokens - 2 self.eos_id = num_tokens - 1 self.model_type = cfg.get('model_type', None) self.pad_context_text_to_max_duration = self.model_type in ['decoder_context_tts', 'decoder_ce'] self.use_kv_cache_for_inference = cfg.get('use_kv_cache_for_inference', False) # Below args (text_context_remapping_json, text_context_remapping_prob) are # for combining multiple context_texts into a single one during training. # Eg. if we want to treat Emma_neutral and Emma_conversational as one speaker, # we can create an override dict {'Emma_neutral' : 'Emma', 'Emma_conversational' : 'Emma'} # This dict is saved in a json file given by cfg.model.text_context_remapping_json # If we want to preserve both behaviours i.e (Emma_neutral, Emma_conversational) and just (Emma) # we can do this mapping with a probability during training, as specified by text_context_remapping_prob self.text_context_remapping = None text_context_remapping_json = cfg.get('text_context_remapping_json', None) self.text_context_remapping_prob = cfg.get('text_context_remapping_prob', 0.0) if text_context_remapping_json is not None: with open(text_context_remapping_json, 'r') as f: self.text_context_remapping = json.load(f) super().__init__(cfg=cfg, trainer=trainer) if self.legacy_text_conditioning: tc_tokenizer = self.tokenizer.tokenizers[self.text_conditioning_tokenizer_name] self.context_text_embedding = nn.Embedding(tc_tokenizer.vocab_size, cfg.embedding_dim) # This needs to happen after super().__init__() self._codec_model = codec_model self._codec_model.freeze() # Lightning does requires_grad = False and self.eval() self._codec_converter = codec_converter audio_embeddings = [] for _ in range(self.num_audio_codebooks * self.frame_stacking_factor): audio_embeddings.append(nn.Embedding(self.num_all_tokens_per_codebook, cfg.embedding_dim)) self.audio_embeddings = nn.ModuleList(audio_embeddings) if self.use_bpe_char_tokenizer: # BPE char tokenizer assert len(self.tokenizer.tokenizers) == 1, "BPE char tokenizer should only be used with one tokenizer" tokenizer_name = self.tokenizer.tokenizer_names[0] tokenizer = self.tokenizer.tokenizers[tokenizer_name] subword_vocab = tokenizer.get_vocab() # special tokens will be stored as it is in the char_vocab # Each special token will only be mapped to one char id special_vocab = { '': self.bos_id, '': self.eos_id, } self.cas_encoder = CharAwareSubwordEncoder( d_embed=cfg.embedding_dim, llm_tokenizer_vocab=subword_vocab, subword_padding_idx=self.tokenizer.pad, special_vocab=special_vocab, ) else: # Regular text embedding self.text_embedding = nn.Embedding(num_tokens, cfg.embedding_dim) self.encoder = transformer_2501.Transformer(**dict(cfg.encoder)) self.decoder = transformer_2501.Transformer(**dict(cfg.decoder)) self.final_proj = nn.Linear( cfg.decoder.d_model, self.num_audio_codebooks * self.num_all_tokens_per_codebook * self.frame_stacking_factor, ) self.local_transformer_type = LocalTransformerType(cfg.get('local_transformer_type', 'none').lower()) logging.info(f"Local transformer type: {self.local_transformer_type}") if self.local_transformer_type != LocalTransformerType.NO_LT: local_transformer_hidden_dim = cfg.get('local_transformer_hidden_dim', 256) if local_transformer_hidden_dim != cfg.decoder.d_model: self.local_transformer_in_projection = nn.Linear(cfg.decoder.d_model, local_transformer_hidden_dim) else: self.local_transformer_in_projection = nn.Identity() self.local_transformer = transformer_2501.Transformer( n_layers=self.cfg.get('local_transformer_n_layers', 2), d_model=local_transformer_hidden_dim, d_ffn=local_transformer_hidden_dim * 4, sa_n_heads=self.cfg.get('local_transformer_n_heads', 1), kernel_size=1, is_causal=self.local_transformer_type == LocalTransformerType.AR, max_length_causal_mask=self.frame_stacking_factor * self.num_audio_codebooks + 2, use_learnable_pos_emb=True, ) local_transformer_out_projections = [] for _ in range(self.num_audio_codebooks * self.frame_stacking_factor): # Have a separate projection layer for each codebook, to distinguish between them local_transformer_out_projections.append( nn.Linear(local_transformer_hidden_dim, self.num_all_tokens_per_codebook) ) self.local_transformer_out_projections = nn.ModuleList(local_transformer_out_projections) if cfg.get('use_alignment_encoder', False): self.alignment_encoder = AlignmentEncoder( n_mel_channels=cfg.embedding_dim, n_text_channels=cfg.embedding_dim, dist_type="cosine", temperature=15.0, ) if self.model_type == 'multi_encoder_context_tts': logging.warning(f"The multi_encoder_context_tts model type for {self} is deprecated.") # Transcript and context audio/text go to different encoders. # Output of the encoders goes to the decoder through the cross-attention layers self.transcript_decoder_layers = cfg.get('transcript_decoder_layers', [3, 4, 5, 6, 7, 8]) self.context_decoder_layers = cfg.get( 'context_decoder_layers', [0, 1, 2, 9, 10, 11] ) # For backward compatibility multi_encoder_mapping = [None for _ in range(self.decoder.n_layers)] for layer in self.transcript_decoder_layers: multi_encoder_mapping[layer] = 0 # 0 means text goes to this layer, 1 means context goes to this layer for layer in self.context_decoder_layers: multi_encoder_mapping[layer] = 1 self.multi_encoder_mapping = multi_encoder_mapping self.context_encoder = transformer_2501.Transformer(**dict(cfg.context_encoder)) elif self.model_type == 'decoder_context_tts': # Context audio/text goes directly to the decoder (before the target audio codes) self.transcript_decoder_layers = [ idx for idx in range(self.decoder.n_layers) ] # All layers are used for text elif self.model_type == 'decoder_ce': # Similar to decoder_context_tts, but we use context encoder # Decoder gets output from context encoder instead of raw context tokens embeddings self.context_encoder = transformer_2501.Transformer(**dict(cfg.context_encoder)) self.transcript_decoder_layers = [ idx for idx in range(cfg.decoder.n_layers) ] # All layers are used for text # Register buffers for baked context embedding (initially None/empty) # These will be populated when loading a checkpoint with baked embedding self.register_buffer('baked_context_embedding', None) self.register_buffer('baked_context_embedding_len', None) else: raise ValueError(f"Unsupported model type {self.model_type}") self.cross_entropy_loss = nn.CrossEntropyLoss(reduction='none') self.alignment_loss_scale = cfg.get('alignment_loss_scale', 0.0) self.alignment_encoder_loss_scale = cfg.get('alignment_encoder_loss_scale', 0.0) if self.alignment_loss_scale > 0.0: self.alignment_loss = ForwardSumLoss(loss_scale=self.alignment_loss_scale) if self.alignment_encoder_loss_scale > 0.0: self.alignment_encoder_loss = ForwardSumLoss(loss_scale=self.alignment_encoder_loss_scale) # Define cfg parameters into self parameters self.prior_end_step = self.cfg.prior_end_step self.prior_scaledown_start_step = self.cfg.prior_scaledown_start_step self.indefinite_prior_prob = self.cfg.get('indefinite_prior_prob', 0.0) self.ctc_prior_layer_ids = self.cfg.get('ctc_prior_layer_ids', self.transcript_decoder_layers) self.cfg_unconditional_prob = self.cfg.get('cfg_unconditional_prob', 0.0) self.decoder_input_dropout_prob = self.cfg.get('decoder_input_dropout_prob', 0.0) self.binarize_attn_method = self.cfg.get('binarize_attn_method', 'argmax') self.binarize_repeat_audio_factor = self.cfg.get('binarize_repeat_audio_factor', 2) self.prior_future_decay = self.cfg.get('prior_future_decay', 1.0) self.prior_past_decay = self.cfg.get('prior_past_decay', 1.0) self.binarized_prior_epsilon = self.cfg.get('binarized_prior_epsilon', 0.0) self.prior_future_context = self.cfg.get('prior_future_context', 1) self.prior_past_context = self.cfg.get('prior_past_context', 1) self.binarize_prior_after_step = self.cfg.get('binarize_prior_after_step', 0) self.codebook_loss_scale = self.cfg.get('codebook_loss_scale', 1.0) self.local_transformer_loss_scale = self.cfg.get('local_transformer_loss_scale', 1.0) self.use_alignment_encoder = self.cfg.get('use_alignment_encoder', False) self.use_prior_for_aligner = self.cfg.get('use_prior_for_aligner', False) self.aligner_encoder_train_steps = self.cfg.get('aligner_encoder_train_steps', float('inf')) self.dec_random_input_max = self.cfg.get('dec_random_input_max', self.num_all_tokens_per_codebook) # Configuration validity checks self.check_frame_stacking_config_validity() def state_dict(self, destination=None, prefix='', keep_vars=False): """ Only used for saving checkpoints. On save, we remove _speaker_verification_model and _codec_model from the checkpoint. The codec model is saved in a separate checkpoint. _speaker_verification_model is only included in older checkpoints with the older single_encoder_sv_tts model_type that is no longer supported and can likely be removed in a future version. If the model has a baked context embedding, the context_encoder weights are also excluded since they are no longer needed for inference. """ if hasattr(self, '_no_state_dict') and self._no_state_dict: return {} # Don't save the speaker verification and codec model in the state dict state_dict = super().state_dict(destination, prefix, keep_vars) keys_substrings_to_exclude = ['_speaker_verification_model', '_codec_model'] # If we have a baked context embedding, exclude context_encoder weights if self.has_baked_context_embedding: keys_substrings_to_exclude.append('context_encoder') for key in list(state_dict.keys()): if any([substring in key for substring in keys_substrings_to_exclude]): del state_dict[key] return state_dict def check_frame_stacking_config_validity(self): """ Check if the configuration is compatible with frame stacking. """ if self.frame_stacking_factor > 1: # The settings below are not supported with frame stacking. # Some of them may work - but they have not been tested. # disallow alignment encoder if self.use_alignment_encoder: raise ValueError("Alignment encoder is not supported for frame stacking") # disallow alignment loss if self.alignment_loss_scale > 0.0: raise ValueError("Alignment loss is not supported for frame stacking") # disallow training prior if self.cfg.prior_scaling_factor is not None and self.cfg.prior_scaling_factor > 0: raise ValueError("Training-time attention prior is not supported for frame stacking") # disallow text conditioning if self.use_text_conditioning_encoder: raise ValueError("Text conditioning is not supported for frame stacking") @property def has_baked_context_embedding(self) -> bool: """Check if the model has a baked context embedding. Returns: True if baked_context_embedding buffer is set, not None, and has elements. """ return ( self.model_type == 'decoder_ce' and hasattr(self, 'baked_context_embedding') and self.baked_context_embedding is not None and self.baked_context_embedding.numel() > 0 ) def update_ckpt(self, state_dict): """ Backward compatibility for checkpoints saved with old model names. """ new_state_dict = {} for key in state_dict.keys(): if 't5_encoder' in key: new_key = key.replace('t5_encoder', 'encoder') new_state_dict[new_key] = state_dict[key] elif 't5_decoder' in key: new_key = key.replace('t5_decoder', 'decoder') new_state_dict[new_key] = state_dict[key] else: new_state_dict[key] = state_dict[key] return new_state_dict def load_state_dict(self, state_dict, strict=True): """ Modify load_state_dict so that we don't restore weights to _speaker_verification_model and _codec_model when strict is True. When strict is False, we can call pytorch's load_state_dict. When strict is True, we loop through all parameters and rename them to enable loading. _speaker_verification_model is only included in older checkpoints with the older single_encoder_sv_tts model_type that is no longer supported and can likely be removed in a future version. Also handles loading baked context embeddings. If the checkpoint contains baked_context_embedding, context_encoder weights are not expected to be present. """ state_dict = self.update_ckpt(state_dict) # Check if checkpoint has baked context embedding has_baked_embedding_in_ckpt = ( 'baked_context_embedding' in state_dict and state_dict['baked_context_embedding'] is not None ) # Load baked embedding buffers if present if has_baked_embedding_in_ckpt: self.baked_context_embedding = state_dict['baked_context_embedding'] self.baked_context_embedding_len = state_dict['baked_context_embedding_len'] logging.info( f"Loaded baked context embedding with shape {self.baked_context_embedding.shape}, " f"length {self.baked_context_embedding_len.item()}" ) if not strict: super().load_state_dict(state_dict, strict=False) # Build list of modules to skip modules_to_skip = [ '_speaker_verification_model', '_codec_model', '_reference_model', 'eval_asr_model', 'eval_speaker_verification_model', 'whisper_model', 'squim_objective_model', ] # Skip context_encoder if checkpoint has baked embedding (weights won't be in checkpoint) if has_baked_embedding_in_ckpt: modules_to_skip.append('context_encoder') for name, child in self.named_children(): if name in modules_to_skip: continue if any(param.numel() > 0 for param in child.parameters()): # If the module has parameters, we want to change the default mapping so that the state_dict gets # loaded. # Ex: state_dict[encoder.position_embeddings.weight] -> new_state_dict[position_embeddings.weight] new_state_dict = {} for key in state_dict.keys(): name_with_dot = f"{name}." if key.startswith(name_with_dot): new_state_dict[key[len(name_with_dot) :]] = state_dict[key] child.load_state_dict(new_state_dict) def audio_to_codes(self, audio, audio_len, audio_type='target'): # audio: (B, T) # audio_len: (B,) if audio_type == 'target': audio_eos_id = self.audio_eos_id audio_bos_id = self.audio_bos_id elif audio_type == 'context': audio_eos_id = self.context_audio_eos_id audio_bos_id = self.context_audio_bos_id else: raise ValueError(f"Received audio_type of {audio_type}. Must be `target` or `context`") self._codec_model.eval() with torch.no_grad(), torch.autocast(device_type=audio.device.type, dtype=torch.float32): codes, codes_len = self._codec_model.encode(audio=audio, audio_len=audio_len) if self._codec_converter is not None: codes = self._codec_converter.convert_original_to_new(audio_tokens=codes, audio_lens=codes_len) # Add a timestep to begining and end of codes tensor bos_tensor = torch.full( (codes.size(0), codes.size(1), 1), audio_bos_id, dtype=codes.dtype, device=codes.device ) # pad at the end to make room for the EOS token; the EOS token's actual position # varies per batch element depending on each element's length. pad_tensor = torch.full( (codes.size(0), codes.size(1), 1), 0, dtype=codes.dtype, device=codes.device ) # 0 is the padding token in the audio codebook codes = torch.cat([bos_tensor, codes, pad_tensor], dim=-1) # codes: (B, C, T') # codes_len: (B,) for idx in range(codes.size(0)): codes[idx, :, codes_len[idx] + 1] = audio_eos_id codes_len = codes_len + 2 # +1 for bos and +1 for eos return codes.long(), codes_len.long() def codes_to_audio(self, codes, codes_len): # codes: (B, C, T') # codes_len: (B,) self._codec_model.eval() with torch.no_grad(), torch.autocast(device_type=codes.device.type, dtype=torch.float32): # Make a copy to avoid modifying the original tensor if it's used elsewhere codes_copy = codes.clone() # Replace eos and bos tokens with padding in the copied tensor codes_copy[codes == self.audio_bos_id] = 0 # zero is the padding token codes_copy[codes == self.audio_eos_id] = 0 # Pass the modified integer token IDs if self._codec_converter is not None: codes_copy = self._codec_converter.convert_new_to_original( audio_tokens=codes_copy, audio_lens=codes_len ) audio, audio_len = self._codec_model.decode(tokens=codes_copy, tokens_len=codes_len) # audio: (B, T) # audio_len: (B,) return audio, audio_len def embed_audio_tokens(self, audio_tokens): B, C, T = audio_tokens.shape audio_embedding = None for i in range(self.frame_stacking_factor): for c in range(C): tokens = audio_tokens[:, c, i :: self.frame_stacking_factor] embedding = self.audio_embeddings[c + i * C](tokens) if audio_embedding is None: audio_embedding = embedding else: audio_embedding += embedding audio_embedding = audio_embedding / (C * self.frame_stacking_factor) return audio_embedding def compute_local_transformer_logits(self, dec_out, audio_codes_target, targets_offset_by_one=False): """ Predicts the logits for all codebooks using the local transformer. Used in both autoregressive (AR) and MaskGit (MG) modes. This function is used in training and validation, not inference/sampling. The sequence layout is slightly different between AR and MG modes, as shown in the diagram below, (using an 8-codebook setup as an example): +------------+---------+---------+---------+---------+---------+---------+---------+---------+---------+ | AR target | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | none | | codebook | | | | | | | | | | +------------+---------+---------+---------+---------+---------+---------+---------+---------+---------+ | MG target | none | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | | codebook | | | | | | | | | | +------------+---------+---------+---------+---------+---------+---------+---------+---------+---------+ | input | Magpie | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | | codebook | latent | or MASK | or MASK | or MASK | or MASK | or MASK | or MASK | or MASK | or MASK | +------------+---------+---------+---------+---------+---------+---------+---------+---------+---------+ | seq. index | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | +------------+---------+---------+---------+---------+---------+---------+---------+---------+---------+ dec_out: (B, T', E) audio_codes_target: (B, C, T') targets_offset_by_one: bool, if False, the target for index 0 is codebook 0, for index 1 is codebook 1, etc. (autoregressive) if True, the target for index 1 is codebook 0, for index 2 is codebook 1, etc. (MaskGit) """ C = self.num_audio_codebooks dec_out_all = dec_out.reshape(-1, dec_out.size(-1)) # (B*T', E) local_transformer_input = [dec_out_all] # Build the teacher-forced input to the LT. for fs_index in range(self.frame_stacking_factor): for codebook_num in range(C): # Collect ground truth codes for the current codebook and frame stack index combintation. codes = audio_codes_target[:, codebook_num, fs_index :: self.frame_stacking_factor] # (B, T') # Individual timesteps are independently handled by the LT fold time into the batch dimension. codes = codes.reshape(-1) # (B*T',) # Embed the codes codebook_embedding = self.audio_embeddings[codebook_num + fs_index * C](codes) # (B*T', E) local_transformer_input.append(codebook_embedding) # Stack the input codes along dimension 1 (codebooks). This is the dimension along which the LT predicts iteratively. local_transformer_input = torch.stack(local_transformer_input, dim=1) # (B*T', C+1, E) local_transformer_input = self.local_transformer_in_projection(local_transformer_input) # (B*T', C+1, 128) _mask = torch.ones( local_transformer_input.size(0), local_transformer_input.size(1), device=local_transformer_input.device ) local_transformer_output = self.local_transformer(local_transformer_input, _mask)['output'] # (B*T', C+1, E) if not targets_offset_by_one: # for autoregressive local transformer the target for index 0 is codebook 0, for index 1 is codebook 1, etc. local_transformer_output = local_transformer_output[:, :-1, :] # (B*T', C, E) else: # for MaskGit the target for index **1** is codebook 0, for index 2 is codebook 1, etc. local_transformer_output = local_transformer_output[:, 1:, :] # (B*T', C, E) all_code_logits = [] for fs_index in range(self.frame_stacking_factor): for codebook_num in range(audio_codes_target.size(1)): # Using a separate projection layer for each codebook (to distinguish between them) # Checked the time - this loop is not taking much time (compared to the local transformer forward pass) codebook_logits = self.local_transformer_out_projections[codebook_num + fs_index * C]( local_transformer_output[:, codebook_num + fs_index * C, :] ) # (B*T', num_all_tokens_per_codebook) all_code_logits.append(codebook_logits) all_code_logits = torch.cat( all_code_logits, dim=1 ) # (B*T'/frame_stacking_factor, num_codebooks * num_all_tokens_per_codebook * frame_stacking_factor) all_code_logits = all_code_logits.view( audio_codes_target.size(0), audio_codes_target.size(2) // self.frame_stacking_factor, -1 ) # (B, T'/frame_stacking_factor, C * num_all_tokens_per_codebook * frame_stacking_factor) return all_code_logits def maskgit_create_random_mask(self, codes): """ Creates a mask where True indicates the positions that should be replaced with a MASK_TOKEN. """ # Codes: (B, C, T) B, C, T = codes.shape # get a uniform random vector uniformly sampled from [0,1) ## Todo does it need to be inclusive on the right? rand_values = torch.rand(B, T, device=codes.device) # apply the cosine schedule frac_masked = cosine_schedule(rand_values) # how many positions to mask n_masked = torch.ceil(frac_masked * C).long() # B,T # The code further below is the vectorized version of this: # for b in range(B): # for t in range(T): # if n_masked[b,t] > 0: # # get a random permutation of the codebook indices # perm = torch.randperm(C) # # mask the top n_masked positions # mask[b, perm[:n_masked[b,t]], t] = True # # Create random permutations random_permutations = torch.argsort(torch.rand(B, C, T, device=codes.device), dim=1) # (B, C, T) # Create a mask tensor where each position indicates if it should be masked mask_indices = torch.arange(C, device=codes.device).view(1, C, 1) mask = mask_indices < n_masked.view(B, 1, T) # (B, C, T) # Apply the random permutations to the mask mask = torch.gather(mask, 1, random_permutations) return mask # (B, C, T) def maskgit_apply_random_mask(self, codes): # Randomly replaces some codes with the MASK_TOKEN with a proportion following the cosine schedule. # Codes: (B, C, T) mask = self.maskgit_create_random_mask(codes) # replace some tokens with MASK_TOKEN codes_with_mask = torch.where(mask, self.mask_token_id, codes) return codes_with_mask, mask def compute_loss(self, logits, audio_codes, audio_codes_lens, mask_tokens_mask=None, frame_stacking_factor=1): """ Computes the audio codebook loss. Used by (1) The main Magpie-TTS transformer (2) The local transformer, for both autoregressive and MaskGit methods logits: (B, T', num_codebooks * num_tokens_per_codebook) audio_codes: (B, C, T') audio_codes_lens: (B,) mask_tokens_mask: (B, C, T') True for tokens that were replaced with the MASK_TOKEN and should therefore be the only ones included in the loss computation (for MaskGit). frame_stacking_factor: int, the stacking factor used in the model """ loss_mask = get_mask_from_lengths(audio_codes_lens, pad_to_factor=frame_stacking_factor) if mask_tokens_mask is not None: # For MaskGit we only compute loss for the masked tokens. # *Both* conditions must be true: # 1. the token is masked # 2. the token is not padding loss_mask = loss_mask.unsqueeze(1) * mask_tokens_mask if not loss_mask.any(): # Without this we were very rarely getting NaNs in the loss logging.warning("No tokens valid were found in compute_loss()!") return torch.tensor(0.0, device=loss_mask.device), loss_mask else: # repeat loss mask for each codebook to simplify code below loss_mask = loss_mask.unsqueeze(1).repeat(1, audio_codes.size(1), 1) total_codebook_loss = None for fs_index in range(frame_stacking_factor): for codebook in range(audio_codes.size(1)): si = (codebook + self.num_audio_codebooks * fs_index) * self.num_all_tokens_per_codebook ei = si + self.num_all_tokens_per_codebook codebook_logits = logits[:, :, si:ei] # (B, T', num_tokens_per_codebook) codebook_targets = audio_codes[:, codebook, fs_index::frame_stacking_factor] # (B, T') codebook_loss = self.cross_entropy_loss( codebook_logits.permute(0, 2, 1), codebook_targets # (B, num_tokens_per_codebook, T') ) # (B, T') codebook_loss_mask = loss_mask[:, codebook, fs_index::frame_stacking_factor] codebook_loss = codebook_loss * codebook_loss_mask if codebook_loss_mask.sum() == 0: logging.warning(f"Loss mask for codebook {codebook} is all zeros, global_step: {self.global_step}") continue codebook_loss = codebook_loss.sum() / codebook_loss_mask.sum() if total_codebook_loss is None: total_codebook_loss = codebook_loss else: total_codebook_loss = total_codebook_loss + codebook_loss total_codebook_loss = total_codebook_loss / (audio_codes.size(1) * frame_stacking_factor) return total_codebook_loss, loss_mask def forward(self, dec_input_embedded, dec_input_mask, cond, cond_mask, attn_prior, multi_encoder_mapping): decoder_out = self.decoder( dec_input_embedded, dec_input_mask, cond=cond, cond_mask=cond_mask, attn_prior=attn_prior, multi_encoder_mapping=multi_encoder_mapping, ) attn_probabilities = decoder_out['attn_probabilities'] all_code_logits = self.final_proj(decoder_out['output']) # (B, T', num_codebooks * num_tokens_per_codebook) return all_code_logits, attn_probabilities, decoder_out['output'] def logits_to_audio_codes(self, all_code_logits, audio_codes_lens): # all_code_logits: (B, T', num_codebooks * num_tokens_per_codebook) # audio_codes_lens: (B,) all_preds = [[] for _ in range(self.frame_stacking_factor)] for fs_index in range(self.frame_stacking_factor): for idx in range(self.num_audio_codebooks): si = (idx + self.num_audio_codebooks * fs_index) * self.num_all_tokens_per_codebook ei = si + self.num_all_tokens_per_codebook codebook_logits = all_code_logits[:, :, si:ei] codebook_probs = torch.softmax(codebook_logits, dim=-1) # (B, T', num_tokens_per_codebook) # argmax to get the tokens codebook_preds = torch.argmax(codebook_probs, dim=-1) # (B, T') all_preds[fs_index].append(codebook_preds) all_preds = [ torch.stack(p, dim=1) for p in all_preds ] # list of `frame_stacking_factor`` elements of shape (B,C,T) each all_preds = torch.stack(all_preds, dim=-1) # B, C, T, frame_stacking_factor # undo the frame stacking all_preds = all_preds.reshape(all_preds.size(0), all_preds.size(1), -1) # B, C, T*frame_stacking_factor pred_max_len = all_preds.size(2) real_max_len = audio_codes_lens.max() assert (pred_max_len - real_max_len) < self.frame_stacking_factor # trim padding introduced for frame stacking all_preds = all_preds[:, :, :real_max_len] audio_mask = get_mask_from_lengths(audio_codes_lens) all_preds = all_preds * audio_mask.unsqueeze(1) return all_preds def visualize_codes(self, codes, mask_id=2020, frame_stacking_rate=2): """ Visualize codes for analysis purposes codes: (B, C) """ def code_to_str(code): if code == mask_id: return "M " else: return f"{code:04d} " B, C = codes.shape if B > 1: logging.debug("Warning: visualizing only first batch element") codes = codes.clone().detach().cpu().numpy()[0] codes = [code_to_str(c) for c in codes] output_str = "" for i, c in enumerate(codes): if (i) % (C / frame_stacking_rate) == 0: output_str += "|timestep| " output_str += c logging.debug(output_str) def clear_forbidden_logits(self, logits: torch.Tensor, forbid_audio_eos: bool = False) -> torch.Tensor: """ Sets logits of forbidden tokens to `-inf` so they will never be sampled. Specifically, we forbid sampling of all special tokens except AUDIO_EOS which is allowed by default. Args: logits: (B, C, num_audio_tokens_per_codebook) forbid_audio_eos (bool, optional): If True, also forbid AUDIO_EOS tokens from being sampled. Default: False. """ logits[ :, :, SpecialAudioToken.get_forbidden_tokens(self.codebook_size, forbid_audio_eos=forbid_audio_eos), ] = float('-inf') return logits def local_transformer_sample_maskgit( self, dec_output: torch.Tensor, temperature: float = 0.7, topk: int = 80, unfinished_items: Dict[int, bool] = {}, finished_items: Dict[int, bool] = {}, use_cfg: bool = False, cfg_scale: float = 1.0, n_steps: int = 3, noise_scale: float = 0.0, fixed_schedule: Optional[List[int]] = None, dynamic_cfg_scale: bool = False, sampling_type: Optional[str] = None, forbid_audio_eos: bool = False, ) -> torch.Tensor: """ Sample audio codes for the current timestep using MaskGit-like iterative prediction with the local transformer. If frame-stacking is enabled, the codes for all frames in the stack are sampled, treated as one long sequence. The MaskGit process starts with all positions masked and iteratively unmasks the most confident positions over multiple steps. By "masked" we mean that a dedicated MASK token is used (as opposed to attention masking). The LT in this case is a non-causal transformer decoder. At each step the model predicts all positions at once. Of those predictions, a subset of the most confident previously-masked positions is kept and unmasked in the next step. The number of positions that are unmasked at each step is determined by the unmasking schedule. We support a cosine schedule and a fixed schedule provided by the user. Uses multinomial sampling with temperature, top-k, and classifier-free guidance (CFG). Special handling: * forbids special tokens (like AUDIO_BOS, AUDIO_CONTEXT_EOS, etc.) from being sampled * forces / forbids EOS for finished / unfinished items respectively * optionally, globally forbids audio EOS for all items in the batch. This is useful early in the generation process. * supports different unmasking methods, see `sampling_type` argument for details. Args: dec_output (torch.Tensor): Decoder output tensor with shape (B, E) where B is batch size and E is primary decoder's embedding dimension. temperature (float, optional): Sampling temperature topk (int, optional): Number of top-probability tokens to consider in sampling. unfinished_items (dict, optional): Dictionary containing indices of batch items that we are confident have not completed generation. For these items, audio EOS sampling is forbidden. finished_items (dict, optional): Dictionary containing indices of batch items that we are confident are completed. For these items, audio EOS sampling is forced. use_cfg (bool, optional): Whether to use classifier-free guidance. If True, expects batch size to be doubled with conditional and unconditional outputs from the primary decoder. cfg_scale (float, optional): Scale factor for classifier-free guidance. Only used if use_cfg=True. n_steps (int, optional): Number of iterative refinement steps for MaskGit sampling. noise_scale (float, optional): Scale factor for noise to add to confidence scores during sampling (experimental). fixed_schedule (list, optional): Fixed schedule for number of tokens to unmask at each step. If None, uses cosine schedule. dynamic_cfg_scale (bool, optional): Whether to dynamically adjust CFG scale during sampling (experimental). sampling_type (str, optional): Type of sampling strategy. Options are: ["default", "causal", "purity_causal", "purity_default"]. * Purity refers to "purity sampling" from https://arxiv.org/abs/2304.01515. If "purity" is not specified, confidence sampling is used as in the original MaskGit paper. * "default"/"causal": Controls the order of unmasking across frames when frame-stacking is enabled. If "causal" is specified, frames are unmasked in causal order. "default" doesn't impose any constraints on the unmasking order. forbid_audio_eos (bool, optional): Whether to globally forbid audio EOS for the entire batch. Returns: torch.Tensor: Sampled audio codes with shape (B, num_codebooks, frame_stacking_factor) """ # dec_output: (B, E) device = dec_output.device # disable KV cache since our transformer is not causal self.local_transformer.reset_cache(use_cache=False) dec_output = dec_output.unsqueeze(1) # (B, 1, E) local_transformer_input_init = self.local_transformer_in_projection( dec_output ) # (B, 1, D) where D is the dimension of the local transformer codebook_seq_len = self.num_audio_codebooks * self.frame_stacking_factor B = dec_output.size(0) min_confidence = 0 # this needs to be large enough that unmasked items will always remain unmasked (even after noise addition) # Setting it smaller could allow "regret", i.e. re-masking a codebook that was previously unmasked; we might want to try that max_confidence = 5 confidences = min_confidence * torch.ones(B, codebook_seq_len, device=device) # initialize to all masked codes = self.mask_token_id * torch.ones((B, codebook_seq_len), device=device, dtype=torch.long) sampled_codes = codes.clone() topk_indices = None if fixed_schedule is not None: n_steps = len(fixed_schedule) for step in range(n_steps): # how far along we are in the unmasking process progress = step / n_steps # get mask fraction frac_masked = cosine_schedule(torch.tensor(progress)) if sampling_type == "causal" or sampling_type == "purity_causal": frac_masked = torch.ones_like(frac_masked) * (1.0 - progress) # how many codebooks to mask if fixed_schedule is None: n_masked = torch.ceil(codebook_seq_len * frac_masked).long() else: n_masked = codebook_seq_len - fixed_schedule[step] n_unmasked = codebook_seq_len - n_masked if ( sampling_type == "causal" or sampling_type == "purity_causal" ): # and n_unmasked <= self.num_audio_codebooks: # force second frame not to be unmasked n_frames_to_allow = int(np.floor(progress * self.frame_stacking_factor + 1)) confidences[:, n_frames_to_allow * self.num_audio_codebooks :] = ( min_confidence - 1 ) # only tested for frame_stacking_factor=2 # pick top-confidence codebooks up to n_unmasked _, topk_indices = torch.topk(confidences, k=n_unmasked, dim=1) if use_cfg: actual_batch_size = topk_indices.size(0) // 2 assert ( topk_indices[actual_batch_size:] == topk_indices[:actual_batch_size] ).all(), "Topk indices are not the same for conditional and unconditional codes" # replace masks of the top-k confident codebooks with the codes that were sampled for them unmasked_codes = torch.gather(sampled_codes, dim=1, index=topk_indices) codes.scatter_(dim=1, index=topk_indices, src=unmasked_codes) # build transformer input local_transformer_input = local_transformer_input_init for codebook_num in range(codebook_seq_len): next_local_transformer_input = self.audio_embeddings[codebook_num](codes[:, codebook_num]).unsqueeze( 1 ) # (B, 1, 768) next_local_transformer_input = self.local_transformer_in_projection( next_local_transformer_input ) # (B, 1, d_local) local_transformer_input = torch.cat( [local_transformer_input, next_local_transformer_input], dim=1 ) # (B, codebook_num+1, d_local) # run transformer _mask = torch.ones(B, codebook_seq_len + 1, device=device) local_transformer_output = self.local_transformer(local_transformer_input, _mask)[ 'output' ] # (B, C+1, d_local) # get logits logits = [] for codebook_num in range(codebook_seq_len): # The `codebook_num+1` is to drop first position which corresponds to the magpie latent codebook_logits = self.local_transformer_out_projections[codebook_num]( local_transformer_output[:, codebook_num + 1, :] ) # (B, num_audio_tokens_per_codebook) logits.append(codebook_logits) logits = torch.stack(logits, dim=1) # (B, C*frame_stacking_factor, num_audio_tokens_per_codebook) # apply CFG if use_cfg: actual_batch_size = logits.size(0) // 2 conditional_logits = logits[:actual_batch_size] unconditional_logits = logits[actual_batch_size:] if not dynamic_cfg_scale: current_cfg_scale = cfg_scale else: # gradually increase the scale until mid point through sampling, then reduce it again progress = step / (n_steps - 1) # interp = -abs(progress-0.5)+0.5 # increase from 0..1 in the interval from start to midpoint and then go back to zero # interp = 1.0 - progress # decrease from 1 to 0 interp = progress # gradually increase from 0 to 1 current_cfg_scale = (cfg_scale - 1) * interp + 1.0 # 1.0 --> cfg_scale --> 1.0 cfg_logits = current_cfg_scale * conditional_logits + (1.0 - current_cfg_scale) * unconditional_logits logits[:actual_batch_size] = cfg_logits # Disallow generation of special tokens logits = self.clear_forbidden_logits(logits, forbid_audio_eos=forbid_audio_eos) # handle unfinished and finished items for item_idx in unfinished_items: logits[item_idx, self.audio_eos_id] = float('-inf') for item_idx in finished_items: logits[item_idx, :, :] = float('-inf') logits[item_idx, :, self.audio_eos_id] = 0.0 # sample with top-k logits_topk = torch.topk(logits, topk, dim=-1)[0] # (B, C, topk) indices_to_remove = logits < logits_topk[:, :, -1].unsqueeze(-1) # (B, C, num_audio_tokens_per_codebook) logits_rescored = logits.clone() logits_rescored[indices_to_remove] = float('-inf') probs = torch.softmax(logits_rescored / temperature, dim=-1) # (B, C, num_audio_tokens_per_codebook) sampled_codes = torch.multinomial(probs.view(B * codebook_seq_len, -1), 1).view(B, codebook_seq_len) if use_cfg: sampled_codes[actual_batch_size:] = sampled_codes[:actual_batch_size] probs[actual_batch_size:] = probs[:actual_batch_size] if sampling_type != "purity_causal" and sampling_type != "purity_default": confidences = torch.gather(probs, dim=2, index=sampled_codes.unsqueeze(-1)).squeeze(-1) else: # use the max probability across all tokens for each codebook as the confidence for each codebook; known as "purity sampling" confidences = probs.max(dim=2)[0] # replace entries in sampled_codes with previously unmasked codebooks sampled_codes.scatter_(dim=1, index=topk_indices, src=unmasked_codes) # add noise to confidences (as in token-critic paper, https://arxiv.org/abs/2209.04439) if noise_scale > 0.0: # get noise from uniform distribution in the interval [-0.5, 0.5), scale it by `noise_scale`, # and anneal it to 0 as we approach the end of the unmasking process noise = ( (torch.rand_like(confidences) - 0.5) * noise_scale * (1 - (step + 2) / n_steps) ) # the +2 makes sure that by the last iteration the noise is exactly 0 confidences += noise # the conditional and unconditional get different noise and must be fixed to be the same again confidences[actual_batch_size:] = confidences[:actual_batch_size] confidence_eps = 0.1 assert ( confidences.max() + confidence_eps < max_confidence ), f"Predicted confidence is approaching max_confidence: {confidences.max()}" # for unmasked codebooks, set confidence to max so that they will remain unmasked confidences.scatter_( index=topk_indices, dim=1, src=max_confidence * torch.ones_like(topk_indices, dtype=torch.float) ) codes = sampled_codes assert not ( codes == self.mask_token_id ).any(), "Codes contain mask tokens after completion of MaskGit sampling" # break stacked groups of frames into individual frames codes = codes.reshape(B, self.frame_stacking_factor, self.num_audio_codebooks).permute( 0, 2, 1 ) # B, C, frame_stacking_factor if use_cfg: # drop unconditional codes codes = codes[:actual_batch_size] return codes def local_transformer_sample_autoregressive( self, dec_output: torch.Tensor, temperature: float = 0.7, topk: int = 80, unfinished_items: Dict[int, bool] = {}, finished_items: Dict[int, bool] = {}, use_cfg: bool = False, cfg_scale: float = 1.0, use_kv_cache: bool = True, forbid_audio_eos: bool = False, ) -> torch.Tensor: """ Sample audio codes autoregressively across codebooks using the local transformer. Uses multinomial sampling with temperature, top-k, and classifier-free guidance (CFG). The sequence is initialized with the primary decoder's hidden output as the only input and is gradually extended a code for one codebook at a time, appending the sampled code as input sequence for the next step. At the last step the sequence is `num_codebooks` long. If frame stacking is enabled, codes for all frames in the stack are sampled as one long sequence and the final sequence length is `num_codebooks * frame_stacking_factor` codes long. Special handling: * forbids special tokens (like AUDIO_BOS, AUDIO_CONTEXT_EOS, etc.) from being sampled * forces / forbids EOS for finished / unfinished items respectively * optionally, globally forbids audio EOS (useful early in the generation process) Args: dec_output (torch.Tensor): Decoder output tensor with shape (B, E) where B is batch size and E is primary decoder's embedding dimension. temperature (float, optional): Sampling temperature. topk (int, optional): Number of top-probability tokens to consider in sampling. unfinished_items (dict, optional): Dictionary containing indices of batch items that we are confident have not completed generation. For these items, audio EOS sampling is forbidden. finished_items (dict, optional): Dictionary containing indices of batch items that we are confident are completed. For these items, audio EOS sampling is forced. use_cfg (bool, optional): Whether to use classifier-free guidance. If True, expects batch size to be doubled with conditional and unconditional outputs from the primary decoder. cfg_scale (float, optional): Scale factor for classifier-free guidance. Only used if use_cfg=True. use_kv_cache (bool, optional): Whether to use key-value caching in the transformer. forbid_audio_eos (bool, optional): Whether to globally forbid audio EOS for the entire batch. Returns: torch.Tensor: Sampled audio codes with shape (B, num_codebooks, frame_stacking_factor) where B is batch size (or actual_batch_size if use_cfg=True). """ self.local_transformer.reset_cache(use_cache=use_kv_cache) dec_output = dec_output.unsqueeze(1) # (B, 1, E) local_transformer_input = self.local_transformer_in_projection(dec_output) # (B, 1, 128) all_preds = [] for codebook_num in range(self.num_audio_codebooks * self.frame_stacking_factor): _mask = torch.ones( local_transformer_input.size(0), local_transformer_input.size(1), device=local_transformer_input.device ) local_transformer_output = self.local_transformer(local_transformer_input, _mask)['output'] # (B, T, 128) codebook_logits = self.local_transformer_out_projections[codebook_num]( local_transformer_output[:, -1, :] ) # (B, num_all_tokens_per_codebook) if use_cfg: actual_batch_size = codebook_logits.size(0) // 2 conditional_logits = codebook_logits[:actual_batch_size] unconditional_logits = codebook_logits[actual_batch_size:] cfg_logits = cfg_scale * conditional_logits + (1.0 - cfg_scale) * unconditional_logits codebook_logits[:actual_batch_size] = cfg_logits for item_idx in unfinished_items: codebook_logits[item_idx, self.audio_eos_id] = float('-inf') for item_idx in finished_items: codebook_logits[item_idx, :] = float('-inf') codebook_logits[item_idx, self.audio_eos_id] = 0.0 # Disallow generation of special tokens codebook_logits = self.clear_forbidden_logits( codebook_logits.unsqueeze(1), forbid_audio_eos=forbid_audio_eos ).squeeze(1) codebook_logits_topk = torch.topk(codebook_logits, topk, dim=-1)[0] # (B, topk) indices_to_remove = codebook_logits < codebook_logits_topk[:, -1].unsqueeze( -1 ) # (B, num_tokens_per_codebook) codebook_logits_rescored = codebook_logits.clone() codebook_logits_rescored[indices_to_remove] = float('-inf') codebook_probs = torch.softmax( codebook_logits_rescored / temperature, dim=-1 ) # (B, num_tokens_per_codebook) codebook_preds = torch.multinomial(codebook_probs, 1) # (B, 1) if use_cfg: codebook_preds[actual_batch_size:] = codebook_preds[:actual_batch_size] all_preds.append(codebook_preds) next_local_transformer_input = self.audio_embeddings[codebook_num](codebook_preds.squeeze(-1)).unsqueeze( 1 ) # (B, 1, 128) next_local_transformer_input = self.local_transformer_in_projection( next_local_transformer_input ) # (B, 1, 128) local_transformer_input = torch.cat( [local_transformer_input, next_local_transformer_input], dim=1 ) # (B, T+1, 128) all_preds = torch.cat(all_preds, dim=1).long() # (B, num_codebooks * frame_stacking_factor) all_preds = all_preds.reshape(-1, self.frame_stacking_factor, self.num_audio_codebooks).permute( 0, 2, 1 ) # (B, num_codebooks, frame_stacking_factor) if use_cfg: all_preds = all_preds[:actual_batch_size] return all_preds def sample_codes_from_logits( self, all_code_logits_t: torch.Tensor, temperature: float = 0.7, topk: int = 80, unfinished_items: Dict[int, bool] = {}, finished_items: Dict[int, bool] = {}, forbid_audio_eos: bool = False, ) -> torch.Tensor: """ Sample codes for all codebooks at a given timestep. Uses multinomial sampling with temperature and top-k. If frame stacking is on (i.e. `frame_stacking_factor > 1`), this function will sample across the entire frame stack. Special handling: * forbids special tokens (like AUDIO_BOS, AUDIO_CONTEXT_EOS, etc.) from being sampled * forces / forbids EOS for finished / unfinished items respectively * optionally, globally forbids audio EOS (useful early in the generation process) Args: all_code_logits_t (torch.Tensor): Logits at a given timestep with shape (B, num_tokens_per_codebook * num_codebooks * frame_stacking_factor) temperature (float, optional): Sampling temperature topk (int, optional): Number of top-probability tokens to consider in sampling. unfinished_items (dict, optional): Dictionary containing indices of batch items that we are confident have not completed generation. For these items, audio EOS sampling is forbidden. finished_items (dict, optional): Dictionary containing indices of batch items that we are confident are completed. For these items, audio EOS sampling is forced. forbid_audio_eos (bool, optional): Whether to globally forbid audio EOS for the entire batch. Returns: torch.Tensor: Sampled audio codes with shape (B, num_codebooks, frame_stacking_factor). """ all_preds = [[] for _ in range(self.frame_stacking_factor)] for fs_index in range(self.frame_stacking_factor): for idx in range(self.num_audio_codebooks): si = (idx + self.num_audio_codebooks * fs_index) * self.num_all_tokens_per_codebook ei = si + self.num_all_tokens_per_codebook codebook_logits = all_code_logits_t[:, si:ei] # (B, num_tokens_per_codebook) for item_idx in unfinished_items: codebook_logits[item_idx, self.audio_eos_id] = float('-inf') for item_idx in finished_items: codebook_logits[item_idx, :] = float('-inf') codebook_logits[item_idx, self.audio_eos_id] = 0.0 # Disallow generation of special tokens codebook_logits = self.clear_forbidden_logits( codebook_logits.unsqueeze(1), forbid_audio_eos=forbid_audio_eos ).squeeze(1) codebook_logits_topk = torch.topk(codebook_logits, topk, dim=-1)[0] # (B, topk) indices_to_remove = codebook_logits < codebook_logits_topk[:, -1].unsqueeze( -1 ) # (B, num_tokens_per_codebook) codebook_logits_rescored = codebook_logits.clone() codebook_logits_rescored[indices_to_remove] = float('-inf') codebook_probs = torch.softmax( codebook_logits_rescored / temperature, dim=-1 ) # (B, num_tokens_per_codebook) codebook_preds = torch.multinomial(codebook_probs, 1) # (B, 1) all_preds[fs_index].append(codebook_preds) all_preds = [ torch.cat(ds_preds, dim=1).long() for ds_preds in all_preds ] # list of `frame_stacking_factor` elements, each of shape (B, num_codebooks) all_preds = torch.stack(all_preds, dim=2) # (B, num_codebooks, frame_stacking_factor) return all_preds def log_attention_probs(self, attention_prob_matrix, audio_codes_lens, text_lens, prefix="", dec_context_size=0): # attention_prob_matrix List of (B, C, audio_timesteps, text_timesteps) wandb_images_log = {} with torch.no_grad(): attention_prob_matrix = torch.cat(attention_prob_matrix, dim=1) # (B, C, audio_timesteps, text_timesteps) attention_prob_matrix_mean = attention_prob_matrix.mean(dim=1) # (B, audio_timesteps, text_timesteps) for logger in self.loggers: is_wandb = isinstance(logger, WandbLogger) is_tb = isinstance(logger, TensorBoardLogger) if not is_wandb and not is_tb: raise ValueError( f"Invalid logger type for image logging: {type(logger)}. Only `WandbLogger` and `TensorBoardLogger` are supported." ) wandb_images_log[f"Image/{prefix}/attention_matrix"] = list() for idx in range(min(3, attention_prob_matrix_mean.size(0))): item_attn_matrix = attention_prob_matrix_mean[idx][ dec_context_size : dec_context_size + audio_codes_lens[idx], : text_lens[idx] ] item_attn_matrix = item_attn_matrix.detach().cpu().numpy() img_np = plot_alignment_to_numpy(item_attn_matrix.T) if is_wandb: wandb_images_log[f"Image/{prefix}/attention_matrix"].append( wandb.Image(img_np, caption=f"Example_{idx}") ) if is_tb: logger.experiment.add_image( f'{prefix}/attention_matrix/Example_{idx}', img_np, global_step=self.global_step, dataformats="HWC", ) return wandb_images_log def log_val_audio_example( self, logits, target_audio_codes, audio_codes_lens_target, context_audio_codes=None, context_audio_codes_lens=None, ): wandb_audio_log = {} pred_audio_codes = self.logits_to_audio_codes(logits, audio_codes_lens_target) pred_audio, pred_audio_lens = self.codes_to_audio(pred_audio_codes, audio_codes_lens_target) target_audio, target_audio_lens = self.codes_to_audio(target_audio_codes, audio_codes_lens_target) context_audio, context_audio_lens = None, None if context_audio_codes is not None and context_audio_codes.shape[2] > 3: # > 3 ensures, it is a valid context audio tensor (and not dummy tensor used in text context) context_audio, context_audio_lens = self.codes_to_audio(context_audio_codes, context_audio_codes_lens) for logger in self.loggers: is_wandb = isinstance(logger, WandbLogger) is_tb = isinstance(logger, TensorBoardLogger) if not is_wandb and not is_tb: raise ValueError( f"Invalid logger type for audio logging: {type(logger)}. Only `WandbLogger` and `TensorBoardLogger` are supported." ) for idx in range(min(3, pred_audio.size(0))): pred_audio_np = pred_audio[idx].float().detach().cpu().numpy() target_audio_np = target_audio[idx].float().detach().cpu().numpy() pred_audio_np = pred_audio_np[: pred_audio_lens[idx]] target_audio_np = target_audio_np[: target_audio_lens[idx]] context_audio_np = None if context_audio is not None: context_audio_np = context_audio[idx].float().detach().cpu().numpy() context_audio_np = context_audio_np[: context_audio_lens[idx]] if is_wandb: wandb_audio_log[f"Audio/Example_{idx}"] = list() if context_audio_np is not None: wandb_audio_log[f"Audio/Example_{idx}"].append( wandb.Audio(context_audio_np, sample_rate=self.sample_rate, caption="context") ) wandb_audio_log[f"Audio/Example_{idx}"].append( wandb.Audio(pred_audio_np, sample_rate=self.sample_rate, caption="prediction") ) wandb_audio_log[f"Audio/Example_{idx}"].append( wandb.Audio(target_audio_np, sample_rate=self.sample_rate, caption="target") ) if is_tb: if context_audio_np is not None: logger.experiment.add_audio( f'Example_{idx}/context', context_audio_np, global_step=self.global_step, sample_rate=self.sample_rate, ) logger.experiment.add_audio( f'Example_{idx}/prediction', pred_audio_np, global_step=self.global_step, sample_rate=self.sample_rate, ) logger.experiment.add_audio( f'Example_{idx}/target', target_audio_np, global_step=self.global_step, sample_rate=self.sample_rate, ) return wandb_audio_log def scale_prior(self, prior, global_step): if prior is None: return None if global_step < self.prior_scaledown_start_step: return prior elif global_step >= self.prior_end_step: if random.random() < self.indefinite_prior_prob: print("Using Prior") return prior else: print("Not using Prior") return None else: with torch.no_grad(): # Interpolate between all ones and the prior residual = 1.0 - prior new_prior = prior + ( residual * (global_step - self.prior_scaledown_start_step) / (self.prior_end_step - self.prior_scaledown_start_step) ) return new_prior def embed_text(self, text, text_mask): if self.use_bpe_char_tokenizer: text_embedded = self.cas_encoder(text, subword_mask=text_mask) else: text_embedded = self.text_embedding(text) return text_embedded def compute_alignment_loss(self, attention_scores, text_lens, audio_lens, dec_context_size=0): # attention scores: List of (B, C, audio_timesteps, text_timesteps) attention_scores_combined = torch.cat(attention_scores, dim=1) # (B, C, audio_timesteps, text_timesteps) attention_scores_mean = attention_scores_combined.mean( dim=1, keepdim=True ) # (B, 1, audio_timesteps, text_timesteps) attention_scores_mean = attention_scores_mean[ :, :, dec_context_size:, : ] # Remove the context audio embeddings from the attention scores alignment_loss = self.alignment_loss( attn_logprob=attention_scores_mean, in_lens=text_lens, out_lens=audio_lens ) return alignment_loss def pad_audio_codes(self, audio_codes: torch.Tensor, frame_stacking_factor: int = 1, pad_token: int = 0): """ Pads the time dimension of the audio codes to a multiple of the frame stacking factor. Args: audio_codes (torch.Tensor): B, C, T frame_stacking_factor (int): The factor that frames will be stacked by. pad_token (int): The token ID to pad with. Returns: B, C, T_padded """ T = audio_codes.size(2) T_padded = int(np.ceil(T / frame_stacking_factor) * frame_stacking_factor) if T_padded > T: padding = pad_token * torch.ones( audio_codes.size(0), audio_codes.size(1), T_padded - T, device=audio_codes.device, dtype=audio_codes.dtype, ) audio_codes = torch.cat([audio_codes, padding], dim=2) return audio_codes def embed_context_text(self, context_text_tokens): if self.legacy_text_conditioning: context_text_tokens = ( context_text_tokens - self.tokenizer.tokenizer_offsets[self.text_conditioning_tokenizer_name] ) context_text_embedded = self.context_text_embedding(context_text_tokens) # (B, L, E) else: context_text_embedded = self.text_embedding(context_text_tokens) # (B, L, E) return context_text_embedded def prepare_context_tensors(self, batch): dec_context_size = 0 additional_decoder_input = None additional_decoder_mask = None context_audio_codes = None context_audio_codes_lens = None _attn_prior = None attn_prior = None cond = None cond_mask = None multi_encoder_mapping = None text = None text_lens = None # self.model_type must be one of [multi_encoder_context_tts, decoder_context_tts, decoder_ce] text = batch['text'] text_lens = batch['text_lens'] text_mask = get_mask_from_lengths(text_lens) # (B, T) text_embedded = self.embed_text(text, text_mask) # (B, T, E) text_encoder_out = self.encoder(text_embedded, text_mask, cond=None, cond_mask=None)['output'] # (B, T, E) _attn_prior = batch.get('align_prior_matrix', None) _attn_prior = self.scale_prior(_attn_prior, self.global_step) if self.model_type in ['multi_encoder_context_tts', 'decoder_context_tts', 'decoder_ce']: if 'context_audio_codes' in batch: context_audio_codes = batch['context_audio_codes'] context_audio_codes_lens = batch['context_audio_codes_lens'] if self._codec_converter is not None: context_audio_codes = self._codec_converter.convert_original_to_new( audio_tokens=context_audio_codes, audio_lens=context_audio_codes_lens ).long() else: context_audio_codes, context_audio_codes_lens = self.audio_to_codes( batch['context_audio'], batch['context_audio_lens'], audio_type='context' ) context_audio_codes = self.pad_audio_codes(context_audio_codes, self.frame_stacking_factor, pad_token=0) context_audio_embedded = self.embed_audio_tokens(context_audio_codes) # (B, T/frame_stacking_factor, E) if self.use_text_conditioning_encoder: context_text_tokens = batch['context_text_tokens'] context_text_lens = batch['context_text_tokens_lens'] context_text_embedded = self.embed_context_text(context_text_tokens) # (B, L, E) # Pad context_audio_embedded or context_text_embedded so that they have same number of timesteps if context_audio_embedded.size(1) < context_text_embedded.size(1): padding = torch.zeros( context_audio_embedded.size(0), context_text_embedded.size(1) - context_audio_embedded.size(1), context_audio_embedded.size(2), device=context_audio_embedded.device, ) context_audio_embedded = torch.cat([context_audio_embedded, padding], dim=1) elif context_audio_embedded.size(1) > context_text_embedded.size(1): padding = torch.zeros( context_text_embedded.size(0), context_audio_embedded.size(1) - context_text_embedded.size(1), context_text_embedded.size(2), device=context_text_embedded.device, ) context_text_embedded = torch.cat([context_text_embedded, padding], dim=1) # (B, T, E) has_text_context = batch['has_text_context'].unsqueeze(-1).unsqueeze(-1).float() # (B, 1, 1) context_input_embedded = ( has_text_context * context_text_embedded + (1 - has_text_context) * context_audio_embedded ) context_input_lens = ( batch['has_text_context'].float() * context_text_lens + (1 - batch['has_text_context'].float()) * context_audio_codes_lens ) # (B,) else: context_input_embedded = context_audio_embedded context_input_lens = context_audio_codes_lens context_input_lens = torch.ceil(context_input_lens / self.frame_stacking_factor).to( context_input_lens.dtype ) context_mask = get_mask_from_lengths(context_input_lens) if self.model_type == 'multi_encoder_context_tts': context_embeddings = self.context_encoder( context_input_embedded, context_mask, cond=None, cond_mask=None )['output'] cond = [text_encoder_out, context_embeddings] cond_mask = [text_mask, context_mask] multi_encoder_mapping = self.multi_encoder_mapping attn_prior = [_attn_prior, None] elif self.model_type in ['decoder_context_tts', 'decoder_ce']: context_embeddings = None # Address CodeQL if self.model_type == 'decoder_context_tts': context_embeddings = context_input_embedded elif self.model_type == 'decoder_ce': # Check for baked context embedding first if self.has_baked_context_embedding: # self.baked_context_embedding is a fixed context embedding that is baked into the model. # This is used when we do not want users to generate speech with context audio or context text. # This is done to disable zero-shot inference. Users can only generate speech in 1 voice chosen # by the model development team. batch_size = text.size(0) # Expand baked embedding to batch size: (T, E) -> (B, T, E) context_embeddings = self.baked_context_embedding.unsqueeze(0).expand(batch_size, -1, -1) # Create context mask from baked length context_input_lens = ( self.baked_context_embedding_len.unsqueeze(0).expand(batch_size).to(text.device) ) context_mask = get_mask_from_lengths(context_input_lens) else: context_embeddings = self.context_encoder( context_input_embedded, context_mask, cond=None, cond_mask=None )['output'] dec_context_size = context_mask.size(1) attn_prior = _attn_prior if attn_prior is not None: # B, audio_timesteps, text_timesteps padding_zeros = torch.zeros( attn_prior.size(0), dec_context_size, attn_prior.size(2), device=attn_prior.device ) attn_prior = torch.cat([padding_zeros, attn_prior], dim=1) cond = text_encoder_out cond_mask = text_mask multi_encoder_mapping = None additional_decoder_input = context_embeddings additional_decoder_mask = context_mask else: raise ValueError(f"Unsupported model type {self.model_type}") if attn_prior is not None and self.ctc_prior_layer_ids is not None: # Convert prior to a list of tensors, one for each layer # Set None for layers not in ctc_prior_layer_ids if self.model_type == 'multi_encoder_context_tts': text_attn_prior = [ attn_prior[0] if layer_idx in self.ctc_prior_layer_ids else None for layer_idx in range(self.decoder.n_layers) ] attn_prior = [text_attn_prior, attn_prior[1]] else: attn_prior = [ attn_prior if layer_idx in self.ctc_prior_layer_ids else None for layer_idx in range(self.decoder.n_layers) ] return { 'beta_binomial_attn_prior': batch.get('align_prior_matrix', None), 'text_encoder_out': text_encoder_out, 'cond': cond, 'cond_mask': cond_mask, 'attn_prior': attn_prior, 'prior_used': _attn_prior is not None, 'multi_encoder_mapping': multi_encoder_mapping, 'additional_decoder_input': additional_decoder_input, 'additional_decoder_mask': additional_decoder_mask, 'dec_context_size': dec_context_size, 'text': text, 'text_embedded': text_embedded, 'text_mask': text_mask, 'text_lens': text_lens, 'context_audio_codes': context_audio_codes, 'context_audio_codes_lens': context_audio_codes_lens, } def replace_beta_binomial_prior_with_binarized(self, attn_prior, aligner_attn_hard): # aligner_attn_hard B, audio_timesteps, text_timesteps if self.model_type == 'multi_encoder_context_tts': text_attn_prior = attn_prior[0] else: text_attn_prior = attn_prior assert text_attn_prior is not None, "Prior is None" if isinstance(text_attn_prior, list): # Layer wise prior prior_updated = False for idx, prior in enumerate(text_attn_prior): if prior is not None: text_attn_prior[idx][:, -aligner_attn_hard.size(1) :, :] = aligner_attn_hard prior_updated = True assert prior_updated, "Did not find any prior to update" else: # Same prior for all layers text_attn_prior[:, -aligner_attn_hard.size(1) :, :] = aligner_attn_hard if self.model_type == 'multi_encoder_context_tts': attn_prior[0] = text_attn_prior else: attn_prior = text_attn_prior return attn_prior def get_binarized_prior_matrix(self, aligner_attn_soft, audio_lens, text_lens): # aligner_attn_soft B, 1, audio_timesteps, text_timesteps if self.binarize_attn_method == 'nemo_binarize': logging.debug("Binarizing attention using nemo_binarize") binarize_repeat_audio_factor = self.binarize_repeat_audio_factor aligner_attn_soft_repeated = aligner_attn_soft.repeat_interleave( binarize_repeat_audio_factor, dim=2 ) # B, 1, 2*audio_timesteps, text_timesteps aligner_attn_hard = binarize_attention_parallel( aligner_attn_soft_repeated, text_lens, audio_lens * binarize_repeat_audio_factor ).squeeze( 1 ) # B, 2*audio_timesteps, text_timesteps aligner_attn_hard = aligner_attn_hard[:, ::2, :] # B, audio_timesteps, text_timesteps elif self.binarize_attn_method == 'argmax': logging.debug("Binarizing attention using argmax") aligner_attn_hard = torch.argmax(aligner_attn_soft.squeeze(1), dim=-1) aligner_attn_hard = torch.nn.functional.one_hot( aligner_attn_hard, num_classes=aligner_attn_soft.size(-1) ).float() else: raise ValueError( f"self.binarize_attn_method '{self.binarize_attn_method}' must be one of 'nemo_binarize' or 'argmax'." ) aligner_attn_hard_wider = aligner_attn_hard + self.binarized_prior_epsilon for future_timestep in range(self.prior_future_context): decay_factor = self.prior_future_decay ** (future_timestep + 1) aligner_attn_hard_wider[:, :, future_timestep + 1 :] += ( decay_factor * aligner_attn_hard[:, :, : -(future_timestep + 1)] ) for past_timestep in range(self.prior_past_context): decay_factor = self.prior_past_decay ** (past_timestep + 1) aligner_attn_hard_wider[:, :, : -past_timestep - 1] += ( decay_factor * aligner_attn_hard[:, :, past_timestep + 1 :] ) aligner_attn_hard_wider = torch.clamp(aligner_attn_hard_wider, 0.0, 1.0) return aligner_attn_hard_wider def prepare_dummy_cond_for_cfg(self, cond, cond_mask, additional_decoder_input, additional_dec_mask): dummy_additional_decoder_input = None dummy_additional_dec_mask = None if additional_decoder_input is not None: dummy_additional_decoder_input = torch.zeros_like(additional_decoder_input) # all ones mask means dont ignore any timesteps (so that it is consistent with usual decoder mask) dummy_additional_dec_mask = torch.ones_like(additional_dec_mask) if isinstance(cond, list): # multi encoder conditioning dummy_cond = [torch.zeros_like(cond_item) for cond_item in cond] attn_prior = [None for _ in cond] dummy_mask = [] for mask_item in cond_mask: # ignore all timesteps except the first one mask = torch.zeros_like(mask_item) mask[:, 0] = 1 # Make first timestep all zeros dummy_mask.append(mask) elif isinstance(cond, torch.Tensor): # single encoder conditioning dummy_cond = torch.zeros_like(cond) dummy_mask = torch.zeros_like(cond_mask) dummy_mask[:, 0] = 1 # ignore all timesteps except the first one attn_prior = None else: raise ValueError(f"Unsupported type for cond {type(cond)}") return dummy_cond, dummy_mask, dummy_additional_decoder_input, dummy_additional_dec_mask, attn_prior def process_batch(self, batch, mode="train"): context_tensors = self.prepare_context_tensors(batch) disable_alignment_loss = False if 'audio_codes' not in batch: audio_codes, audio_codes_lens = self.audio_to_codes(batch['audio'], batch['audio_lens']) else: audio_codes = batch['audio_codes'] audio_codes_lens = batch['audio_codes_lens'] if self._codec_converter: audio_codes = self._codec_converter.convert_original_to_new( audio_tokens=audio_codes, audio_lens=audio_codes_lens ).long() if self.frame_stacking_factor > 1: # repeat the BOS token to frame_stacking_factor times. This is necessary since at inference # we need to start autoregressive generation from a full stack indicating BOS. # TODO: @rfejgin: this assert might be slow due to GPU/CPU sync assert (audio_codes[:, :, 0] == self.audio_bos_id).all(), "Audio codes do not start with BOS token" audio_codes = torch.cat( [ torch.full( (audio_codes.size(0), audio_codes.size(1), self.frame_stacking_factor - 1), self.audio_bos_id, device=audio_codes.device, dtype=audio_codes.dtype, ), audio_codes, ], dim=2, ) audio_codes_lens += self.frame_stacking_factor - 1 # account for BOS repeat audio_codes = self.pad_audio_codes(audio_codes, self.frame_stacking_factor, pad_token=0) # Note: if a tensor lacks the `_unstacked` suffix, it can be assumed to to be in the frame-stacked domain # drop last (stacked) frame since it is not part of *input* audio_codes_input_unstacked = audio_codes[:, :, : -self.frame_stacking_factor] # B, C, T' # drop first (stacked) frame which contains BOS token(s) which are not part of *target* audio_codes_target_unstacked = audio_codes[:, :, self.frame_stacking_factor :] audio_codes_lens_input_unstacked = audio_codes_lens - 1 # don't count EOS for input audio_codes_lens_target_unstacked = audio_codes_lens - self.frame_stacking_factor # don't count BOS for target audio_codes_lens_input = torch.floor(audio_codes_lens_input_unstacked / self.frame_stacking_factor).long() audio_codes_embedded_all = self.embed_audio_tokens( audio_codes ) # (B, T, E) # Computing this to be use in the alignment encoder audio_codes_embedded = audio_codes_embedded_all[ :, :-1, : ] # (B, T', E) Input to the decoder; this is already in the frame-stacked domain, hence the -1 (not `frame_stacking_factor`) audio_codes_mask = get_mask_from_lengths(audio_codes_lens_input) use_cfg = (self.cfg_unconditional_prob > 0.0) and (mode == "train") and (context_tensors['cond'] is not None) if use_cfg and torch.rand(1).item() < self.cfg_unconditional_prob: cond, cond_mask, additional_decoder_input, additional_decoder_mask, attn_prior = ( self.prepare_dummy_cond_for_cfg( context_tensors['cond'], context_tensors['cond_mask'], context_tensors['additional_decoder_input'], context_tensors['additional_decoder_mask'], ) ) disable_alignment_loss = True else: cond = context_tensors['cond'] cond_mask = context_tensors['cond_mask'] additional_decoder_input = context_tensors['additional_decoder_input'] additional_decoder_mask = context_tensors['additional_decoder_mask'] attn_prior = context_tensors['attn_prior'] if mode == "train" and self.decoder_input_dropout_prob > 0.0 and torch.rand(1).item() < 0.5: # For some batches (half of them), replace decoder_input_dropout_prob of the timesteps with random tokens max_codebook_val = self.dec_random_input_max # @pneekhara: Keeping dec_random_input_max configurable since num_all_tokens_per_codebook usually has padding tokens # which can cause errors when doing codes_to_audio for audio_codes_input. We are not currently calling codes_to_audio on # audio_codes_input so should not matter if we don't supply dec_random_input_max. random_audio_tokens = torch.randint( 0, max_codebook_val, audio_codes_input_unstacked.size(), device=audio_codes_input_unstacked.device ) random_audio_tokens = random_audio_tokens * audio_codes_mask.unsqueeze(1) dec_dropout_mask = ( torch.rand((1, 1, audio_codes_input_unstacked.size(2)), device=audio_codes_input_unstacked.device) > self.decoder_input_dropout_prob ) # timestep_mask is True for timesteps to be kept audio_codes_input_unstacked = audio_codes_input_unstacked * dec_dropout_mask + random_audio_tokens * ( ~dec_dropout_mask ) audio_codes_embedded = self.embed_audio_tokens(audio_codes_input_unstacked) # (B, T', E) if context_tensors['additional_decoder_input'] is not None: dec_input_embedded = torch.cat([additional_decoder_input, audio_codes_embedded], dim=1) dec_input_mask = torch.cat([additional_decoder_mask, audio_codes_mask], dim=1) else: dec_input_embedded = audio_codes_embedded dec_input_mask = audio_codes_mask aligner_encoder_loss = None aligner_attn_soft = None aligner_attn_hard = None if self.use_alignment_encoder and not disable_alignment_loss: aligner_prior = None if self.use_prior_for_aligner: aligner_prior = context_tensors['beta_binomial_attn_prior'] # Passing target audio embeddings to the alignment encoder if self.global_step < self.aligner_encoder_train_steps: aligner_attn_soft, aligner_attn_logprobs = self.alignment_encoder( queries=audio_codes_embedded_all[:, 1:, :].permute(0, 2, 1), # B, E, T' keys=context_tensors['text_encoder_out'].permute(0, 2, 1), # B, E, T mask=~context_tensors['text_mask'].unsqueeze(-1), attn_prior=aligner_prior, ) aligner_encoder_loss = self.alignment_encoder_loss( attn_logprob=aligner_attn_logprobs, in_lens=context_tensors['text_lens'], out_lens=audio_codes_lens_input, ) else: with torch.no_grad(): # Just get the attention matrix without computing the loss or gradients aligner_attn_soft, aligner_attn_logprobs = self.alignment_encoder( queries=audio_codes_embedded_all[:, 1:, :].permute(0, 2, 1), # B, E, T' keys=context_tensors['text_encoder_out'].permute(0, 2, 1), # B, E, T mask=~context_tensors['text_mask'].unsqueeze(-1), attn_prior=aligner_prior, ) with torch.no_grad(): aligner_attn_hard = self.get_binarized_prior_matrix( aligner_attn_soft, audio_codes_lens_input, context_tensors['text_lens'] ) if (self.global_step > self.binarize_prior_after_step) and context_tensors['prior_used']: attn_prior = self.replace_beta_binomial_prior_with_binarized(attn_prior, aligner_attn_hard) logits, attn_info, dec_out = self.forward( dec_input_embedded=dec_input_embedded, dec_input_mask=dec_input_mask, cond=cond, cond_mask=cond_mask, attn_prior=attn_prior, multi_encoder_mapping=context_tensors['multi_encoder_mapping'], ) # logits: (B, T', num_codebooks * num_tokens_per_codebook) # dec_out: (B, T', E) dec_context_size = context_tensors['dec_context_size'] logits = logits[:, dec_context_size:, :] # Remove the context audio embeddings from the logits # Codebook loss (parallel) codebook_loss, loss_mask = self.compute_loss( logits, audio_codes_target_unstacked, audio_codes_lens_target_unstacked, frame_stacking_factor=self.frame_stacking_factor, ) # Alignment loss alignment_loss = None if self.alignment_loss_scale > 0.0 and not disable_alignment_loss: text_lens = context_tensors['text_lens'] cross_attention_scores = [ attn['cross_attn_probabilities'][1] for layer_idx, attn in enumerate(attn_info) if layer_idx in self.ctc_prior_layer_ids ] alignment_loss = self.compute_alignment_loss( cross_attention_scores, text_lens, audio_codes_lens_input, dec_context_size ) loss = self.codebook_loss_scale * codebook_loss + alignment_loss else: loss = self.codebook_loss_scale * codebook_loss # Local Transformer loss local_transformer_loss = None local_transformer_logits = None if self.local_transformer_type != LocalTransformerType.NO_LT: if self.local_transformer_type == LocalTransformerType.MASKGIT: # Maskgit # randomly replace some positions with MASK_TOKEN audio_codes_masked, mask_tokens_mask = self.maskgit_apply_random_mask(audio_codes_target_unstacked) # TODO @rfejgin: the very last position might be padding but the local transformer might look at it as part of # of a pair where the first position is valid. Is this an issue? local_transformer_logits = self.compute_local_transformer_logits( dec_out[:, dec_context_size:, :], audio_codes_masked, targets_offset_by_one=True ) local_transformer_loss, _ = self.compute_loss( local_transformer_logits, audio_codes_target_unstacked, audio_codes_lens_target_unstacked, mask_tokens_mask, frame_stacking_factor=self.frame_stacking_factor, ) else: # Autoregressive assert self.local_transformer_type == LocalTransformerType.AR, "Unexpected local transformer type" local_transformer_logits = self.compute_local_transformer_logits( dec_out[:, dec_context_size:, :], audio_codes_target_unstacked, targets_offset_by_one=False ) local_transformer_loss, _ = self.compute_loss( local_transformer_logits, audio_codes_target_unstacked, audio_codes_lens_target_unstacked, None, frame_stacking_factor=self.frame_stacking_factor, ) loss = loss + self.local_transformer_loss_scale * local_transformer_loss if aligner_encoder_loss is not None: loss = loss + aligner_encoder_loss return { 'logits': logits, 'attn_info': attn_info, 'loss': loss, 'codebook_loss': codebook_loss, 'local_transformer_loss': local_transformer_loss, 'local_transformer_logits': local_transformer_logits, 'loss_mask': loss_mask, 'alignment_loss': alignment_loss, 'aligner_encoder_loss': aligner_encoder_loss, 'audio_codes_target': audio_codes_target_unstacked, 'audio_codes_lens_target': audio_codes_lens_target_unstacked, 'text': context_tensors['text'], 'text_lens': context_tensors['text_lens'], 'context_audio_codes': context_tensors['context_audio_codes'], 'context_audio_codes_lens': context_tensors['context_audio_codes_lens'], 'dec_context_size': dec_context_size, 'aligner_attn_soft': aligner_attn_soft, 'aligner_attn_hard': aligner_attn_hard, } def training_step(self, batch, batch_idx): batch_output = self.process_batch(batch) loss = batch_output['loss'] codebook_loss = batch_output['codebook_loss'] self.log('train/codebook_loss', codebook_loss, prog_bar=True, sync_dist=True) if self.cfg_unconditional_prob == 0.0: # Only log alignment loss when not using cfg to avoid sync issues when # alignment loss is None on some ranks alignment_loss = batch_output['alignment_loss'] if alignment_loss is not None: self.log('train/alignment_loss', alignment_loss, prog_bar=True, sync_dist=True) self.log('train/loss', loss, prog_bar=True, sync_dist=True) local_transformer_loss = batch_output['local_transformer_loss'] if local_transformer_loss is not None: self.log('train/local_transformer_loss', local_transformer_loss, prog_bar=True, sync_dist=True) # Log batch info batch_size, text_token_max_len = batch["text"].shape text_token_total_num = batch["text_lens"].sum() batch_info_dict = { "train/batch_size": batch_size, "train/text_token_max_len": text_token_max_len, "train/text_token_total_num_in_batch": text_token_total_num.item(), "train/text_token_pad_ratio_percent_in_batch": 100 * (1 - text_token_total_num / (batch_size * text_token_max_len)), } if "audio_codes" in batch: audio_codes_max_len = batch["audio_codes"].shape[-1] audio_codes_total_num = batch["audio_codes_lens"].sum() batch_info_dict.update( { "train/audio_codes_max_len": audio_codes_max_len, "train/audio_codes_total_num_in_batch": audio_codes_total_num.item(), "train/audio_codes_pad_ratio_percent_in_batch": 100 * (1 - audio_codes_total_num / (batch_size * audio_codes_max_len)), } ) else: audio_samples_max_len = batch["audio"].shape[-1] audio_samples_total_num = batch["audio_lens"].sum() batch_info_dict.update( { "train/audio_samples_max_len": audio_samples_max_len, "train/audio_samples_total_num_in_batch": audio_samples_total_num.item(), "train/audio_samples_pad_ratio_percent_in_batch": 100 * (1 - audio_samples_total_num / (batch_size * audio_samples_max_len)), } ) self.log_dict(batch_info_dict, on_step=True) return loss def validation_step(self, batch, batch_idx): batch_output = self.process_batch(batch, mode="val") # self.process_batch returns a dict. We currently only log "logits" which come from the parallel prediction # head. If we use local_transformer, then the local_transformer returns "local_transformer_logits" loss = batch_output['loss'] codebook_loss = batch_output['codebook_loss'] alignment_loss = batch_output['alignment_loss'] aligner_encoder_loss = batch_output['aligner_encoder_loss'] logits = batch_output['logits'] audio_codes_target = batch_output['audio_codes_target'] audio_codes_lens_target = batch_output['audio_codes_lens_target'] context_audio_codes = batch_output['context_audio_codes'] context_audio_codes_lens = batch_output['context_audio_codes_lens'] attn_info = batch_output['attn_info'] text_lens = batch_output['text_lens'] dec_context_size = batch_output['dec_context_size'] if alignment_loss is None: alignment_loss = torch.tensor(0.0, device=loss.device) if aligner_encoder_loss is None: aligner_encoder_loss = torch.tensor(0.0, device=loss.device) if batch_idx == 0 and self.global_rank == 0: # Prepare dictionary for aggregated wandb logging wandb_log_dict = {} # Get audio data for logging wandb_log_dict.update( self.log_val_audio_example( logits, audio_codes_target, audio_codes_lens_target, context_audio_codes, context_audio_codes_lens ) ) # Get attention image data for logging if len(attn_info[self.transcript_decoder_layers[0]]['cross_attn_probabilities']) > 1: # cross_attn_probabilities only returned when not using flash attention cross_attention_probs = [ attn['cross_attn_probabilities'][0] for layer_idx, attn in enumerate(attn_info) if layer_idx in self.ctc_prior_layer_ids ] wandb_log_dict.update( self.log_attention_probs( cross_attention_probs, audio_codes_lens_target, text_lens, prefix="val", dec_context_size=dec_context_size, ) ) for layer_idx in self.transcript_decoder_layers: cross_attention_probs = [attn_info[layer_idx]['cross_attn_probabilities'][0]] wandb_log_dict.update( self.log_attention_probs( cross_attention_probs, audio_codes_lens_target, text_lens, prefix=f"val/layer_{layer_idx}", dec_context_size=dec_context_size, ) ) if batch_output['aligner_attn_soft'] is not None: wandb_log_dict.update( self.log_attention_probs( [batch_output['aligner_attn_soft']], audio_codes_lens_target, text_lens, prefix="val/aligner_encoder_attn", ) ) if batch_output['aligner_attn_hard'] is not None: wandb_log_dict.update( self.log_attention_probs( [batch_output['aligner_attn_hard'].unsqueeze(1)], audio_codes_lens_target, text_lens, prefix="val/aligner_encoder_attn_hard", ) ) # Perform single wandb log call if wandb is active and there is data for logger in self.loggers: if isinstance(logger, WandbLogger) and wandb_log_dict: logger.experiment.log(wandb_log_dict) local_transformer_loss = batch_output['local_transformer_loss'] val_output = { 'val_loss': loss, 'val_codebook_loss': codebook_loss, 'val_alignment_loss': alignment_loss, 'val_local_transformer_loss': local_transformer_loss, 'val_aligner_encoder_loss': aligner_encoder_loss, } self.validation_step_outputs.append(val_output) return val_output def get_cross_attention_scores(self, attn_probs, filter_layers=None): """ Returns the cross attention probabilities for the last audio timestep """ mean_cross_attn_scores = [] all_heads_cross_attn_scores = [] for lidx, layerwise_attn_prob in enumerate(attn_probs): if (filter_layers is not None and lidx not in filter_layers) or ( lidx not in self.transcript_decoder_layers ): continue cross_attn_prob = layerwise_attn_prob['cross_attn_probabilities'][ 0 ] # B, H, audio_timesteps, text_timesteps mean_cross_attn_scores.append(cross_attn_prob.mean(dim=1)) # B, audio_timesteps, text_timesteps for head_idx in range(cross_attn_prob.size(1)): all_heads_cross_attn_scores.append(cross_attn_prob[:, head_idx, -1, :]) # B, text_timesteps mean_cross_attn_scores = torch.stack(mean_cross_attn_scores, dim=1) # B, L, audio_timesteps, text_timesteps mean_cross_attn_scores = mean_cross_attn_scores.mean(dim=1) # B, audio_timesteps, text_timesteps last_audio_timestep_scores = mean_cross_attn_scores[:, -1, :] # B, text_timesteps return last_audio_timestep_scores, all_heads_cross_attn_scores def get_most_attended_text_timestep( self, alignment_attention_scores, last_attended_timesteps, text_lens, lookahead_window_size, attended_timestep_counter, batch_size, ): """ Returns the most attended timestep for each batch item """ text_time_step_attended = [] for bidx in range(batch_size): last_attended_timestep = last_attended_timesteps[-1][bidx] if attended_timestep_counter[bidx].get(last_attended_timestep, 0) >= 8: # This is probably an attention sink! Move to the next timestep last_attended_timestep += 1 window_size = lookahead_window_size window_end = min(last_attended_timestep + window_size, text_lens[bidx] - 3) # Ignore the last 3 timesteps item_attention_scores = alignment_attention_scores[bidx, last_attended_timestep:window_end] if item_attention_scores.size(0) == 0: # This means the sentence has ended attended_timestep = text_lens[bidx].item() - 1 else: attended_timestep = item_attention_scores.argmax().item() + last_attended_timestep text_time_step_attended.append(attended_timestep) attended_timestep_counter[bidx][attended_timestep] = ( attended_timestep_counter[bidx].get(attended_timestep, 0) + 1 ) return text_time_step_attended, attended_timestep_counter def construct_inference_prior( self, prior_epsilon, cross_attention_scores, text_lens, text_time_step_attended, attended_timestep_counter, unfinished_texts, finished_texts_counter, end_indices, lookahead_window_size, batch_size, ): # Attn prior for the next timestep _attn_prior = torch.zeros(cross_attention_scores.shape[0], 1, cross_attention_scores.shape[1]) + prior_epsilon _attn_prior = _attn_prior.to(cross_attention_scores.device) for bidx in range(cross_attention_scores.shape[0]): if bidx < batch_size: _text_len = text_lens[bidx] if text_lens[bidx] <= 5: # Very short sentences, No Prior _attn_prior[bidx, 0, :] = 1.0 else: _attn_prior[bidx, 0, max(1, text_time_step_attended[bidx] - 1)] = ( 1.0 # Slight exposure to history for better pronounciation. Not very important. ) _attn_prior[bidx, 0, text_time_step_attended[bidx]] = ( 1.0 # Slightly bias to continue moving forward. Not very important. ) for ind in range(1, lookahead_window_size + 1): _attn_prior[bidx, 0, min(text_time_step_attended[bidx] + ind, _text_len - 1)] = 1.0 # Penalize timesteps that have been attended to more than 10 times for _timestep in attended_timestep_counter[bidx]: if attended_timestep_counter[bidx][_timestep] >= 10: # This means the timestep has been attended to more than 10 times (To avoid getting stuck) _attn_prior[bidx, 0, : _timestep + 1] = prior_epsilon unfinished_texts[bidx] = False if text_time_step_attended[bidx] < text_lens[bidx] - 3: # This means the sentence has not ended if bidx not in end_indices: unfinished_texts[bidx] = True if text_time_step_attended[bidx] >= text_lens[bidx] - 2 or bidx in end_indices: if bidx not in finished_texts_counter: finished_texts_counter[bidx] = 0 for bidx in finished_texts_counter: finished_texts_counter[bidx] += 1 if finished_texts_counter[bidx] > 5: # This means we have been within the text EOS window for at least 5 timesteps # We should allow EOS to be predicted now. unfinished_texts[bidx] = False return _attn_prior, unfinished_texts, finished_texts_counter def get_inference_attention_plots( self, cross_attention_scores_all_timesteps, all_heads_cross_attn_scores_all_timesteps, text_lens, predicted_codes_lens, batch_size, compute_all_heads_attn_maps, last_attended_timestep, ): last_attended_timestep = np.array(last_attended_timestep).T cross_attention_scores_all_timesteps = torch.stack( cross_attention_scores_all_timesteps, dim=2 ) # B, text_timesteps, T' headwise_cross_attention_scores_all_timesteps = [] for hidx in range(len(all_heads_cross_attn_scores_all_timesteps[0])): head_cross_attention_all_timesteps = torch.stack( [x[hidx] for x in all_heads_cross_attn_scores_all_timesteps], dim=2 ) # B, text_timesteps, T' headwise_cross_attention_scores_all_timesteps.append(head_cross_attention_all_timesteps) cross_attention_maps = [] headwise_cross_attention_maps = [] for bidx in range(batch_size): item_cross_attention_scores = cross_attention_scores_all_timesteps[ bidx, : text_lens[bidx], : predicted_codes_lens[bidx] ] cross_attn_np = plot_alignment_to_numpy( item_cross_attention_scores.cpu().numpy(), attended=last_attended_timestep[bidx, : predicted_codes_lens[bidx]], ) cross_attention_maps.append(cross_attn_np) item_all_head_cross_attn_maps = [] if compute_all_heads_attn_maps: for hidx in range(len(all_heads_cross_attn_scores_all_timesteps[0])): item_headwise_cross_attention_scores = headwise_cross_attention_scores_all_timesteps[hidx][ bidx, : text_lens[bidx], : predicted_codes_lens[bidx] ] headwise_cross_attn_np = plot_alignment_to_numpy( item_headwise_cross_attention_scores.cpu().numpy(), attended=last_attended_timestep[bidx, : predicted_codes_lens[bidx]], ) item_all_head_cross_attn_maps.append(headwise_cross_attn_np) headwise_cross_attention_maps.append(item_all_head_cross_attn_maps) return cross_attention_maps, headwise_cross_attention_maps def find_eos_frame_index(self, codes, eos_detection_method) -> Union[int, float]: """ Checks for EOS in the predicted codes. Returns the index of the first frame within the frame stack that contains an EOS token across any codebook, or `None` if no EOS is found. Args: codes: (num_codebooks, frame_stacking_factor) Returns: index (within the frame stack) of the first frame with EOS, or `float('inf')` if no EOS is found """ eos_mask = codes == self.audio_eos_id # (codebooks, frame_stacking_factor) detection_type = EOSDetectionMethod.detection_type(eos_detection_method) if detection_type == "any": eos_per_frame = eos_mask.any( dim=0 ) # (frame_stacking_factor,) - True if any codebook has EOS in this frame elif detection_type == "all": eos_per_frame = eos_mask.all( dim=0 ) # (frame_stacking_factor,) - True if all codebooks have EOS in this frame elif detection_type == "zero_cb": eos_per_frame = eos_mask[:1, :].any( dim=0 ) # (frame_stacking_factor,) - True if zeroth codebook has EOS in this frame else: raise ValueError(f"Invalid EOS detection method: {eos_detection_method}") # find first frame with EOS if eos_per_frame.any(): # return index of the first frame with EOS return eos_per_frame.nonzero()[0].item() return float('inf') def detect_eos(self, audio_codes_multinomial, audio_codes_argmax, eos_detection_method) -> Union[int, float]: """ Detects EOS in the predicted codes. Returns the index of the first frame within the frame stack that triggers EOS detection, or `float('inf')` if no EOS is found. Args: audio_codes_multinomial: (num_codebooks, frame_stacking_factor) - Multinomial samples audio_codes_argmax: (num_codebooks, frame_stacking_factor) - Argmax samples eos_detection_method: EOS detection method Returns: index (within the frame stack) of the first frame with EOS, or `float('inf')` if no EOS is found """ sampling_type = EOSDetectionMethod.sampling_type(eos_detection_method) if sampling_type == "argmax": return self.find_eos_frame_index(audio_codes_argmax, eos_detection_method) elif sampling_type == "argmax_or_multinomial": argmax_eos_frame = self.find_eos_frame_index(audio_codes_argmax, eos_detection_method) multinomial_eos_frame = self.find_eos_frame_index(audio_codes_multinomial, eos_detection_method) return min(argmax_eos_frame, multinomial_eos_frame) else: raise ValueError(f"Invalid EOS detection method: {eos_detection_method}") def infer_batch( self, batch, max_decoder_steps=500, temperature=0.7, topk=80, use_cfg=False, cfg_scale=1.0, return_cross_attn_probs=False, apply_attention_prior=False, prior_epsilon=1e-5, lookahead_window_size=10, estimate_alignment_from_layers=None, apply_prior_to_layers=None, start_prior_after_n_audio_steps=10, compute_all_heads_attn_maps=False, use_local_transformer_for_inference=False, use_LT_kv_cache=True, maskgit_n_steps=3, maskgit_noise_scale=0.0, maskgit_fixed_schedule=None, maskgit_dynamic_cfg_scale=False, maskgit_sampling_type=None, ignore_finished_sentence_tracking=False, eos_detection_method="argmax_or_multinomial_any", # Setting this greater than 0 prevents rare cases of first-frame termination. Any number greater between 1 and 4 should work, but 4 # lines up with the codec's minimum frame requirement. min_generated_frames=4, ): eos_detection_method = EOSDetectionMethod(eos_detection_method) with torch.no_grad(): start_time = time.time() self.decoder.reset_cache(use_cache=self.use_kv_cache_for_inference) context_tensors = self.prepare_context_tensors(batch) text = context_tensors['text'] audio_codes_bos = torch.full( (text.size(0), self.num_audio_codebooks, self.frame_stacking_factor), self.audio_bos_id, device=text.device, ).long() audio_codes_lens = torch.full( (text.size(0),), 1, device=text.device ).long() # intetionally 1 rather than self.frame_stacking_factor since this is in stacked form audio_codes_input = audio_codes_bos audio_codes_mask = get_mask_from_lengths(audio_codes_lens) all_predictions = [] end_indices = {} if use_cfg: dummy_cond, dummy_cond_mask, dummy_additional_decoder_input, dummy_addition_dec_mask, _ = ( self.prepare_dummy_cond_for_cfg( context_tensors['cond'], context_tensors['cond_mask'], context_tensors['additional_decoder_input'], context_tensors['additional_decoder_mask'], ) ) cross_attention_scores_all_timesteps = [] all_heads_cross_attn_scores_all_timesteps = [] _attn_prior = None unfinished_texts = {} finished_texts_counter = {} attended_timestep_counter = [{} for _ in range(text.size(0))] last_attended_timesteps = [ [1 for _ in range(text.size(0))] ] # Maintain a list of attended timesteps as we predict audio for each batch item time_to_first_prediction = 0.0 for idx in range(max_decoder_steps // self.frame_stacking_factor): if idx == 1: time_to_first_prediction = time.time() - start_time if idx % 20 == 0: print(f"Decoding timestep {idx}") audio_codes_embedded = self.embed_audio_tokens(audio_codes_input) if context_tensors['additional_decoder_input'] is not None: _audio_codes_embedded = torch.cat( [context_tensors['additional_decoder_input'], audio_codes_embedded], dim=1 ) _audio_codes_mask = torch.cat( [context_tensors['additional_decoder_mask'], audio_codes_mask], dim=1 ) else: _audio_codes_embedded = audio_codes_embedded _audio_codes_mask = audio_codes_mask if apply_prior_to_layers is not None: attn_prior = [None for _ in range(self.decoder.n_layers)] for layer_idx in apply_prior_to_layers: attn_prior[layer_idx] = _attn_prior else: attn_prior = _attn_prior if self.model_type == 'multi_encoder_context_tts': attn_prior = [attn_prior, None] if use_cfg: batch_size = audio_codes_embedded.size(0) if isinstance(context_tensors['cond'], list): cfg_cond = [ torch.cat([cond_item, dummy_cond_item], dim=0) for cond_item, dummy_cond_item in zip(context_tensors['cond'], dummy_cond) ] cfg_cond_mask = [ torch.cat([cond_mask_item, dummy_cond_mask_item], dim=0) for cond_mask_item, dummy_cond_mask_item in zip( context_tensors['cond_mask'], dummy_cond_mask ) ] else: cfg_cond = torch.cat([context_tensors['cond'], dummy_cond], dim=0) cfg_cond_mask = torch.cat([context_tensors['cond_mask'], dummy_cond_mask], dim=0) cfg_audio_codes_embedded = torch.cat([_audio_codes_embedded, _audio_codes_embedded], dim=0) cfg_audio_codes_mask = torch.cat([_audio_codes_mask, _audio_codes_mask], dim=0) if dummy_additional_decoder_input is not None: cfg_audio_codes_embedded[batch_size:, : dummy_additional_decoder_input.size(1)] = ( dummy_additional_decoder_input ) cfg_audio_codes_mask[batch_size:, : dummy_additional_decoder_input.size(1)] = ( dummy_addition_dec_mask ) # print(f"step {idx}") # print(f"use_cfg {use_cfg}") # print(f"shape {cfg_audio_codes_embedded.shape}") # print(f"use kv cahce? {self.use_kv_cache_for_inference}") combined_logits, attn_probs, dec_out = self.forward( dec_input_embedded=cfg_audio_codes_embedded, dec_input_mask=cfg_audio_codes_mask, cond=cfg_cond, cond_mask=cfg_cond_mask, attn_prior=attn_prior, multi_encoder_mapping=context_tensors['multi_encoder_mapping'], ) cond_logits = combined_logits[:batch_size] uncond_logits = combined_logits[batch_size:] all_code_logits = (1 - cfg_scale) * uncond_logits + cfg_scale * cond_logits else: batch_size = audio_codes_embedded.size(0) all_code_logits, attn_probs, dec_out = self.forward( dec_input_embedded=_audio_codes_embedded, dec_input_mask=_audio_codes_mask, cond=context_tensors['cond'], cond_mask=context_tensors['cond_mask'], attn_prior=attn_prior, multi_encoder_mapping=context_tensors['multi_encoder_mapping'], ) if return_cross_attn_probs or apply_attention_prior: cross_attention_scores, all_heads_cross_attn_scores = self.get_cross_attention_scores( attn_probs ) # B, text_timesteps alignment_attention_scores = cross_attention_scores if estimate_alignment_from_layers is not None: alignment_attention_scores, _ = self.get_cross_attention_scores( attn_probs, filter_layers=estimate_alignment_from_layers ) # B, text_timesteps cross_attention_scores_all_timesteps.append(cross_attention_scores) all_heads_cross_attn_scores_all_timesteps.append(all_heads_cross_attn_scores) if apply_attention_prior and idx >= start_prior_after_n_audio_steps: text_time_step_attended, attended_timestep_counter = self.get_most_attended_text_timestep( alignment_attention_scores=alignment_attention_scores, last_attended_timesteps=last_attended_timesteps, text_lens=context_tensors['text_lens'], lookahead_window_size=lookahead_window_size, attended_timestep_counter=attended_timestep_counter, batch_size=batch_size, ) last_attended_timesteps.append(text_time_step_attended) _attn_prior, unfinished_texts, finished_texts_counter = self.construct_inference_prior( prior_epsilon=prior_epsilon, cross_attention_scores=cross_attention_scores, text_lens=context_tensors['text_lens'], text_time_step_attended=text_time_step_attended, attended_timestep_counter=attended_timestep_counter, unfinished_texts=unfinished_texts, finished_texts_counter=finished_texts_counter, end_indices=end_indices, lookahead_window_size=lookahead_window_size, batch_size=batch_size, ) if ignore_finished_sentence_tracking: finished_items = {} unfinished_items = {} else: finished_items = { k: v for k, v in finished_texts_counter.items() if v >= 20 } # Items that have been close to the end for atleast 20 timesteps unfinished_items = {k: v for k, v in unfinished_texts.items() if v} # Don't allow termination until we have generated at least `min_generated_frames` frames (rounded up to the nearest multiple of frame_stacking_factor) # This guards against rare cases of termination right at the start of generation. forbid_audio_eos = idx * self.frame_stacking_factor < min_generated_frames all_code_logits_t = all_code_logits[:, -1, :] # (B, num_codebooks * num_tokens_per_codebook) if use_local_transformer_for_inference: if self.local_transformer_type == LocalTransformerType.AR: # Autoregressive sampling with local transformer audio_codes_next = self.local_transformer_sample_autoregressive( dec_output=dec_out[:, -1, :], temperature=temperature, topk=topk, unfinished_items=unfinished_items, finished_items=finished_items, use_cfg=use_cfg, cfg_scale=cfg_scale, use_kv_cache=use_LT_kv_cache, forbid_audio_eos=forbid_audio_eos, ) elif self.local_transformer_type == LocalTransformerType.MASKGIT: audio_codes_next = self.local_transformer_sample_maskgit( dec_output=dec_out[:, -1, :], temperature=temperature, topk=topk, unfinished_items=unfinished_items, finished_items=finished_items, use_cfg=use_cfg, cfg_scale=cfg_scale, n_steps=maskgit_n_steps, noise_scale=maskgit_noise_scale, fixed_schedule=maskgit_fixed_schedule, dynamic_cfg_scale=maskgit_dynamic_cfg_scale, sampling_type=maskgit_sampling_type, forbid_audio_eos=forbid_audio_eos, ) else: raise ValueError( f"Local transformer inference requested by but local transformer type is {self.local_transformer_type}" ) else: # Parallel sampling from all codebooks audio_codes_next = self.sample_codes_from_logits( all_code_logits_t, temperature=temperature, topk=topk, unfinished_items=unfinished_items, finished_items=finished_items, forbid_audio_eos=forbid_audio_eos, ) # (B, num_codebooks, frame_stacking_factor) all_codes_next_argmax = self.sample_codes_from_logits( all_code_logits_t, temperature=0.01, topk=1, unfinished_items=unfinished_items, finished_items=finished_items, forbid_audio_eos=forbid_audio_eos, ) # (B, num_codebooks, frame_stacking_factor) for item_idx in range(all_codes_next_argmax.size(0)): if item_idx not in end_indices: end_frame_index = self.detect_eos( audio_codes_next[item_idx], all_codes_next_argmax[item_idx], eos_detection_method ) if end_frame_index != float('inf'): global_index = idx * self.frame_stacking_factor + end_frame_index end_indices[item_idx] = global_index print(f"End detected for item {item_idx} at decoder timestep: {idx}") all_predictions.append(audio_codes_next) audio_codes_input = torch.cat([audio_codes_input, audio_codes_next], dim=-1) # (B, C, T') audio_codes_lens = audio_codes_lens + 1 # already in stacked form audio_codes_mask = get_mask_from_lengths(audio_codes_lens) if len(end_indices) == text.size(0) and len(all_predictions) >= 4: # Codec must be of atleast 4 timesteps to be decoded properly print("All ends reached") break tts_generation_time = time.time() - start_time tts_generation_time_per_frame = tts_generation_time / (len(all_predictions) * self.frame_stacking_factor) # Concatenate the list of predictions along the time dimension. Note that when frame stacking is on, # this also undoes the stacking. predicted_codes = torch.cat(all_predictions, dim=-1) # (B, num_codebooks, T') predicted_lens = [ end_indices.get(idx, max_decoder_steps) for idx in range(text.size(0)) ] # Ensure that the codec is atleast of length 4 predicted_codes_lens = torch.tensor(predicted_lens, device=text.device).long() predicted_audio, predicted_audio_lens = self.codes_to_audio(predicted_codes, predicted_codes_lens) end_time = time.time() total_audio_duration_generated = ( predicted_audio_lens.max().item() * predicted_audio_lens.shape[0] ) / self.sample_rate rtf = total_audio_duration_generated / (end_time - start_time) rtf_metrics = { 'rtf': rtf, 'time_to_first_prediction': time_to_first_prediction, 'tts_generation_time': tts_generation_time, 'max_frames_generated': len(all_predictions), 'tts_generation_time_per_frame': tts_generation_time_per_frame, 'batch_size': text.size(0), } torch.cuda.empty_cache() cross_attention_maps = None headwise_cross_attention_maps = None if return_cross_attn_probs: cross_attention_maps, headwise_cross_attention_maps = self.get_inference_attention_plots( cross_attention_scores_all_timesteps, all_heads_cross_attn_scores_all_timesteps, context_tensors['text_lens'], predicted_codes_lens, text.size(0), compute_all_heads_attn_maps, last_attended_timesteps, ) return InferBatchOutput( predicted_audio=predicted_audio, predicted_audio_lens=predicted_audio_lens, predicted_codes=predicted_codes, predicted_codes_lens=predicted_codes_lens, rtf_metrics=rtf_metrics, cross_attention_maps=cross_attention_maps, headwise_cross_attention_maps=headwise_cross_attention_maps, ) def test_step(self, batch, batch_idx): with torch.no_grad(): test_dl_batch_size = self._test_dl.batch_size temperature = self.cfg.get('inference_temperature', 0.7) topk = self.cfg.get('inference_topk', 80) use_cfg = self.cfg.get('inference_use_cfg', False) cfg_scale = self.cfg.get('inference_cfg_scale', 1.0) output = self.infer_batch( batch, max_decoder_steps=self.cfg.get('max_decoder_steps', 500), temperature=temperature, topk=topk, use_cfg=use_cfg, cfg_scale=cfg_scale, ) predicted_audio = output.predicted_audio predicted_audio_lens = output.predicted_audio_lens for logger in self.loggers: is_wandb = isinstance(logger, WandbLogger) is_tb = isinstance(logger, TensorBoardLogger) if not is_wandb and not is_tb: raise ValueError( "Invalid logger type for audio logging: {type(logger)}. Only `WandbLogger` and `TensorBoardLogger` are supported." ) for idx in range(predicted_audio.size(0)): predicted_audio_np = predicted_audio[idx].float().detach().cpu().numpy() predicted_audio_np = predicted_audio_np[: predicted_audio_lens[idx]] item_idx = batch_idx * test_dl_batch_size + idx if is_wandb: log_dict = { "test/predicted_audio": wandb.Audio( predicted_audio_np, sample_rate=self.sample_rate, caption="Predicted Audio" ), } logger.experiment.log(log_dict, step=item_idx) if is_tb: logger.experiment.add_audio( 'test/predicted_audio', predicted_audio_np, global_step=item_idx, sample_rate=self.sample_rate, ) # Save the predicted audio log_dir = logger.log_dir audio_dir = os.path.join(log_dir, 'audios') if not os.path.exists(audio_dir): os.makedirs(audio_dir) audio_path = os.path.join(audio_dir, f'predicted_audioRank{self.global_rank}_{item_idx}.wav') sf.write(audio_path, predicted_audio_np, self.sample_rate) def on_validation_epoch_end(self): collect = lambda key: torch.stack([x[key] for x in self.validation_step_outputs]).mean() val_loss = collect("val_loss") val_codebook_loss = collect("val_codebook_loss") val_alignment_loss = collect("val_alignment_loss") val_aligner_encoder_loss = collect("val_aligner_encoder_loss") # log val_loss in the same group as the other val metrics. self.log("val/loss", val_loss, prog_bar=True, sync_dist=True) # ensure val_loss is available for epoch-level checkpointing and filename generation without cluttering wandb logs. self.log( "val_loss", val_loss, prog_bar=False, sync_dist=True, on_step=False, on_epoch=True, logger=False, enable_graph=False, ) self.log("val/codebook_loss", val_codebook_loss, prog_bar=True, sync_dist=True) self.log("val/alignment_loss", val_alignment_loss, prog_bar=True, sync_dist=True) self.log("val/aligner_encoder_loss", val_aligner_encoder_loss, prog_bar=True, sync_dist=True) if self.local_transformer_type != LocalTransformerType.NO_LT: val_local_transformer_loss = collect("val_local_transformer_loss") self.log("val/local_transformer_loss", val_local_transformer_loss, prog_bar=True, sync_dist=True) self.validation_step_outputs.clear() # free memory def get_dataset(self, dataset_cfg, dataset_type): dataset = instantiate( dataset_cfg.dataset, sample_rate=self.sample_rate, bos_id=self.bos_id, eos_id=self.eos_id, audio_bos_id=self.audio_bos_id, audio_eos_id=self.audio_eos_id, context_audio_bos_id=self.context_audio_bos_id, context_audio_eos_id=self.context_audio_eos_id, num_audio_codebooks=self.data_num_audio_codebooks, codec_model_samples_per_frame=self.codec_model_samples_per_frame, prior_scaling_factor=self.cfg.prior_scaling_factor, load_cached_codes_if_available=self.cfg.load_cached_codes_if_available, dataset_type=dataset_type, # train or test used for setting phone prob to 1.0 in test dataset (worker_init_fn) use_text_conditioning_tokenizer=self.cfg.use_text_conditioning_encoder, text_conditioning_tokenizer_name=self.text_conditioning_tokenizer_name, pad_context_text_to_max_duration=self.pad_context_text_to_max_duration, context_duration_min=self.cfg.context_duration_min, context_duration_max=self.cfg.context_duration_max, text_context_remapping=self.text_context_remapping, text_context_remapping_prob=self.text_context_remapping_prob, ) dataset.load_16khz_audio = False dataset.tokenizer_config = ( self.cfg.text_tokenizers ) # This will be used in worker_init_fn for instantiating tokenizer return dataset def get_lhotse_dataloader(self, dataset_cfg, mode='train') -> torch.utils.data.DataLoader: # TODO @xueyang: better to distinguish cfg. self.cfg is the model cfg, while cfg here is train_ds cfg. Also # cfg is a classifier-free guidance. dataset = MagpieTTSLhotseDataset( sample_rate=self.sample_rate, volume_norm=dataset_cfg.volume_norm, codec_model_samples_per_frame=self.codec_model_samples_per_frame, audio_bos_id=self.audio_bos_id, audio_eos_id=self.audio_eos_id, context_audio_bos_id=self.context_audio_bos_id, context_audio_eos_id=self.context_audio_eos_id, num_audio_codebooks=self.data_num_audio_codebooks, prior_scaling_factor=self.cfg.prior_scaling_factor, load_cached_codes_if_available=self.cfg.load_cached_codes_if_available, dataset_type=mode, # train or test used for setting phone prob to 1.0 in test dataset (worker_init_fn) load_16khz_audio=False, pad_context_text_to_max_duration=self.pad_context_text_to_max_duration, context_duration_min=self.cfg.context_duration_min, context_duration_max=self.cfg.context_duration_max, use_text_conditioning_tokenizer=self.cfg.use_text_conditioning_encoder, text_conditioning_tokenizer_name=self.text_conditioning_tokenizer_name, tokenizer_config=self.cfg.text_tokenizers, text_context_remapping=self.text_context_remapping, text_context_remapping_prob=self.text_context_remapping_prob, ) data_loader = get_lhotse_dataloader_from_config( config=dataset_cfg.dataset, global_rank=self.global_rank, world_size=self.world_size, dataset=dataset, ) return data_loader def setup_training_data(self, dataset_cfg): if dataset_cfg.get("use_lhotse", False): # TODO @xueyang: better to distinguish cfg. self.cfg is the model cfg, while cfg here is train_ds cfg. Also # cfg is a classifier-free guidance. # specify target sampling rate the same as codec model's because lhotse config defaults 16_000. if not isinstance(dataset_cfg, DictConfig): dataset_cfg = OmegaConf.create(dataset_cfg) OmegaConf.set_struct(dataset_cfg.dataset, False) dataset_cfg.dataset.update({"sample_rate": self.sample_rate}) OmegaConf.set_struct(dataset_cfg.dataset, True) self._train_dl = self.get_lhotse_dataloader(dataset_cfg, mode='train') else: dataset = self.get_dataset(dataset_cfg, dataset_type='train') sampler = dataset.get_sampler(dataset_cfg.dataloader_params.batch_size, world_size=self.trainer.world_size) persistent_workers = True if dataset_cfg.dataloader_params.num_workers == 0: persistent_workers = False # For num workers > 0 tokenizer will be assigned in worker_init_fn (since it is not picklable) dataset.text_tokenizer = setup_tokenizers( all_tokenizers_config=self.cfg.text_tokenizers, mode='train', ) self._train_dl = torch.utils.data.DataLoader( dataset, collate_fn=dataset.collate_fn, sampler=sampler, **dataset_cfg.dataloader_params, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers, ) def _setup_test_dataloader(self, dataset_cfg) -> torch.utils.data.DataLoader: if dataset_cfg.get("use_lhotse", False): # specify target sampling rate the same as codec model's because lhotse config defaults 16_000. if not isinstance(dataset_cfg, DictConfig): dataset_cfg = OmegaConf.create(dataset_cfg) OmegaConf.set_struct(dataset_cfg.dataset, False) dataset_cfg.dataset.update({"sample_rate": self.sample_rate}) OmegaConf.set_struct(dataset_cfg.dataset, True) data_loader = self.get_lhotse_dataloader(dataset_cfg, mode='test') else: dataset = self.get_dataset(dataset_cfg, dataset_type='test') persistent_workers = True if dataset_cfg.dataloader_params.num_workers == 0: persistent_workers = False # For num workers > 0 tokenizer will be assigned in worker_init_fn (since it is not picklable) dataset.text_tokenizer = setup_tokenizers(all_tokenizers_config=self.cfg.text_tokenizers, mode='test') data_loader = torch.utils.data.DataLoader( dataset, collate_fn=dataset.collate_fn, **dataset_cfg.dataloader_params, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers, ) return data_loader def setup_validation_data(self, dataset_cfg): self._validation_dl = self._setup_test_dataloader(dataset_cfg) def setup_test_data(self, dataset_cfg): self._test_dl = self._setup_test_dataloader(dataset_cfg) def setup_dummy_text_context_in_batch( self, batch: Dict[str, torch.Tensor], ) -> bool: """Setup dummy text context tensors in the batch dictionary. """ # No text context provided - set up dummy if model requires text conditioning tensors dummy_context_text = "[NO TEXT CONTEXT]" dummy_tokens = self.tokenizer.encode( text=dummy_context_text, tokenizer_name=self.text_conditioning_tokenizer_name ) batch['context_text_tokens'] = torch.tensor([dummy_tokens], device=self.device, dtype=torch.long) batch['context_text_tokens_lens'] = torch.tensor([len(dummy_tokens)], device=self.device, dtype=torch.long) batch['has_text_context'] = torch.tensor([False], device=self.device, dtype=torch.bool) def setup_dummy_audio_context_in_batch( self, batch: Dict[str, torch.Tensor], context_audio: Optional[torch.Tensor] = None, context_audio_lens: Optional[torch.Tensor] = None, ) -> bool: """Setup dummy audio context tensors in the batch dictionary. """ # Model has baked context - create minimal dummy context tensors # These will be ignored in prepare_context_tensors when baked embedding is used dummy_context_codes = torch.zeros( 1, self.num_audio_codebooks, 2, device=self.device, dtype=torch.long ) dummy_context_codes[:, :, 0] = self.context_audio_bos_id dummy_context_codes[:, :, 1] = self.context_audio_eos_id batch['context_audio_codes'] = dummy_context_codes batch['context_audio_codes_lens'] = torch.tensor([2], device=self.device, dtype=torch.long) def do_tts( self, transcript: str, language: str = "en", apply_TN: bool = False, temperature: float = 0.7, topk: int = 80, max_decoder_steps: int = 500, use_cfg: bool = True, cfg_scale: float = 2.5, ) -> tuple: """ Generate speech from raw text transcript. This is a convenience method for single-utterance text-to-speech synthesis. For batch processing, use `infer_batch` directly. Only supports baked context embedding context injection, NO audio conditioning and text conditioning. Custom voice generation is not supported by this method. Args: transcript: Raw text to synthesize. language: Language code for text normalization and tokenization. Supported values depend on model's tokenizer configuration. Common: "en" (English), "de" (German), "es" (Spanish), etc. apply_TN: Whether to apply text normalization to the transcript. If True, uses nemo_text_processing for normalization. temperature: Sampling temperature for token generation. topk: Top-k sampling parameter. max_decoder_steps: Maximum number of decoder steps. use_cfg: Whether to use classifier-free guidance. cfg_scale: Scale factor for classifier-free guidance. Returns: Tuple of (audio, audio_len) where: audio: Generated audio waveform. Shape: (1, T_audio). audio_len: Length of generated audio in samples. Shape: (1,). Raises: ValueError: If model does not have a baked context embedding. ImportError: If apply_TN=True but nemo_text_processing is not installed. Example: >>> # If text does not need to be normalized >>> audio, audio_len = model.do_tts("Hello, how are you today?") >>> >>> # If text needs to be normalized >>> audio, audio_len = model.do_tts( ... "Hello, how are you today?", ... apply_TN=True, ... ) """ assert self.has_baked_context_embedding, "Model does not have a baked context embedding. Please use a checkpoint with a baked context embedding." # Apply text normalization if requested normalized_text = transcript if apply_TN: try: from nemo_text_processing.text_normalization.normalize import Normalizer normalizer = Normalizer(input_case='cased', lang=language) normalized_text = normalizer.normalize(transcript, verbose=False) logging.debug(f"Text normalization: '{transcript}' -> '{normalized_text}'") except ImportError: logging.warning( "nemo_text_processing not installed. Skipping text normalization. " "Install with: pip install nemo_text_processing" ) # Determine tokenizer name based on language # Try to find a matching tokenizer, fallback to first available tokenizer_name = None available_tokenizers = list(self.tokenizer.tokenizers.keys()) print(f"Available tokenizers: {available_tokenizers}") # Common mappings for tokenizer names language_tokenizer_map = { "en": ["english_phoneme", "english"], "de": ["german_phoneme", "german"], "es": ["spanish_phoneme", "spanish"], "fr": ["french_phoneme", "french"], "it": ["italian_phoneme", "italian"], "vi": ["vietnamese_phoneme", "vietnamese"], "zh": ["mandarin_phoneme", "mandarin", "chinese"], } # Find matching tokenizer if language in language_tokenizer_map: for candidate in language_tokenizer_map[language]: if candidate in available_tokenizers: tokenizer_name = candidate break # Fallback to first available tokenizer if tokenizer_name is None: tokenizer_name = available_tokenizers[0] logging.info( f"No tokenizer found for language '{language}'. " f"Using '{tokenizer_name}'. Available: {available_tokenizers}" ) # Tokenize the transcript text tokens = self.tokenizer.encode(text=normalized_text, tokenizer_name=tokenizer_name) tokens = tokens + [self.eos_id] # Add EOS token (BOS not used per dataset convention) text_tensor = torch.tensor([tokens], device=self.device, dtype=torch.long) text_lens = torch.tensor([len(tokens)], device=self.device, dtype=torch.long) # Create batch dictionary batch = { 'text': text_tensor, 'text_lens': text_lens, } # Setup context in batch if self.use_text_conditioning_encoder: self.setup_dummy_text_context_in_batch(batch) self.setup_dummy_audio_context_in_batch(batch) # Run inference with torch.no_grad(): output = self.infer_batch( batch, max_decoder_steps=max_decoder_steps, temperature=temperature, topk=topk, use_cfg=use_cfg, cfg_scale=cfg_scale, ) return output.predicted_audio, output.predicted_audio_lens @classmethod def list_available_models(cls) -> List[PretrainedModelInfo]: return []