Spaces:
Runtime error
Runtime error
| # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import itertools | |
| from typing import Iterable, Optional | |
| import editdistance | |
| import librosa | |
| import torch | |
| from hydra.utils import instantiate | |
| from lightning.pytorch import Trainer | |
| from lightning.pytorch.loggers import TensorBoardLogger | |
| from lightning.pytorch.utilities.combined_loader import CombinedLoader | |
| from omegaconf import DictConfig, OmegaConf | |
| from nemo.collections.asr.losses.angularloss import AngularSoftmaxLoss | |
| from nemo.collections.tts.data.dataset import TTSDataset | |
| from nemo.collections.tts.modules.ssl_tts import GreedyCTCDecoder | |
| from nemo.collections.tts.torch.tts_tokenizers import BaseTokenizer, EnglishCharsTokenizer | |
| from nemo.core.classes import ModelPT | |
| from nemo.core.classes.common import PretrainedModelInfo | |
| from nemo.core.optim.lr_scheduler import WarmupPolicy | |
| from nemo.utils import logging | |
| from nemo.utils.decorators import experimental | |
| class SSLDisentangler(ModelPT): | |
| """ | |
| SSLDisentangler is a Conformer based model for extracting disentangled content and speaker embeddings | |
| from an audio waveform. This model uses a pre-trained Conformer SSL model. To extract the linguistic content | |
| and speaker representations using a pre-trained Conformer, two randomly initialized downstream | |
| heads are added and the entire setup is finetuned in multi-task manner for speech recognition and speaker verification. | |
| These representations can be used by FastPitchModel_SSL for voice conversion by swapping the speaker embedding | |
| of a given source utterance, with the speaker embedding of a target speaker. | |
| """ | |
| def __init__(self, cfg: DictConfig, trainer: Trainer = None): | |
| super().__init__(cfg=cfg, trainer=trainer) | |
| self.preprocessor_disentangler = SSLDisentangler.from_config_dict(self._cfg.preprocessor) | |
| self.encoder = SSLDisentangler.from_config_dict(self._cfg.encoder) | |
| self._text_tokenizer = EnglishCharsTokenizer(add_blank_at="last") | |
| self._tb_logger = None | |
| self.downstream_nets = torch.nn.ModuleDict() | |
| for task in self._cfg.downstream_heads.task_names: | |
| if task == 'speaker_verification': | |
| # setting up downstream heads and loss functions for speaker verification task | |
| in_dim = self._cfg.encoder.d_model | |
| out_dim = self._cfg.downstream_heads.speaker_embed_size | |
| num_speakers = self._cfg.downstream_heads.num_speakers | |
| self.downstream_nets[task] = torch.nn.Linear(in_dim, out_dim) | |
| self.sv_linear = torch.nn.Linear(out_dim, num_speakers) | |
| self.sv_loss = AngularSoftmaxLoss(scale=30, margin=0.4) | |
| elif task == 'content': | |
| # setting up downstream heads and loss functions for text/content recognition task | |
| in_dim = self._cfg.encoder.d_model | |
| out_dim = self._cfg.downstream_heads.content_embed_size | |
| num_chars = len(self._text_tokenizer.tokens) # list of english tokens | |
| self.downstream_nets[task] = torch.nn.Linear(in_dim, out_dim) | |
| self.content_linear = torch.nn.Linear(out_dim, num_chars) | |
| self.ctc_loss = torch.nn.CTCLoss(blank=self._text_tokenizer.blank, zero_infinity=True) | |
| self.pitch_augment = self._cfg.get('pitch_augment', False) | |
| self.augment_ctc = self._cfg.get('augment_ctc', False) | |
| self.aug_loss_type = self._cfg.get('aug_loss_type', 'mse') | |
| self.stop_gradient = self._cfg.get('stop_gradient', False) | |
| assert ( | |
| self.stop_gradient and self.augment_ctc | |
| ) == False, "stop_gradient and augment_ctc cannot be true at the same time" | |
| self.mse_loss = torch.nn.MSELoss() | |
| self.ctc_decoder = GreedyCTCDecoder(self._text_tokenizer.tokens, self._text_tokenizer.blank) | |
| else: | |
| raise ValueError(f"{task} is not a valid task. Task must be speaker_verification or content.") | |
| self.automatic_optimization = False | |
| stft_cfg = self._cfg.preprocessor | |
| librosa_mel_filter = librosa.filters.mel( | |
| sr=stft_cfg.sample_rate, n_fft=stft_cfg.n_fft, n_mels=stft_cfg.features, fmin=0, fmax=8000 | |
| ) | |
| fb = torch.tensor( | |
| librosa_mel_filter, | |
| dtype=torch.float, | |
| ).unsqueeze(0) | |
| self.register_buffer("fb", fb) | |
| def list_available_models(cls) -> Optional[PretrainedModelInfo]: | |
| """ | |
| This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. | |
| Returns: | |
| List of available pre-trained models. | |
| """ | |
| results = [] | |
| model = PretrainedModelInfo( | |
| pretrained_model_name="ssl_en_conformer_large", | |
| description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:ssl_en_conformer_large", | |
| location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/ssl_en_conformer_large/versions/1.10.1/files/ssl_en_conformer_large.nemo", | |
| ) | |
| results.append(model) | |
| model = PretrainedModelInfo( | |
| pretrained_model_name="ssl_en_conformer_xlarge", | |
| description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:ssl_en_conformer_xlarge", | |
| location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/ssl_en_conformer_xlarge/versions/1.10.0/files/ssl_en_conformer_xlarge.nemo", | |
| ) | |
| results.append(model) | |
| return results | |
| def tb_logger(self): | |
| if self._tb_logger is None: | |
| if self.logger is None and self.logger.experiment is None: | |
| return None | |
| tb_logger = self.logger.experiment | |
| if isinstance(self.logger, Iterable): | |
| for logger in self.logger: | |
| if isinstance(logger, TensorBoardLogger): | |
| tb_logger = logger.experiment | |
| break | |
| self._tb_logger = tb_logger | |
| return self._tb_logger | |
| def __setup_dataloader_from_config(self, data_config): | |
| if hasattr(self, '_text_tokenizer') and isinstance(self._text_tokenizer, BaseTokenizer): | |
| _text_tokenizer = self._text_tokenizer | |
| else: | |
| if hasattr(self, '_text_tokenizer') and not isinstance(self._text_tokenizer, BaseTokenizer): | |
| logging.warning(f"test_tokenizer is set but not a BaseTokenizer. Will be set to EnglishCharsTokenizer") | |
| _text_tokenizer = self._text_tokenizer = EnglishCharsTokenizer(add_blank_at="last") | |
| for task in self._cfg.downstream_heads.task_names: | |
| if task == 'speaker_verification': | |
| sv_dataset = TTSDataset( | |
| manifest_filepath=data_config['manifest_speaker_verification_fp'], | |
| sample_rate=self._cfg.sample_rate, | |
| text_tokenizer=_text_tokenizer, | |
| segment_max_duration=data_config['segment_max_duration'], | |
| sup_data_types=['speaker_id'], | |
| sup_data_path=data_config['sup_data_path'], | |
| pad_multiple=data_config.get('pad_multiple', 1), | |
| ) | |
| sv_loader = torch.utils.data.DataLoader( | |
| sv_dataset, | |
| batch_size=data_config['batch_size_sv'], | |
| collate_fn=sv_dataset.general_collate_fn, | |
| shuffle=data_config['shuffle'], | |
| num_workers=data_config.get('num_workers_sv', 0), | |
| pin_memory=data_config.get('pin_memory', False), | |
| ) | |
| elif task == 'content': | |
| content_dataset = TTSDataset( | |
| manifest_filepath=data_config['manifest_content_fp'], | |
| sample_rate=self._cfg.sample_rate, | |
| text_tokenizer=_text_tokenizer, | |
| min_duration=data_config['min_duration_content'], | |
| max_duration=data_config['max_duration_content'], | |
| pitch_augment=data_config.get('pitch_augment', False), | |
| cache_pitch_augment=data_config.get('cache_pitch_augment', True), | |
| sup_data_path=data_config['sup_data_path'], | |
| pad_multiple=data_config.get('pad_multiple', 1), | |
| ) | |
| content_loader = torch.utils.data.DataLoader( | |
| content_dataset, | |
| batch_size=data_config['batch_size_content'], | |
| collate_fn=content_dataset.general_collate_fn, | |
| shuffle=data_config['shuffle'], | |
| num_workers=data_config.get('num_workers_content', 0), | |
| pin_memory=data_config.get('pin_memory', False), | |
| ) | |
| else: | |
| raise ValueError(f"{task} is not a valid task. Task must be speaker_verification or content.") | |
| loaders = {"sv": sv_loader, "content": content_loader} | |
| return loaders | |
| def setup_training_data(self, cfg): | |
| self._train_dl = self.__setup_dataloader_from_config(self._cfg.train_ds) | |
| def setup_validation_data(self, cfg): | |
| self._validation_dl = CombinedLoader(self.__setup_dataloader_from_config(self._cfg.validation_ds)) | |
| def configure_optimizers(self): | |
| optim_backbone_config = self._cfg.optim_backbone.copy() | |
| optim_downstream_config = self._cfg.optim_downstream.copy() | |
| OmegaConf.set_struct(optim_backbone_config, False) | |
| sched_backbone_config = optim_backbone_config.pop("sched", None) | |
| OmegaConf.set_struct(optim_backbone_config, True) | |
| OmegaConf.set_struct(optim_downstream_config, False) | |
| sched_downstream_config = optim_downstream_config.pop("sched", None) | |
| OmegaConf.set_struct(optim_downstream_config, True) | |
| optim_backbone = instantiate( | |
| optim_backbone_config, | |
| params=self.encoder.parameters(), | |
| ) | |
| optim_downstream = instantiate( | |
| optim_downstream_config, | |
| params=itertools.chain( | |
| self.downstream_nets.parameters(), | |
| self.sv_linear.parameters(), | |
| self.content_linear.parameters(), | |
| self.sv_loss.parameters(), | |
| ), | |
| ) | |
| if sched_backbone_config is not None and sched_downstream_config is not None: | |
| scheduler_backbone = WarmupPolicy( | |
| optimizer=optim_backbone, | |
| max_steps=None, | |
| min_lr=sched_backbone_config.min_lr, | |
| warmup_steps=sched_backbone_config.warmup_steps, | |
| ) # Use warmup to delay start | |
| sch1_dict = { | |
| 'scheduler': scheduler_backbone, | |
| 'interval': 'step', | |
| } | |
| scheduler_downstream = WarmupPolicy( | |
| optimizer=optim_downstream, | |
| max_steps=None, | |
| min_lr=sched_downstream_config.min_lr, | |
| warmup_steps=sched_downstream_config.warmup_steps, | |
| ) | |
| sch2_dict = { | |
| 'scheduler': scheduler_downstream, | |
| 'interval': 'step', | |
| } | |
| return [optim_backbone, optim_downstream], [sch1_dict, sch2_dict] | |
| else: | |
| return [optim_backbone, optim_downstream] | |
| def forward(self, input_signal=None, input_signal_length=None, normalize_content=True): | |
| processed_signal, processed_signal_length = self.preprocessor_disentangler( | |
| input_signal=input_signal, | |
| length=input_signal_length, | |
| ) | |
| encoded, encoded_len = self.encoder(audio_signal=processed_signal, length=processed_signal_length) # b,c,t | |
| for task in self._cfg.downstream_heads.task_names: | |
| if task == "speaker_verification": | |
| speaker_embedding = self.downstream_nets['speaker_verification'](encoded[:, :, 0]) | |
| l2_norm = torch.norm(speaker_embedding, p=2, dim=-1, keepdim=True) | |
| speaker_embedding_normalized = speaker_embedding / l2_norm | |
| speaker_logits = self.sv_linear(speaker_embedding_normalized) | |
| elif task == "content": | |
| encoded_btc = encoded.permute(0, 2, 1) | |
| content_embedding = self.downstream_nets['content'](encoded_btc) | |
| if normalize_content: | |
| l2_norm_content = torch.norm(content_embedding, p=2, dim=-1, keepdim=True) | |
| content_embedding = content_embedding / l2_norm_content | |
| content_logits = self.content_linear(content_embedding) | |
| content_log_probs = content_logits.log_softmax(dim=2) | |
| content_log_probs = content_log_probs.permute(1, 0, 2) # t,b,c for ctc | |
| else: | |
| raise ValueError(f"{task} is not a valid task. Task must be speaker_verification or content.") | |
| return ( | |
| speaker_logits, | |
| speaker_embedding_normalized, | |
| content_embedding, | |
| content_log_probs, | |
| encoded_len, | |
| ) | |
| def forward_for_export(self, input_signal=None, input_signal_length=None, normalize_content=True): | |
| # Same as forward right now. Earlier version of encoder had a different forward for export. | |
| # This function is still kept for compatibility with older evaluation/inference scripts. | |
| return self.forward( | |
| input_signal=input_signal, | |
| input_signal_length=input_signal_length, | |
| normalize_content=normalize_content, | |
| ) | |
| def training_step(self, batch, batch_idx): | |
| loss = 0.0 | |
| optim_backbone, optim_downstream = self.optimizers() | |
| schedulers = self.lr_schedulers() | |
| for key in batch.keys(): | |
| if key == 'sv': | |
| signal = batch[key]['audio'] | |
| signal_len = batch[key]['audio_lens'] | |
| speaker_id = batch[key]['speaker_id'] | |
| sv_logits, sv_emb, _, _, _ = self.forward(input_signal=signal, input_signal_length=signal_len) | |
| pred_speaker = torch.argmax(sv_logits, dim=1) | |
| sv_loss = self.sv_loss(logits=sv_logits, labels=speaker_id) | |
| loss += sv_loss | |
| if not self._cfg.combined_loss: | |
| optim_backbone.zero_grad() | |
| optim_downstream.zero_grad() | |
| self.manual_backward(sv_loss) | |
| optim_backbone.step() | |
| optim_downstream.step() | |
| correct = pred_speaker.eq(speaker_id.data.view_as(pred_speaker)).sum().item() | |
| acc = (correct / len(speaker_id)) * 100 | |
| self.log("t_sv_loss", sv_loss.item()) | |
| self.log("t_sv_accuracy", acc) | |
| elif key == "content": | |
| content_loss = 0 | |
| signal = batch[key]['audio'] | |
| signal_len = batch[key]['audio_lens'] | |
| target = batch[key]['text'] # (B, T) | |
| target_len = batch[key]['text_lens'] | |
| _, _, content_embedding, content_log_probs, encoded_len = self.forward( | |
| input_signal=signal, input_signal_length=signal_len | |
| ) | |
| ctc_loss = self.ctc_loss(content_log_probs, target, encoded_len, target_len) | |
| # check if ctc loss is nan | |
| if torch.isfinite(ctc_loss): | |
| self.log("t_ctc_loss", ctc_loss.item()) | |
| content_loss += ctc_loss | |
| else: | |
| logging.warning(f"ctc_loss is not finite") | |
| if self.pitch_augment: | |
| augmented_signal = batch[key]['audio_shifted'] | |
| if self.stop_gradient: | |
| with torch.no_grad(): | |
| _, _, content_embedding_aug, content_log_probs_aug, _ = self.forward( | |
| input_signal=augmented_signal, input_signal_length=signal_len | |
| ) | |
| else: | |
| _, _, content_embedding_aug, content_log_probs_aug, _ = self.forward( | |
| input_signal=augmented_signal, input_signal_length=signal_len | |
| ) | |
| if self.aug_loss_type == "mse": | |
| sim_loss = self.mse_loss(content_embedding, content_embedding_aug) | |
| elif self.aug_loss_type == "cosine": | |
| cosine_similarity = torch.nn.functional.cosine_similarity( | |
| content_embedding, content_embedding_aug, dim=-1 | |
| ).mean() | |
| sim_loss = 1.0 - cosine_similarity | |
| content_loss += self._cfg.augment_sim_alpha * sim_loss | |
| self.log("t_sim_loss", sim_loss.item()) | |
| if self.augment_ctc: | |
| ctc_loss_aug = self.ctc_loss(content_log_probs_aug, target, encoded_len, target_len) | |
| if torch.isfinite(ctc_loss_aug): | |
| content_loss += ctc_loss_aug | |
| self.log("t_ctc_loss_aug", ctc_loss_aug.item()) | |
| else: | |
| logging.warning(f"ctc_loss_aug is not finite. Add min duration to avoid getting here.") | |
| loss += content_loss | |
| if not self._cfg.combined_loss: | |
| optim_backbone.zero_grad() | |
| optim_downstream.zero_grad() | |
| self.manual_backward(content_loss) | |
| optim_backbone.step() | |
| optim_downstream.step() | |
| if isinstance(content_loss, torch.Tensor): | |
| self.log("t_content_loss", content_loss.item()) | |
| if self._cfg.combined_loss: | |
| optim_backbone.zero_grad() | |
| optim_downstream.zero_grad() | |
| self.manual_backward(loss) | |
| optim_backbone.step() | |
| optim_downstream.step() | |
| if schedulers is not None: | |
| sch1, sch2 = schedulers | |
| sch1.step() | |
| sch2.step() | |
| if self.trainer.global_step % 10 == 0: | |
| self.log("lr_backbone", optim_backbone.param_groups[0]['lr']) | |
| self.log("lr_downstream", optim_downstream.param_groups[0]['lr']) | |
| self.log("t_loss", loss) | |
| def validation_step(self, batch, batch_idx): | |
| loss_total = 0 | |
| for key in batch.keys(): | |
| if key == 'sv': | |
| signal = batch[key]['audio'] | |
| signal_len = batch[key]['audio_lens'] | |
| speaker_id = batch[key]['speaker_id'] | |
| sv_logits, sv_emb, _, _, _ = self.forward(input_signal=signal, input_signal_length=signal_len) | |
| pred_speaker = torch.argmax(sv_logits, dim=1) | |
| sv_loss = self.sv_loss(logits=sv_logits, labels=speaker_id) | |
| loss_total += sv_loss | |
| correct = pred_speaker.eq(speaker_id.data.view_as(pred_speaker)).sum().item() | |
| acc = (correct / len(speaker_id)) * 100 | |
| acc_val = torch.as_tensor(acc) | |
| if key == 'content': | |
| content_loss = 0 | |
| signal = batch[key]['audio'] | |
| signal_len = batch[key]['audio_lens'] | |
| target = batch[key]['text'] # (B, T) | |
| target_len = batch[key]['text_lens'] | |
| _, _, content_embedding, content_log_probs, encoded_len = self.forward( | |
| input_signal=signal, input_signal_length=signal_len | |
| ) | |
| ctc_loss = self.ctc_loss(content_log_probs, target, encoded_len, target_len) | |
| # check if ctc loss is nan | |
| if torch.isfinite(ctc_loss): | |
| content_loss += ctc_loss | |
| else: | |
| logging.warning(f"ctc_loss is not finite. Add min duration to avoid getting here.") | |
| if self.pitch_augment: | |
| augmented_signal = batch[key]['audio_shifted'] | |
| _, _, content_embedding_aug, content_log_probs_aug, _ = self.forward( | |
| input_signal=augmented_signal, input_signal_length=signal_len | |
| ) | |
| if self.aug_loss_type == "mse": | |
| sim_loss = self.mse_loss(content_embedding, content_embedding_aug) | |
| elif self.aug_loss_type == "cosine": | |
| cosine_similarity = torch.nn.functional.cosine_similarity( | |
| content_embedding, content_embedding_aug, dim=-1 | |
| ).mean() | |
| sim_loss = 1.0 - cosine_similarity | |
| content_loss += self._cfg.augment_sim_alpha * sim_loss | |
| loss_total += content_loss | |
| cers = [] | |
| for _idx in range(target.shape[0]): | |
| item_log_prob = content_log_probs[:, _idx, :][: encoded_len[_idx]].cpu() | |
| item_target = target[_idx][: target_len[_idx]].cpu() | |
| _, predicted_str = self.ctc_decoder(item_log_prob) | |
| tokenizer = self._text_tokenizer | |
| target_str = tokenizer.sep.join(tokenizer._id2token[t] for t in item_target.tolist()) | |
| ed = editdistance.eval(predicted_str, target_str) | |
| if max(len(predicted_str), len(target_str)) > 0: | |
| normalized_ed = (1.0 * ed) / max(len(predicted_str), len(target_str)) | |
| else: | |
| normalized_ed = 1.0 | |
| cers.append(normalized_ed) | |
| return { | |
| 'val_loss': loss_total.cpu(), | |
| 'sv_loss': sv_loss.cpu(), | |
| 'ctc_loss': ctc_loss.cpu(), | |
| 'content_loss': content_loss.cpu(), | |
| 'accuracy_sv': acc_val.cpu(), | |
| 'cer': torch.tensor(cers).mean().cpu(), | |
| } | |
| def on_validation_epoch_end(self, outputs): | |
| collect = lambda key: torch.stack([x[key] for x in outputs if torch.isfinite(x[key])]).mean() | |
| val_loss = collect("val_loss") | |
| val_sv_loss = collect("sv_loss") | |
| val_ctc_loss = collect("ctc_loss") | |
| val_content_loss = collect("content_loss") | |
| accuracy_sv = collect("accuracy_sv") | |
| cer = collect("cer") | |
| self.log("val_loss", val_loss) | |
| self.log("sv_loss", val_sv_loss) | |
| self.log("val_ctc_loss", val_ctc_loss) | |
| self.log("val_content_loss", val_content_loss) | |
| self.log("accuracy_sv", accuracy_sv) | |
| self.log("cer", cer) | |