Spaces:
Runtime error
Runtime error
| # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import contextlib | |
| from typing import List, Optional | |
| import numpy as np | |
| import omegaconf | |
| import torch | |
| import transformers | |
| import wandb | |
| from hydra.utils import instantiate | |
| from lightning.pytorch import Trainer | |
| from lightning.pytorch.loggers import WandbLogger | |
| from omegaconf import DictConfig | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from transformers import AlbertTokenizer | |
| from nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers import ( | |
| EnglishCharsTokenizer, | |
| EnglishPhonemesTokenizer, | |
| ) | |
| from nemo.collections.tts.losses.aligner_loss import BinLoss, ForwardSumLoss | |
| from nemo.collections.tts.models.base import SpectrogramGenerator | |
| from nemo.collections.tts.modules.fastpitch import average_features, regulate_len | |
| from nemo.collections.tts.parts.utils.helpers import ( | |
| binarize_attention_parallel, | |
| g2p_backward_compatible_support, | |
| get_mask_from_lengths, | |
| plot_pitch_to_numpy, | |
| plot_spectrogram_to_numpy, | |
| ) | |
| from nemo.core import Exportable | |
| from nemo.core.classes.common import PretrainedModelInfo, typecheck | |
| from nemo.core.neural_types.elements import ( | |
| LengthsType, | |
| LogprobsType, | |
| MelSpectrogramType, | |
| ProbsType, | |
| RegressionValuesType, | |
| TokenDurationType, | |
| TokenIndex, | |
| TokenLogDurationType, | |
| ) | |
| from nemo.core.neural_types.neural_type import NeuralType | |
| from nemo.utils import logging, model_utils | |
| class MixerTTSModel(SpectrogramGenerator, Exportable): | |
| """Mixer-TTS and Mixer-TTS-X models (https://arxiv.org/abs/2110.03584) that is used to generate mel spectrogram from text.""" | |
| def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): | |
| # Convert to Hydra 1.0 compatible DictConfig | |
| cfg = model_utils.convert_model_config_to_dict_config(cfg) | |
| cfg = model_utils.maybe_update_config_version(cfg) | |
| # Setup normalizer | |
| self.normalizer = None | |
| self.text_normalizer_call = None | |
| self.text_normalizer_call_kwargs = {} | |
| self._setup_normalizer(cfg) | |
| # Setup tokenizer | |
| self.tokenizer = None | |
| self._setup_tokenizer(cfg) | |
| assert self.tokenizer is not None | |
| num_tokens = len(self.tokenizer.tokens) | |
| self.tokenizer_pad = self.tokenizer.pad | |
| self.tokenizer_unk = self.tokenizer.oov | |
| super().__init__(cfg=cfg, trainer=trainer) | |
| self.pitch_loss_scale = cfg.pitch_loss_scale | |
| self.durs_loss_scale = cfg.durs_loss_scale | |
| self.mel_loss_scale = cfg.mel_loss_scale | |
| self.aligner = instantiate(cfg.alignment_module) | |
| self.forward_sum_loss = ForwardSumLoss() | |
| self.bin_loss = BinLoss() | |
| self.add_bin_loss = False | |
| self.bin_loss_scale = 0.0 | |
| self.bin_loss_start_ratio = cfg.bin_loss_start_ratio | |
| self.bin_loss_warmup_epochs = cfg.bin_loss_warmup_epochs | |
| self.cond_on_lm_embeddings = cfg.get("cond_on_lm_embeddings", False) | |
| if self.cond_on_lm_embeddings: | |
| self.lm_padding_value = ( | |
| self._train_dl.dataset.lm_padding_value | |
| if self._train_dl is not None | |
| else self._get_lm_padding_value(cfg.lm_model) | |
| ) | |
| self.lm_embeddings = self._get_lm_embeddings(cfg.lm_model) | |
| self.lm_embeddings.weight.requires_grad = False | |
| self.self_attention_module = instantiate( | |
| cfg.self_attention_module, n_lm_tokens_channels=self.lm_embeddings.weight.shape[1] | |
| ) | |
| self.encoder = instantiate(cfg.encoder, num_tokens=num_tokens, padding_idx=self.tokenizer_pad) | |
| self.symbol_emb = self.encoder.to_embed | |
| self.duration_predictor = instantiate(cfg.duration_predictor) | |
| self.pitch_mean, self.pitch_std = float(cfg.pitch_mean), float(cfg.pitch_std) | |
| self.pitch_predictor = instantiate(cfg.pitch_predictor) | |
| self.pitch_emb = instantiate(cfg.pitch_emb) | |
| self.preprocessor = instantiate(cfg.preprocessor) | |
| self.decoder = instantiate(cfg.decoder) | |
| self.proj = nn.Linear(self.decoder.d_model, cfg.n_mel_channels) | |
| def _setup_tokenizer(self, cfg): | |
| text_tokenizer_kwargs = {} | |
| if "g2p" in cfg.text_tokenizer: | |
| # for backward compatibility | |
| if ( | |
| self._is_model_being_restored() | |
| and (cfg.text_tokenizer.g2p.get('_target_', None) is not None) | |
| and cfg.text_tokenizer.g2p["_target_"].startswith("nemo_text_processing.g2p") | |
| ): | |
| cfg.text_tokenizer.g2p["_target_"] = g2p_backward_compatible_support( | |
| cfg.text_tokenizer.g2p["_target_"] | |
| ) | |
| g2p_kwargs = {} | |
| if "phoneme_dict" in cfg.text_tokenizer.g2p: | |
| g2p_kwargs["phoneme_dict"] = self.register_artifact( | |
| 'text_tokenizer.g2p.phoneme_dict', | |
| cfg.text_tokenizer.g2p.phoneme_dict, | |
| ) | |
| if "heteronyms" in cfg.text_tokenizer.g2p: | |
| g2p_kwargs["heteronyms"] = self.register_artifact( | |
| 'text_tokenizer.g2p.heteronyms', | |
| cfg.text_tokenizer.g2p.heteronyms, | |
| ) | |
| text_tokenizer_kwargs["g2p"] = instantiate(cfg.text_tokenizer.g2p, **g2p_kwargs) | |
| self.tokenizer = instantiate(cfg.text_tokenizer, **text_tokenizer_kwargs) | |
| def _get_lm_model_tokenizer(self, lm_model="albert"): | |
| if getattr(self, "_lm_model_tokenizer", None) is not None: | |
| return self._lm_model_tokenizer | |
| if self._train_dl is not None and self._train_dl.dataset is not None: | |
| self._lm_model_tokenizer = self._train_dl.dataset.lm_model_tokenizer | |
| if lm_model == "albert": | |
| self._lm_model_tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2') | |
| else: | |
| raise NotImplementedError( | |
| f"{lm_model} lm model is not supported. Only albert is supported at this moment." | |
| ) | |
| return self._lm_model_tokenizer | |
| def _get_lm_embeddings(self, lm_model="albert"): | |
| if lm_model == "albert": | |
| return transformers.AlbertModel.from_pretrained('albert-base-v2').embeddings.word_embeddings | |
| else: | |
| raise NotImplementedError( | |
| f"{lm_model} lm model is not supported. Only albert is supported at this moment." | |
| ) | |
| def _get_lm_padding_value(self, lm_model="albert"): | |
| if lm_model == "albert": | |
| return transformers.AlbertTokenizer.from_pretrained('albert-base-v2')._convert_token_to_id('<pad>') | |
| else: | |
| raise NotImplementedError( | |
| f"{lm_model} lm model is not supported. Only albert is supported at this moment." | |
| ) | |
| def _metrics( | |
| self, | |
| true_durs, | |
| true_text_len, | |
| pred_durs, | |
| true_pitch, | |
| pred_pitch, | |
| true_spect=None, | |
| pred_spect=None, | |
| true_spect_len=None, | |
| attn_logprob=None, | |
| attn_soft=None, | |
| attn_hard=None, | |
| attn_hard_dur=None, | |
| ): | |
| text_mask = get_mask_from_lengths(true_text_len) | |
| mel_mask = get_mask_from_lengths(true_spect_len) | |
| loss = 0.0 | |
| # Dur loss and metrics | |
| durs_loss = F.mse_loss(pred_durs, (true_durs + 1).float().log(), reduction='none') | |
| durs_loss = durs_loss * text_mask.float() | |
| durs_loss = durs_loss.sum() / text_mask.sum() | |
| durs_pred = pred_durs.exp() - 1 | |
| durs_pred = torch.clamp_min(durs_pred, min=0) | |
| durs_pred = durs_pred.round().long() | |
| acc = ((true_durs == durs_pred) * text_mask).sum().float() / text_mask.sum() * 100 | |
| acc_dist_1 = (((true_durs - durs_pred).abs() <= 1) * text_mask).sum().float() / text_mask.sum() * 100 | |
| acc_dist_3 = (((true_durs - durs_pred).abs() <= 3) * text_mask).sum().float() / text_mask.sum() * 100 | |
| pred_spect = pred_spect.transpose(1, 2) | |
| # Mel loss | |
| mel_loss = F.mse_loss(pred_spect, true_spect, reduction='none').mean(dim=-2) | |
| mel_loss = mel_loss * mel_mask.float() | |
| mel_loss = mel_loss.sum() / mel_mask.sum() | |
| loss = loss + self.durs_loss_scale * durs_loss + self.mel_loss_scale * mel_loss | |
| # Aligner loss | |
| bin_loss, ctc_loss = None, None | |
| ctc_loss = self.forward_sum_loss(attn_logprob=attn_logprob, in_lens=true_text_len, out_lens=true_spect_len) | |
| loss = loss + ctc_loss | |
| if self.add_bin_loss: | |
| bin_loss = self.bin_loss(hard_attention=attn_hard, soft_attention=attn_soft) | |
| loss = loss + self.bin_loss_scale * bin_loss | |
| true_avg_pitch = average_features(true_pitch.unsqueeze(1), attn_hard_dur).squeeze(1) | |
| # Pitch loss | |
| pitch_loss = F.mse_loss(pred_pitch, true_avg_pitch, reduction='none') # noqa | |
| pitch_loss = (pitch_loss * text_mask).sum() / text_mask.sum() | |
| loss = loss + self.pitch_loss_scale * pitch_loss | |
| return loss, durs_loss, acc, acc_dist_1, acc_dist_3, pitch_loss, mel_loss, ctc_loss, bin_loss | |
| def run_aligner(self, text, text_len, text_mask, spect, spect_len, attn_prior): | |
| text_emb = self.symbol_emb(text) | |
| attn_soft, attn_logprob = self.aligner( | |
| spect, | |
| text_emb.permute(0, 2, 1), | |
| mask=text_mask == 0, | |
| attn_prior=attn_prior, | |
| ) | |
| attn_hard = binarize_attention_parallel(attn_soft, text_len, spect_len) | |
| attn_hard_dur = attn_hard.sum(2)[:, 0, :] | |
| assert torch.all(torch.eq(attn_hard_dur.sum(dim=1), spect_len)) | |
| return attn_soft, attn_logprob, attn_hard, attn_hard_dur | |
| def forward(self, text, text_len, pitch=None, spect=None, spect_len=None, attn_prior=None, lm_tokens=None): | |
| if self.training: | |
| assert pitch is not None | |
| text_mask = get_mask_from_lengths(text_len).unsqueeze(2) | |
| enc_out, enc_mask = self.encoder(text, text_mask) | |
| # Aligner | |
| attn_soft, attn_logprob, attn_hard, attn_hard_dur = None, None, None, None | |
| if spect is not None: | |
| attn_soft, attn_logprob, attn_hard, attn_hard_dur = self.run_aligner( | |
| text, text_len, text_mask, spect, spect_len, attn_prior | |
| ) | |
| if self.cond_on_lm_embeddings: | |
| lm_emb = self.lm_embeddings(lm_tokens) | |
| lm_features = self.self_attention_module( | |
| enc_out, lm_emb, lm_emb, q_mask=enc_mask.squeeze(2), kv_mask=lm_tokens != self.lm_padding_value | |
| ) | |
| # Duration predictor | |
| log_durs_predicted = self.duration_predictor(enc_out, enc_mask) | |
| durs_predicted = torch.clamp(log_durs_predicted.exp() - 1, 0) | |
| # Pitch predictor | |
| pitch_predicted = self.pitch_predictor(enc_out, enc_mask) | |
| # Avg pitch, add pitch_emb | |
| if not self.training: | |
| if pitch is not None: | |
| pitch = average_features(pitch.unsqueeze(1), attn_hard_dur).squeeze(1) | |
| pitch_emb = self.pitch_emb(pitch.unsqueeze(1)) | |
| else: | |
| pitch_emb = self.pitch_emb(pitch_predicted.unsqueeze(1)) | |
| else: | |
| pitch = average_features(pitch.unsqueeze(1), attn_hard_dur).squeeze(1) | |
| pitch_emb = self.pitch_emb(pitch.unsqueeze(1)) | |
| enc_out = enc_out + pitch_emb.transpose(1, 2) | |
| if self.cond_on_lm_embeddings: | |
| enc_out = enc_out + lm_features | |
| # Regulate length | |
| len_regulated_enc_out, dec_lens = regulate_len(attn_hard_dur, enc_out) | |
| dec_out, dec_lens = self.decoder(len_regulated_enc_out, get_mask_from_lengths(dec_lens).unsqueeze(2)) | |
| pred_spect = self.proj(dec_out) | |
| return ( | |
| pred_spect, | |
| durs_predicted, | |
| log_durs_predicted, | |
| pitch_predicted, | |
| attn_soft, | |
| attn_logprob, | |
| attn_hard, | |
| attn_hard_dur, | |
| ) | |
| def infer( | |
| self, | |
| text, | |
| text_len=None, | |
| text_mask=None, | |
| spect=None, | |
| spect_len=None, | |
| attn_prior=None, | |
| use_gt_durs=False, | |
| lm_tokens=None, | |
| pitch=None, | |
| ): | |
| if text_mask is None: | |
| text_mask = get_mask_from_lengths(text_len).unsqueeze(2) | |
| enc_out, enc_mask = self.encoder(text, text_mask) | |
| # Aligner | |
| attn_hard_dur = None | |
| if use_gt_durs: | |
| attn_soft, attn_logprob, attn_hard, attn_hard_dur = self.run_aligner( | |
| text, text_len, text_mask, spect, spect_len, attn_prior | |
| ) | |
| if self.cond_on_lm_embeddings: | |
| lm_emb = self.lm_embeddings(lm_tokens) | |
| lm_features = self.self_attention_module( | |
| enc_out, lm_emb, lm_emb, q_mask=enc_mask.squeeze(2), kv_mask=lm_tokens != self.lm_padding_value | |
| ) | |
| # Duration predictor | |
| log_durs_predicted = self.duration_predictor(enc_out, enc_mask) | |
| durs_predicted = torch.clamp(log_durs_predicted.exp() - 1, 0) | |
| # Avg pitch, pitch predictor | |
| if use_gt_durs and pitch is not None: | |
| pitch = average_features(pitch.unsqueeze(1), attn_hard_dur).squeeze(1) | |
| pitch_emb = self.pitch_emb(pitch.unsqueeze(1)) | |
| else: | |
| pitch_predicted = self.pitch_predictor(enc_out, enc_mask) | |
| pitch_emb = self.pitch_emb(pitch_predicted.unsqueeze(1)) | |
| # Add pitch emb | |
| enc_out = enc_out + pitch_emb.transpose(1, 2) | |
| if self.cond_on_lm_embeddings: | |
| enc_out = enc_out + lm_features | |
| if use_gt_durs: | |
| if attn_hard_dur is not None: | |
| len_regulated_enc_out, dec_lens = regulate_len(attn_hard_dur, enc_out) | |
| else: | |
| raise NotImplementedError | |
| else: | |
| len_regulated_enc_out, dec_lens = regulate_len(durs_predicted, enc_out) | |
| dec_out, _ = self.decoder(len_regulated_enc_out, get_mask_from_lengths(dec_lens).unsqueeze(2)) | |
| pred_spect = self.proj(dec_out) | |
| return pred_spect | |
| def on_train_epoch_start(self): | |
| bin_loss_start_epoch = np.ceil(self.bin_loss_start_ratio * self._trainer.max_epochs) | |
| # Add bin loss when current_epoch >= bin_start_epoch | |
| if not self.add_bin_loss and self.current_epoch >= bin_loss_start_epoch: | |
| logging.info(f"Using hard attentions after epoch: {self.current_epoch}") | |
| self.add_bin_loss = True | |
| if self.add_bin_loss: | |
| self.bin_loss_scale = min((self.current_epoch - bin_loss_start_epoch) / self.bin_loss_warmup_epochs, 1.0) | |
| def training_step(self, batch, batch_idx): | |
| attn_prior, lm_tokens = None, None | |
| if self.cond_on_lm_embeddings: | |
| audio, audio_len, text, text_len, attn_prior, pitch, _, lm_tokens = batch | |
| else: | |
| audio, audio_len, text, text_len, attn_prior, pitch, _ = batch | |
| spect, spect_len = self.preprocessor(input_signal=audio, length=audio_len) | |
| # pitch normalization | |
| zero_pitch_idx = pitch == 0 | |
| pitch = (pitch - self.pitch_mean) / self.pitch_std | |
| pitch[zero_pitch_idx] = 0.0 | |
| ( | |
| pred_spect, | |
| _, | |
| pred_log_durs, | |
| pred_pitch, | |
| attn_soft, | |
| attn_logprob, | |
| attn_hard, | |
| attn_hard_dur, | |
| ) = self( | |
| text=text, | |
| text_len=text_len, | |
| pitch=pitch, | |
| spect=spect, | |
| spect_len=spect_len, | |
| attn_prior=attn_prior, | |
| lm_tokens=lm_tokens, | |
| ) | |
| ( | |
| loss, | |
| durs_loss, | |
| acc, | |
| acc_dist_1, | |
| acc_dist_3, | |
| pitch_loss, | |
| mel_loss, | |
| ctc_loss, | |
| bin_loss, | |
| ) = self._metrics( | |
| pred_durs=pred_log_durs, | |
| pred_pitch=pred_pitch, | |
| true_durs=attn_hard_dur, | |
| true_text_len=text_len, | |
| true_pitch=pitch, | |
| true_spect=spect, | |
| pred_spect=pred_spect, | |
| true_spect_len=spect_len, | |
| attn_logprob=attn_logprob, | |
| attn_soft=attn_soft, | |
| attn_hard=attn_hard, | |
| attn_hard_dur=attn_hard_dur, | |
| ) | |
| train_log = { | |
| 'train_loss': loss, | |
| 'train_durs_loss': durs_loss, | |
| 'train_pitch_loss': torch.tensor(1.0).to(durs_loss.device) if pitch_loss is None else pitch_loss, | |
| 'train_mel_loss': mel_loss, | |
| 'train_durs_acc': acc, | |
| 'train_durs_acc_dist_3': acc_dist_3, | |
| 'train_ctc_loss': torch.tensor(1.0).to(durs_loss.device) if ctc_loss is None else ctc_loss, | |
| 'train_bin_loss': torch.tensor(1.0).to(durs_loss.device) if bin_loss is None else bin_loss, | |
| } | |
| return {'loss': loss, 'progress_bar': {k: v.detach() for k, v in train_log.items()}, 'log': train_log} | |
| def validation_step(self, batch, batch_idx): | |
| attn_prior, lm_tokens = None, None | |
| if self.cond_on_lm_embeddings: | |
| audio, audio_len, text, text_len, attn_prior, pitch, _, lm_tokens = batch | |
| else: | |
| audio, audio_len, text, text_len, attn_prior, pitch, _ = batch | |
| spect, spect_len = self.preprocessor(input_signal=audio, length=audio_len) | |
| # pitch normalization | |
| zero_pitch_idx = pitch == 0 | |
| pitch = (pitch - self.pitch_mean) / self.pitch_std | |
| pitch[zero_pitch_idx] = 0.0 | |
| ( | |
| pred_spect, | |
| _, | |
| pred_log_durs, | |
| pred_pitch, | |
| attn_soft, | |
| attn_logprob, | |
| attn_hard, | |
| attn_hard_dur, | |
| ) = self( | |
| text=text, | |
| text_len=text_len, | |
| pitch=pitch, | |
| spect=spect, | |
| spect_len=spect_len, | |
| attn_prior=attn_prior, | |
| lm_tokens=lm_tokens, | |
| ) | |
| ( | |
| loss, | |
| durs_loss, | |
| acc, | |
| acc_dist_1, | |
| acc_dist_3, | |
| pitch_loss, | |
| mel_loss, | |
| ctc_loss, | |
| bin_loss, | |
| ) = self._metrics( | |
| pred_durs=pred_log_durs, | |
| pred_pitch=pred_pitch, | |
| true_durs=attn_hard_dur, | |
| true_text_len=text_len, | |
| true_pitch=pitch, | |
| true_spect=spect, | |
| pred_spect=pred_spect, | |
| true_spect_len=spect_len, | |
| attn_logprob=attn_logprob, | |
| attn_soft=attn_soft, | |
| attn_hard=attn_hard, | |
| attn_hard_dur=attn_hard_dur, | |
| ) | |
| # without ground truth internal features except for durations | |
| pred_spect, _, pred_log_durs, pred_pitch, attn_soft, attn_logprob, attn_hard, attn_hard_dur = self( | |
| text=text, | |
| text_len=text_len, | |
| pitch=None, | |
| spect=spect, | |
| spect_len=spect_len, | |
| attn_prior=attn_prior, | |
| lm_tokens=lm_tokens, | |
| ) | |
| *_, with_pred_features_mel_loss, _, _ = self._metrics( | |
| pred_durs=pred_log_durs, | |
| pred_pitch=pred_pitch, | |
| true_durs=attn_hard_dur, | |
| true_text_len=text_len, | |
| true_pitch=pitch, | |
| true_spect=spect, | |
| pred_spect=pred_spect, | |
| true_spect_len=spect_len, | |
| attn_logprob=attn_logprob, | |
| attn_soft=attn_soft, | |
| attn_hard=attn_hard, | |
| attn_hard_dur=attn_hard_dur, | |
| ) | |
| val_log = { | |
| 'val_loss': loss, | |
| 'val_durs_loss': durs_loss, | |
| 'val_pitch_loss': torch.tensor(1.0).to(durs_loss.device) if pitch_loss is None else pitch_loss, | |
| 'val_mel_loss': mel_loss, | |
| 'val_with_pred_features_mel_loss': with_pred_features_mel_loss, | |
| 'val_durs_acc': acc, | |
| 'val_durs_acc_dist_3': acc_dist_3, | |
| 'val_ctc_loss': torch.tensor(1.0).to(durs_loss.device) if ctc_loss is None else ctc_loss, | |
| 'val_bin_loss': torch.tensor(1.0).to(durs_loss.device) if bin_loss is None else bin_loss, | |
| } | |
| self.log_dict(val_log, prog_bar=False, on_epoch=True, logger=True, sync_dist=True) | |
| if batch_idx == 0 and self.current_epoch % 5 == 0 and isinstance(self.logger, WandbLogger): | |
| specs = [] | |
| pitches = [] | |
| for i in range(min(3, spect.shape[0])): | |
| specs += [ | |
| wandb.Image( | |
| plot_spectrogram_to_numpy(spect[i, :, : spect_len[i]].data.cpu().numpy()), | |
| caption=f"gt mel {i}", | |
| ), | |
| wandb.Image( | |
| plot_spectrogram_to_numpy(pred_spect.transpose(1, 2)[i, :, : spect_len[i]].data.cpu().numpy()), | |
| caption=f"pred mel {i}", | |
| ), | |
| ] | |
| pitches += [ | |
| wandb.Image( | |
| plot_pitch_to_numpy( | |
| average_features(pitch.unsqueeze(1), attn_hard_dur) | |
| .squeeze(1)[i, : text_len[i]] | |
| .data.cpu() | |
| .numpy(), | |
| ylim_range=[-2.5, 2.5], | |
| ), | |
| caption=f"gt pitch {i}", | |
| ), | |
| ] | |
| pitches += [ | |
| wandb.Image( | |
| plot_pitch_to_numpy(pred_pitch[i, : text_len[i]].data.cpu().numpy(), ylim_range=[-2.5, 2.5]), | |
| caption=f"pred pitch {i}", | |
| ), | |
| ] | |
| self.logger.experiment.log({"specs": specs, "pitches": pitches}) | |
| def generate_spectrogram( | |
| self, | |
| tokens: Optional[torch.Tensor] = None, | |
| tokens_len: Optional[torch.Tensor] = None, | |
| lm_tokens: Optional[torch.Tensor] = None, | |
| raw_texts: Optional[List[str]] = None, | |
| norm_text_for_lm_model: bool = True, | |
| lm_model: str = "albert", | |
| ): | |
| if tokens is not None: | |
| if tokens_len is None: | |
| # It is assumed that padding is consecutive and only at the end | |
| tokens_len = (tokens != self.tokenizer.pad).sum(dim=-1) | |
| else: | |
| if raw_texts is None: | |
| raise ValueError("raw_texts must be specified if tokens is None") | |
| t_seqs = [self.tokenizer(t) for t in raw_texts] | |
| tokens = torch.nn.utils.rnn.pad_sequence( | |
| sequences=[torch.tensor(t, dtype=torch.long, device=self.device) for t in t_seqs], | |
| batch_first=True, | |
| padding_value=self.tokenizer.pad, | |
| ) | |
| tokens_len = torch.tensor([len(t) for t in t_seqs], dtype=torch.long, device=tokens.device) | |
| if self.cond_on_lm_embeddings and lm_tokens is None: | |
| if raw_texts is None: | |
| raise ValueError("raw_texts must be specified if lm_tokens is None") | |
| lm_model_tokenizer = self._get_lm_model_tokenizer(lm_model) | |
| lm_padding_value = lm_model_tokenizer._convert_token_to_id('<pad>') | |
| lm_space_value = lm_model_tokenizer._convert_token_to_id('▁') | |
| assert isinstance(self.tokenizer, EnglishCharsTokenizer) or isinstance( | |
| self.tokenizer, EnglishPhonemesTokenizer | |
| ) | |
| if norm_text_for_lm_model and self.text_normalizer_call is not None: | |
| raw_texts = [self.text_normalizer_call(t, **self.text_normalizer_call_kwargs) for t in raw_texts] | |
| preprocess_texts_as_tts_input = [self.tokenizer.text_preprocessing_func(t) for t in raw_texts] | |
| lm_tokens_as_ids_list = [ | |
| lm_model_tokenizer.encode(t, add_special_tokens=False) for t in preprocess_texts_as_tts_input | |
| ] | |
| if self.tokenizer.pad_with_space: | |
| lm_tokens_as_ids_list = [[lm_space_value] + t + [lm_space_value] for t in lm_tokens_as_ids_list] | |
| lm_tokens = torch.full( | |
| (len(lm_tokens_as_ids_list), max([len(t) for t in lm_tokens_as_ids_list])), | |
| fill_value=lm_padding_value, | |
| device=tokens.device, | |
| ) | |
| for i, lm_tokens_i in enumerate(lm_tokens_as_ids_list): | |
| lm_tokens[i, : len(lm_tokens_i)] = torch.tensor(lm_tokens_i, device=tokens.device) | |
| pred_spect = self.infer(tokens, tokens_len, lm_tokens=lm_tokens).transpose(1, 2) | |
| return pred_spect | |
| def parse(self, text: str, normalize=True) -> torch.Tensor: | |
| if self.training: | |
| logging.warning("parse() is meant to be called in eval mode.") | |
| if normalize and self.text_normalizer_call is not None: | |
| text = self.text_normalizer_call(text, **self.text_normalizer_call_kwargs) | |
| eval_phon_mode = contextlib.nullcontext() | |
| if hasattr(self.tokenizer, "set_phone_prob"): | |
| eval_phon_mode = self.tokenizer.set_phone_prob(prob=1.0) | |
| with eval_phon_mode: | |
| tokens = self.tokenizer.encode(text) | |
| return torch.tensor(tokens).long().unsqueeze(0).to(self.device) | |
| def _loader(self, cfg): | |
| try: | |
| _ = cfg.dataset.manifest_filepath | |
| except omegaconf.errors.MissingMandatoryValue: | |
| logging.warning("manifest_filepath was skipped. No dataset for this model.") | |
| return None | |
| dataset = instantiate( | |
| cfg.dataset, | |
| text_normalizer=self.normalizer, | |
| text_normalizer_call_kwargs=self.text_normalizer_call_kwargs, | |
| text_tokenizer=self.tokenizer, | |
| ) | |
| return torch.utils.data.DataLoader( # noqa | |
| dataset=dataset, | |
| collate_fn=dataset.collate_fn, | |
| **cfg.dataloader_params, | |
| ) | |
| def setup_training_data(self, cfg): | |
| self._train_dl = self._loader(cfg) | |
| def setup_validation_data(self, cfg): | |
| self._validation_dl = self._loader(cfg) | |
| def setup_test_data(self, cfg): | |
| """Omitted.""" | |
| pass | |
| def list_available_models(cls) -> 'List[PretrainedModelInfo]': | |
| """ | |
| This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. | |
| Returns: | |
| List of available pre-trained models. | |
| """ | |
| list_of_models = [] | |
| model = PretrainedModelInfo( | |
| pretrained_model_name="tts_en_lj_mixertts", | |
| location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/tts_en_lj_mixertts/versions/1.6.0/files/tts_en_lj_mixertts.nemo", | |
| description="This model is trained on LJSpeech sampled at 22050Hz with and can be used to generate female English voices with an American accent.", | |
| class_=cls, # noqa | |
| ) | |
| list_of_models.append(model) | |
| model = PretrainedModelInfo( | |
| pretrained_model_name="tts_en_lj_mixerttsx", | |
| location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/tts_en_lj_mixerttsx/versions/1.6.0/files/tts_en_lj_mixerttsx.nemo", | |
| description="This model is trained on LJSpeech sampled at 22050Hz with and can be used to generate female English voices with an American accent.", | |
| class_=cls, # noqa | |
| ) | |
| list_of_models.append(model) | |
| return list_of_models | |
| # Methods for model exportability | |
| def input_types(self): | |
| return { | |
| "text": NeuralType(('B', 'T_text'), TokenIndex()), | |
| "lm_tokens": NeuralType(('B', 'T_lm_tokens'), TokenIndex(), optional=True), | |
| } | |
| def output_types(self): | |
| return { | |
| "spect": NeuralType(('B', 'D', 'T_spec'), MelSpectrogramType()), | |
| } | |
| def input_example(self, max_text_len=10, max_lm_tokens_len=10): | |
| text = torch.randint( | |
| low=0, | |
| high=len(self.tokenizer.tokens), | |
| size=(1, max_text_len), | |
| device=self.device, | |
| dtype=torch.long, | |
| ) | |
| inputs = {'text': text} | |
| if self.cond_on_lm_embeddings: | |
| inputs['lm_tokens'] = torch.randint( | |
| low=0, | |
| high=self.lm_embeddings.weight.shape[0], | |
| size=(1, max_lm_tokens_len), | |
| device=self.device, | |
| dtype=torch.long, | |
| ) | |
| return (inputs,) | |
| def forward_for_export(self, text, lm_tokens=None): | |
| text_mask = (text != self.tokenizer_pad).unsqueeze(2) | |
| spect = self.infer(text=text, text_mask=text_mask, lm_tokens=lm_tokens).transpose(1, 2) | |
| return spect.to(torch.float) | |