# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import itertools from typing import Iterable, Optional import editdistance import librosa import torch from hydra.utils import instantiate from lightning.pytorch import Trainer from lightning.pytorch.loggers import TensorBoardLogger from lightning.pytorch.utilities.combined_loader import CombinedLoader from omegaconf import DictConfig, OmegaConf from nemo.collections.asr.losses.angularloss import AngularSoftmaxLoss from nemo.collections.tts.data.dataset import TTSDataset from nemo.collections.tts.modules.ssl_tts import GreedyCTCDecoder from nemo.collections.tts.torch.tts_tokenizers import BaseTokenizer, EnglishCharsTokenizer from nemo.core.classes import ModelPT from nemo.core.classes.common import PretrainedModelInfo from nemo.core.optim.lr_scheduler import WarmupPolicy from nemo.utils import logging from nemo.utils.decorators import experimental @experimental class SSLDisentangler(ModelPT): """ SSLDisentangler is a Conformer based model for extracting disentangled content and speaker embeddings from an audio waveform. This model uses a pre-trained Conformer SSL model. To extract the linguistic content and speaker representations using a pre-trained Conformer, two randomly initialized downstream heads are added and the entire setup is finetuned in multi-task manner for speech recognition and speaker verification. These representations can be used by FastPitchModel_SSL for voice conversion by swapping the speaker embedding of a given source utterance, with the speaker embedding of a target speaker. """ def __init__(self, cfg: DictConfig, trainer: Trainer = None): super().__init__(cfg=cfg, trainer=trainer) self.preprocessor_disentangler = SSLDisentangler.from_config_dict(self._cfg.preprocessor) self.encoder = SSLDisentangler.from_config_dict(self._cfg.encoder) self._text_tokenizer = EnglishCharsTokenizer(add_blank_at="last") self._tb_logger = None self.downstream_nets = torch.nn.ModuleDict() for task in self._cfg.downstream_heads.task_names: if task == 'speaker_verification': # setting up downstream heads and loss functions for speaker verification task in_dim = self._cfg.encoder.d_model out_dim = self._cfg.downstream_heads.speaker_embed_size num_speakers = self._cfg.downstream_heads.num_speakers self.downstream_nets[task] = torch.nn.Linear(in_dim, out_dim) self.sv_linear = torch.nn.Linear(out_dim, num_speakers) self.sv_loss = AngularSoftmaxLoss(scale=30, margin=0.4) elif task == 'content': # setting up downstream heads and loss functions for text/content recognition task in_dim = self._cfg.encoder.d_model out_dim = self._cfg.downstream_heads.content_embed_size num_chars = len(self._text_tokenizer.tokens) # list of english tokens self.downstream_nets[task] = torch.nn.Linear(in_dim, out_dim) self.content_linear = torch.nn.Linear(out_dim, num_chars) self.ctc_loss = torch.nn.CTCLoss(blank=self._text_tokenizer.blank, zero_infinity=True) self.pitch_augment = self._cfg.get('pitch_augment', False) self.augment_ctc = self._cfg.get('augment_ctc', False) self.aug_loss_type = self._cfg.get('aug_loss_type', 'mse') self.stop_gradient = self._cfg.get('stop_gradient', False) assert ( self.stop_gradient and self.augment_ctc ) == False, "stop_gradient and augment_ctc cannot be true at the same time" self.mse_loss = torch.nn.MSELoss() self.ctc_decoder = GreedyCTCDecoder(self._text_tokenizer.tokens, self._text_tokenizer.blank) else: raise ValueError(f"{task} is not a valid task. Task must be speaker_verification or content.") self.automatic_optimization = False stft_cfg = self._cfg.preprocessor librosa_mel_filter = librosa.filters.mel( sr=stft_cfg.sample_rate, n_fft=stft_cfg.n_fft, n_mels=stft_cfg.features, fmin=0, fmax=8000 ) fb = torch.tensor( librosa_mel_filter, dtype=torch.float, ).unsqueeze(0) self.register_buffer("fb", fb) @classmethod def list_available_models(cls) -> Optional[PretrainedModelInfo]: """ This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. Returns: List of available pre-trained models. """ results = [] model = PretrainedModelInfo( pretrained_model_name="ssl_en_conformer_large", description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:ssl_en_conformer_large", location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/ssl_en_conformer_large/versions/1.10.1/files/ssl_en_conformer_large.nemo", ) results.append(model) model = PretrainedModelInfo( pretrained_model_name="ssl_en_conformer_xlarge", description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:ssl_en_conformer_xlarge", location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/ssl_en_conformer_xlarge/versions/1.10.0/files/ssl_en_conformer_xlarge.nemo", ) results.append(model) return results @property def tb_logger(self): if self._tb_logger is None: if self.logger is None and self.logger.experiment is None: return None tb_logger = self.logger.experiment if isinstance(self.logger, Iterable): for logger in self.logger: if isinstance(logger, TensorBoardLogger): tb_logger = logger.experiment break self._tb_logger = tb_logger return self._tb_logger def __setup_dataloader_from_config(self, data_config): if hasattr(self, '_text_tokenizer') and isinstance(self._text_tokenizer, BaseTokenizer): _text_tokenizer = self._text_tokenizer else: if hasattr(self, '_text_tokenizer') and not isinstance(self._text_tokenizer, BaseTokenizer): logging.warning(f"test_tokenizer is set but not a BaseTokenizer. Will be set to EnglishCharsTokenizer") _text_tokenizer = self._text_tokenizer = EnglishCharsTokenizer(add_blank_at="last") for task in self._cfg.downstream_heads.task_names: if task == 'speaker_verification': sv_dataset = TTSDataset( manifest_filepath=data_config['manifest_speaker_verification_fp'], sample_rate=self._cfg.sample_rate, text_tokenizer=_text_tokenizer, segment_max_duration=data_config['segment_max_duration'], sup_data_types=['speaker_id'], sup_data_path=data_config['sup_data_path'], pad_multiple=data_config.get('pad_multiple', 1), ) sv_loader = torch.utils.data.DataLoader( sv_dataset, batch_size=data_config['batch_size_sv'], collate_fn=sv_dataset.general_collate_fn, shuffle=data_config['shuffle'], num_workers=data_config.get('num_workers_sv', 0), pin_memory=data_config.get('pin_memory', False), ) elif task == 'content': content_dataset = TTSDataset( manifest_filepath=data_config['manifest_content_fp'], sample_rate=self._cfg.sample_rate, text_tokenizer=_text_tokenizer, min_duration=data_config['min_duration_content'], max_duration=data_config['max_duration_content'], pitch_augment=data_config.get('pitch_augment', False), cache_pitch_augment=data_config.get('cache_pitch_augment', True), sup_data_path=data_config['sup_data_path'], pad_multiple=data_config.get('pad_multiple', 1), ) content_loader = torch.utils.data.DataLoader( content_dataset, batch_size=data_config['batch_size_content'], collate_fn=content_dataset.general_collate_fn, shuffle=data_config['shuffle'], num_workers=data_config.get('num_workers_content', 0), pin_memory=data_config.get('pin_memory', False), ) else: raise ValueError(f"{task} is not a valid task. Task must be speaker_verification or content.") loaders = {"sv": sv_loader, "content": content_loader} return loaders def setup_training_data(self, cfg): self._train_dl = self.__setup_dataloader_from_config(self._cfg.train_ds) def setup_validation_data(self, cfg): self._validation_dl = CombinedLoader(self.__setup_dataloader_from_config(self._cfg.validation_ds)) def configure_optimizers(self): optim_backbone_config = self._cfg.optim_backbone.copy() optim_downstream_config = self._cfg.optim_downstream.copy() OmegaConf.set_struct(optim_backbone_config, False) sched_backbone_config = optim_backbone_config.pop("sched", None) OmegaConf.set_struct(optim_backbone_config, True) OmegaConf.set_struct(optim_downstream_config, False) sched_downstream_config = optim_downstream_config.pop("sched", None) OmegaConf.set_struct(optim_downstream_config, True) optim_backbone = instantiate( optim_backbone_config, params=self.encoder.parameters(), ) optim_downstream = instantiate( optim_downstream_config, params=itertools.chain( self.downstream_nets.parameters(), self.sv_linear.parameters(), self.content_linear.parameters(), self.sv_loss.parameters(), ), ) if sched_backbone_config is not None and sched_downstream_config is not None: scheduler_backbone = WarmupPolicy( optimizer=optim_backbone, max_steps=None, min_lr=sched_backbone_config.min_lr, warmup_steps=sched_backbone_config.warmup_steps, ) # Use warmup to delay start sch1_dict = { 'scheduler': scheduler_backbone, 'interval': 'step', } scheduler_downstream = WarmupPolicy( optimizer=optim_downstream, max_steps=None, min_lr=sched_downstream_config.min_lr, warmup_steps=sched_downstream_config.warmup_steps, ) sch2_dict = { 'scheduler': scheduler_downstream, 'interval': 'step', } return [optim_backbone, optim_downstream], [sch1_dict, sch2_dict] else: return [optim_backbone, optim_downstream] def forward(self, input_signal=None, input_signal_length=None, normalize_content=True): processed_signal, processed_signal_length = self.preprocessor_disentangler( input_signal=input_signal, length=input_signal_length, ) encoded, encoded_len = self.encoder(audio_signal=processed_signal, length=processed_signal_length) # b,c,t for task in self._cfg.downstream_heads.task_names: if task == "speaker_verification": speaker_embedding = self.downstream_nets['speaker_verification'](encoded[:, :, 0]) l2_norm = torch.norm(speaker_embedding, p=2, dim=-1, keepdim=True) speaker_embedding_normalized = speaker_embedding / l2_norm speaker_logits = self.sv_linear(speaker_embedding_normalized) elif task == "content": encoded_btc = encoded.permute(0, 2, 1) content_embedding = self.downstream_nets['content'](encoded_btc) if normalize_content: l2_norm_content = torch.norm(content_embedding, p=2, dim=-1, keepdim=True) content_embedding = content_embedding / l2_norm_content content_logits = self.content_linear(content_embedding) content_log_probs = content_logits.log_softmax(dim=2) content_log_probs = content_log_probs.permute(1, 0, 2) # t,b,c for ctc else: raise ValueError(f"{task} is not a valid task. Task must be speaker_verification or content.") return ( speaker_logits, speaker_embedding_normalized, content_embedding, content_log_probs, encoded_len, ) def forward_for_export(self, input_signal=None, input_signal_length=None, normalize_content=True): # Same as forward right now. Earlier version of encoder had a different forward for export. # This function is still kept for compatibility with older evaluation/inference scripts. return self.forward( input_signal=input_signal, input_signal_length=input_signal_length, normalize_content=normalize_content, ) def training_step(self, batch, batch_idx): loss = 0.0 optim_backbone, optim_downstream = self.optimizers() schedulers = self.lr_schedulers() for key in batch.keys(): if key == 'sv': signal = batch[key]['audio'] signal_len = batch[key]['audio_lens'] speaker_id = batch[key]['speaker_id'] sv_logits, sv_emb, _, _, _ = self.forward(input_signal=signal, input_signal_length=signal_len) pred_speaker = torch.argmax(sv_logits, dim=1) sv_loss = self.sv_loss(logits=sv_logits, labels=speaker_id) loss += sv_loss if not self._cfg.combined_loss: optim_backbone.zero_grad() optim_downstream.zero_grad() self.manual_backward(sv_loss) optim_backbone.step() optim_downstream.step() correct = pred_speaker.eq(speaker_id.data.view_as(pred_speaker)).sum().item() acc = (correct / len(speaker_id)) * 100 self.log("t_sv_loss", sv_loss.item()) self.log("t_sv_accuracy", acc) elif key == "content": content_loss = 0 signal = batch[key]['audio'] signal_len = batch[key]['audio_lens'] target = batch[key]['text'] # (B, T) target_len = batch[key]['text_lens'] _, _, content_embedding, content_log_probs, encoded_len = self.forward( input_signal=signal, input_signal_length=signal_len ) ctc_loss = self.ctc_loss(content_log_probs, target, encoded_len, target_len) # check if ctc loss is nan if torch.isfinite(ctc_loss): self.log("t_ctc_loss", ctc_loss.item()) content_loss += ctc_loss else: logging.warning(f"ctc_loss is not finite") if self.pitch_augment: augmented_signal = batch[key]['audio_shifted'] if self.stop_gradient: with torch.no_grad(): _, _, content_embedding_aug, content_log_probs_aug, _ = self.forward( input_signal=augmented_signal, input_signal_length=signal_len ) else: _, _, content_embedding_aug, content_log_probs_aug, _ = self.forward( input_signal=augmented_signal, input_signal_length=signal_len ) if self.aug_loss_type == "mse": sim_loss = self.mse_loss(content_embedding, content_embedding_aug) elif self.aug_loss_type == "cosine": cosine_similarity = torch.nn.functional.cosine_similarity( content_embedding, content_embedding_aug, dim=-1 ).mean() sim_loss = 1.0 - cosine_similarity content_loss += self._cfg.augment_sim_alpha * sim_loss self.log("t_sim_loss", sim_loss.item()) if self.augment_ctc: ctc_loss_aug = self.ctc_loss(content_log_probs_aug, target, encoded_len, target_len) if torch.isfinite(ctc_loss_aug): content_loss += ctc_loss_aug self.log("t_ctc_loss_aug", ctc_loss_aug.item()) else: logging.warning(f"ctc_loss_aug is not finite. Add min duration to avoid getting here.") loss += content_loss if not self._cfg.combined_loss: optim_backbone.zero_grad() optim_downstream.zero_grad() self.manual_backward(content_loss) optim_backbone.step() optim_downstream.step() if isinstance(content_loss, torch.Tensor): self.log("t_content_loss", content_loss.item()) if self._cfg.combined_loss: optim_backbone.zero_grad() optim_downstream.zero_grad() self.manual_backward(loss) optim_backbone.step() optim_downstream.step() if schedulers is not None: sch1, sch2 = schedulers sch1.step() sch2.step() if self.trainer.global_step % 10 == 0: self.log("lr_backbone", optim_backbone.param_groups[0]['lr']) self.log("lr_downstream", optim_downstream.param_groups[0]['lr']) self.log("t_loss", loss) def validation_step(self, batch, batch_idx): loss_total = 0 for key in batch.keys(): if key == 'sv': signal = batch[key]['audio'] signal_len = batch[key]['audio_lens'] speaker_id = batch[key]['speaker_id'] sv_logits, sv_emb, _, _, _ = self.forward(input_signal=signal, input_signal_length=signal_len) pred_speaker = torch.argmax(sv_logits, dim=1) sv_loss = self.sv_loss(logits=sv_logits, labels=speaker_id) loss_total += sv_loss correct = pred_speaker.eq(speaker_id.data.view_as(pred_speaker)).sum().item() acc = (correct / len(speaker_id)) * 100 acc_val = torch.as_tensor(acc) if key == 'content': content_loss = 0 signal = batch[key]['audio'] signal_len = batch[key]['audio_lens'] target = batch[key]['text'] # (B, T) target_len = batch[key]['text_lens'] _, _, content_embedding, content_log_probs, encoded_len = self.forward( input_signal=signal, input_signal_length=signal_len ) ctc_loss = self.ctc_loss(content_log_probs, target, encoded_len, target_len) # check if ctc loss is nan if torch.isfinite(ctc_loss): content_loss += ctc_loss else: logging.warning(f"ctc_loss is not finite. Add min duration to avoid getting here.") if self.pitch_augment: augmented_signal = batch[key]['audio_shifted'] _, _, content_embedding_aug, content_log_probs_aug, _ = self.forward( input_signal=augmented_signal, input_signal_length=signal_len ) if self.aug_loss_type == "mse": sim_loss = self.mse_loss(content_embedding, content_embedding_aug) elif self.aug_loss_type == "cosine": cosine_similarity = torch.nn.functional.cosine_similarity( content_embedding, content_embedding_aug, dim=-1 ).mean() sim_loss = 1.0 - cosine_similarity content_loss += self._cfg.augment_sim_alpha * sim_loss loss_total += content_loss cers = [] for _idx in range(target.shape[0]): item_log_prob = content_log_probs[:, _idx, :][: encoded_len[_idx]].cpu() item_target = target[_idx][: target_len[_idx]].cpu() _, predicted_str = self.ctc_decoder(item_log_prob) tokenizer = self._text_tokenizer target_str = tokenizer.sep.join(tokenizer._id2token[t] for t in item_target.tolist()) ed = editdistance.eval(predicted_str, target_str) if max(len(predicted_str), len(target_str)) > 0: normalized_ed = (1.0 * ed) / max(len(predicted_str), len(target_str)) else: normalized_ed = 1.0 cers.append(normalized_ed) return { 'val_loss': loss_total.cpu(), 'sv_loss': sv_loss.cpu(), 'ctc_loss': ctc_loss.cpu(), 'content_loss': content_loss.cpu(), 'accuracy_sv': acc_val.cpu(), 'cer': torch.tensor(cers).mean().cpu(), } def on_validation_epoch_end(self, outputs): collect = lambda key: torch.stack([x[key] for x in outputs if torch.isfinite(x[key])]).mean() val_loss = collect("val_loss") val_sv_loss = collect("sv_loss") val_ctc_loss = collect("ctc_loss") val_content_loss = collect("content_loss") accuracy_sv = collect("accuracy_sv") cer = collect("cer") self.log("val_loss", val_loss) self.log("sv_loss", val_sv_loss) self.log("val_ctc_loss", val_ctc_loss) self.log("val_content_loss", val_content_loss) self.log("accuracy_sv", accuracy_sv) self.log("cer", cer)