Spaces:
Runtime error
Runtime error
| # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from typing import Dict, Optional | |
| import torch | |
| from nemo.core.classes import NeuralModule, typecheck | |
| from nemo.core.neural_types import NeuralType, SpectrogramType | |
| class MixtureConsistencyProjection(NeuralModule): | |
| """Ensure estimated sources are consistent with the input mixture. | |
| Note that the input mixture is assume to be a single-channel signal. | |
| Args: | |
| weighting: Optional weighting mode for the consistency constraint. | |
| If `None`, use uniform weighting. If `power`, use the power of the | |
| estimated source as the weight. | |
| eps: Small positive value for regularization | |
| Reference: | |
| Wisdom et al, Differentiable consistency constraints for improved deep speech enhancement, 2018 | |
| """ | |
| def __init__(self, weighting: Optional[str] = None, eps: float = 1e-8): | |
| super().__init__() | |
| self.weighting = weighting | |
| self.eps = eps | |
| if self.weighting not in [None, 'power']: | |
| raise NotImplementedError(f'Weighting mode {self.weighting} not implemented') | |
| def input_types(self) -> Dict[str, NeuralType]: | |
| """Returns definitions of module output ports.""" | |
| return { | |
| "mixture": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), | |
| "estimate": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), | |
| } | |
| def output_types(self) -> Dict[str, NeuralType]: | |
| """Returns definitions of module output ports.""" | |
| return { | |
| "output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), | |
| } | |
| def forward(self, mixture: torch.Tensor, estimate: torch.Tensor) -> torch.Tensor: | |
| """Enforce mixture consistency on the estimated sources. | |
| Args: | |
| mixture: Single-channel mixture, shape (B, 1, F, N) | |
| estimate: M estimated sources, shape (B, M, F, N) | |
| Returns: | |
| Source estimates consistent with the mixture, shape (B, M, F, N) | |
| """ | |
| if mixture.size(-3) != 1: | |
| raise ValueError(f'Mixture must have a single channel, got shape {mixture.shape}') | |
| # number of sources | |
| M = estimate.size(-3) | |
| # estimated mixture based on the estimated sources | |
| estimated_mixture = torch.sum(estimate, dim=-3, keepdim=True) | |
| # weighting | |
| if self.weighting is None: | |
| weight = 1 / M | |
| elif self.weighting == 'power': | |
| weight = estimate.abs().pow(2) | |
| weight = weight / (weight.sum(dim=-3, keepdim=True) + self.eps) | |
| else: | |
| raise NotImplementedError(f'Weighting mode {self.weighting} not implemented') | |
| # consistent estimate | |
| consistent_estimate = estimate + weight * (mixture - estimated_mixture) | |
| return consistent_estimate | |