subhankarg's picture
Upload folder using huggingface_hub
0558aa4 verified
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Optional, Tuple
import torch
from nemo.collections.asr.modules.conformer_encoder import ConformerEncoder
from nemo.collections.asr.parts.preprocessing.features import make_seq_mask_like
from nemo.collections.audio.modules.features import SpectrogramToMultichannelFeatures
from nemo.collections.audio.parts.submodules.multichannel import (
ChannelAttentionPool,
ChannelAveragePool,
ParametricMultichannelWienerFilter,
TransformAttendConcatenate,
TransformAverageConcatenate,
WPEFilter,
)
from nemo.collections.audio.parts.utils.audio import db2mag
from nemo.core.classes import NeuralModule, typecheck
from nemo.core.neural_types import FloatType, LengthsType, NeuralType, SpectrogramType
from nemo.utils import logging
class MaskEstimatorRNN(NeuralModule):
"""Estimate `num_outputs` masks from the input spectrogram
using stacked RNNs and projections.
The module is structured as follows:
input --> spatial features --> input projection -->
--> stacked RNNs --> output projection for each output --> sigmoid
Reference:
Multi-microphone neural speech separation for far-field multi-talker
speech recognition (https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=8462081)
Args:
num_outputs: Number of output masks to estimate
num_subbands: Number of subbands of the input spectrogram
num_features: Number of features after the input projections
num_layers: Number of RNN layers
num_hidden_features: Number of hidden features in RNN layers
num_input_channels: Number of input channels
dropout: If non-zero, introduces dropout on the outputs of each RNN layer except the last layer, with dropout
probability equal to `dropout`. Default: 0
bidirectional: If `True`, use bidirectional RNN.
rnn_type: Type of RNN, either `lstm` or `gru`. Default: `lstm`
mag_reduction: Channel-wise reduction for magnitude features
use_ipd: Use inter-channel phase difference (IPD) features
"""
def __init__(
self,
num_outputs: int,
num_subbands: int,
num_features: int = 1024,
num_layers: int = 3,
num_hidden_features: Optional[int] = None,
num_input_channels: Optional[int] = None,
dropout: float = 0,
bidirectional=True,
rnn_type: str = 'lstm',
mag_reduction: str = 'rms',
use_ipd: bool = None,
):
super().__init__()
if num_hidden_features is None:
num_hidden_features = num_features
self.features = SpectrogramToMultichannelFeatures(
num_subbands=num_subbands,
num_input_channels=num_input_channels,
mag_reduction=mag_reduction,
use_ipd=use_ipd,
)
self.input_projection = torch.nn.Linear(
in_features=self.features.num_features * self.features.num_channels, out_features=num_features
)
if rnn_type == 'lstm':
self.rnn = torch.nn.LSTM(
input_size=num_features,
hidden_size=num_hidden_features,
num_layers=num_layers,
batch_first=True,
dropout=dropout,
bidirectional=bidirectional,
)
elif rnn_type == 'gru':
self.rnn = torch.nn.GRU(
input_size=num_features,
hidden_size=num_hidden_features,
num_layers=num_layers,
batch_first=True,
dropout=dropout,
bidirectional=bidirectional,
)
else:
raise ValueError(f'Unknown rnn_type: {rnn_type}')
self.fc = torch.nn.Linear(
in_features=2 * num_features if bidirectional else num_features, out_features=num_features
)
self.norm = torch.nn.LayerNorm(num_features)
# Each output shares the RNN and has a separate projection
self.output_projections = torch.nn.ModuleList(
[torch.nn.Linear(in_features=num_features, out_features=num_subbands) for _ in range(num_outputs)]
)
self.output_nonlinearity = torch.nn.Sigmoid()
@property
def input_types(self) -> Dict[str, NeuralType]:
"""Returns definitions of module output ports."""
return {
"input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
"input_length": NeuralType(('B',), LengthsType()),
}
@property
def output_types(self) -> Dict[str, NeuralType]:
"""Returns definitions of module output ports."""
return {
"output": NeuralType(('B', 'C', 'D', 'T'), FloatType()),
"output_length": NeuralType(('B',), LengthsType()),
}
@typecheck()
def forward(self, input: torch.Tensor, input_length: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Estimate `num_outputs` masks from the input spectrogram.
Args:
input: C-channel input, shape (B, C, F, N)
input_length: Length of valid entries along the time dimension, shape (B,)
Returns:
Returns `num_outputs` masks in a tensor, shape (B, num_outputs, F, N),
and output length with shape (B,)
"""
input, _ = self.features(input=input, input_length=input_length)
B, num_feature_channels, num_features, N = input.shape
# (B, num_feat_channels, num_feat, N) -> (B, N, num_feat_channels, num_feat)
input = input.permute(0, 3, 1, 2)
# (B, N, num_feat_channels, num_feat) -> (B, N, num_feat_channels * num_features)
input = input.view(B, N, -1)
# Apply projection on num_feat
input = self.input_projection(input)
# Apply RNN on the input sequence
input_packed = torch.nn.utils.rnn.pack_padded_sequence(
input, input_length.cpu(), batch_first=True, enforce_sorted=False
).to(input.device)
self.rnn.flatten_parameters()
input_packed, _ = self.rnn(input_packed)
output, output_length = torch.nn.utils.rnn.pad_packed_sequence(input_packed, batch_first=True)
output_length = output_length.to(input.device)
# Layer normalization and skip connection
output = self.norm(self.fc(output)) + input
# Create `num_outputs` masks
masks = []
for output_projection in self.output_projections:
# Output projection
mask = output_projection(output)
mask = self.output_nonlinearity(mask)
# Back to the original format
# (B, N, F) -> (B, F, N)
mask = mask.transpose(2, 1)
# Append to the output
masks.append(mask)
# Stack along channel dimension to get (B, M, F, N)
masks = torch.stack(masks, axis=1)
# Mask frames beyond output length
length_mask: torch.Tensor = make_seq_mask_like(
lengths=output_length, like=masks, time_dim=-1, valid_ones=False
)
masks = masks.masked_fill(length_mask, 0.0)
return masks, output_length
class MaskEstimatorFlexChannels(NeuralModule):
"""Estimate `num_outputs` masks from the input spectrogram
using stacked channel-wise and temporal layers.
This model is using interlaved channel blocks and temporal blocks, and
it can process arbitrary number of input channels.
Default channel block is the transform-average-concatenate layer.
Default temporal block is the Conformer encoder.
Reduction from multichannel signal to single-channel signal is performed
after `channel_reduction_position` blocks. Only temporal blocks are used afterwards.
After the sequence of blocks, the output mask is computed using an additional
output temporal layer and a nonlinearity.
References:
- Yoshioka et al, VarArray: Array-Geometry-Agnostic Continuous Speech Separation, 2022
- Jukić et al, Flexible multichannel speech enhancement for noise-robust frontend, 2023
Args:
num_outputs: Number of output masks.
num_subbands: Number of subbands on the input spectrogram.
num_blocks: Number of blocks in the model.
channel_reduction_position: After this block, the signal will be reduced across channels.
channel_reduction_type: Reduction across channels: 'average' or 'attention'
channel_block_type: Block for channel processing: 'transform_average_concatenate' or 'transform_attend_concatenate'
temporal_block_type: Block for temporal processing: 'conformer_encoder'
temporal_block_num_layers: Number of layers for the temporal block
temporal_block_num_heads: Number of heads for the temporal block
temporal_block_dimension: The hidden size of the model
temporal_block_self_attention_model: Self attention model for the temporal block
temporal_block_att_context_size: Attention context size for the temporal block
mag_reduction: Channel-wise reduction for magnitude features
mag_power: Power to apply on magnitude features
use_ipd: Use inter-channel phase difference (IPD) features
mag_normalization: Normalize using mean ('mean') or mean and variance ('mean_var')
ipd_normalization: Normalize using mean ('mean') or mean and variance ('mean_var')
"""
def __init__(
self,
num_outputs: int,
num_subbands: int,
num_blocks: int,
channel_reduction_position: int = -1, # if 0, apply before block 0, if -1 apply at the end
channel_reduction_type: str = 'attention',
channel_block_type: str = 'transform_attend_concatenate',
temporal_block_type: str = 'conformer_encoder',
temporal_block_num_layers: int = 5,
temporal_block_num_heads: int = 4,
temporal_block_dimension: int = 128,
temporal_block_self_attention_model: str = 'rel_pos',
temporal_block_att_context_size: Optional[List[int]] = None,
num_input_channels: Optional[int] = None,
mag_reduction: str = 'abs_mean',
mag_power: Optional[float] = None,
use_ipd: bool = True,
mag_normalization: Optional[str] = None,
ipd_normalization: Optional[str] = None,
):
super().__init__()
self.features = SpectrogramToMultichannelFeatures(
num_subbands=num_subbands,
num_input_channels=num_input_channels,
mag_reduction=mag_reduction,
mag_power=mag_power,
use_ipd=use_ipd,
mag_normalization=mag_normalization,
ipd_normalization=ipd_normalization,
)
self.num_blocks = num_blocks
logging.debug('Total number of blocks: %d', self.num_blocks)
# Channel reduction
if channel_reduction_position == -1:
# Apply reduction after the last layer
channel_reduction_position = num_blocks
if channel_reduction_position > num_blocks:
raise ValueError(
f'Channel reduction position {channel_reduction_position} exceeds the number of blocks {num_blocks}'
)
self.channel_reduction_position = channel_reduction_position
logging.debug('Channel reduction will be applied before block %d', self.channel_reduction_position)
# Prepare processing blocks
self.channel_blocks = torch.nn.ModuleList()
self.temporal_blocks = torch.nn.ModuleList()
for n in range(num_blocks):
logging.debug('Prepare block %d', n)
# Setup channel block
if n < channel_reduction_position:
# Number of input features is either the number of input channels or the number of temporal block features
channel_in_features = self.features.num_features if n == 0 else temporal_block_dimension
logging.debug(
'Setup channel block %s with %d input features and %d output features',
channel_block_type,
channel_in_features,
temporal_block_dimension,
)
# Instantiante the channel block
if channel_block_type == 'transform_average_concatenate':
channel_block = TransformAverageConcatenate(
in_features=channel_in_features, out_features=temporal_block_dimension
)
elif channel_block_type == 'transform_attend_concatenate':
channel_block = TransformAttendConcatenate(
in_features=channel_in_features, out_features=temporal_block_dimension
)
else:
raise ValueError(f'Unknown channel layer type: {channel_block_type}')
self.channel_blocks.append(channel_block)
# Setup temporal block
temporal_in_features = (
self.features.num_features if n == self.channel_reduction_position == 0 else temporal_block_dimension
)
logging.debug('Setup temporal block %s', temporal_block_type)
if temporal_block_type == 'conformer_encoder':
temporal_block = ConformerEncoder(
feat_in=temporal_in_features,
n_layers=temporal_block_num_layers,
d_model=temporal_block_dimension,
subsampling_factor=1,
self_attention_model=temporal_block_self_attention_model,
att_context_size=temporal_block_att_context_size,
n_heads=temporal_block_num_heads,
)
else:
raise ValueError(f'Unknown temporal block {temporal_block}.')
self.temporal_blocks.append(temporal_block)
logging.debug('Setup channel reduction %s', channel_reduction_type)
if channel_reduction_type == 'average':
# Mean across channel dimension
self.channel_reduction = ChannelAveragePool()
elif channel_reduction_type == 'attention':
# Number of input features is either the number of input channels or the number of temporal block features
channel_reduction_in_features = (
self.features.num_features if self.channel_reduction_position == 0 else temporal_block_dimension
)
# Attention across channel dimension
self.channel_reduction = ChannelAttentionPool(in_features=channel_reduction_in_features)
else:
raise ValueError(f'Unknown channel reduction type: {channel_reduction_type}')
logging.debug('Setup %d output layers', num_outputs)
self.output_layers = torch.nn.ModuleList(
[
ConformerEncoder(
feat_in=temporal_block_dimension,
n_layers=1,
d_model=temporal_block_dimension,
feat_out=num_subbands,
subsampling_factor=1,
self_attention_model=temporal_block_self_attention_model,
att_context_size=temporal_block_att_context_size,
n_heads=temporal_block_num_heads,
)
for _ in range(num_outputs)
]
)
# Output nonlinearity
self.output_nonlinearity = torch.nn.Sigmoid()
@property
def input_types(self) -> Dict[str, NeuralType]:
"""Returns definitions of module output ports."""
return {
"input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
"input_length": NeuralType(('B',), LengthsType()),
}
@property
def output_types(self) -> Dict[str, NeuralType]:
"""Returns definitions of module output ports."""
return {
"output": NeuralType(('B', 'C', 'D', 'T'), FloatType()),
"output_length": NeuralType(('B',), LengthsType()),
}
@typecheck()
def forward(self, input: torch.Tensor, input_length: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Estimate `num_outputs` masks from the input spectrogram."""
# get input features from a complex-valued spectrogram, (B, C, F, T)
output, output_length = self.features(input=input, input_length=input_length)
# batch and num channels
B, M = input.size(0), input.size(1)
# process all blocks
for n in range(self.num_blocks):
if n < self.channel_reduction_position:
# apply multichannel block
output = self.channel_blocks[n](input=output)
# change to a single-stream format
F, T = output.size(-2), output.size(-1)
# (B, M, F, T) -> (B * M, F, T)
output = output.reshape(-1, F, T)
if M > 1:
# adjust the lengths accordingly
output_length = output_length.repeat_interleave(M)
elif n == self.channel_reduction_position:
# apply channel reduction
# (B, M, F, T) -> (B, F, T)
output = self.channel_reduction(input=output)
# apply temporal model on each channel independently
with typecheck.disable_checks():
# output is AcousticEncodedRepresentation, conformer encoder requires SpectrogramType
output, output_length = self.temporal_blocks[n](audio_signal=output, length=output_length)
# if channel reduction has not been applied yet, go back to multichannel layout
if n < self.channel_reduction_position:
# back to multi-channel format with possibly a different number of features
T = output.size(-1)
# (B * M, F, T) -> (B, M, F, T)
output = output.reshape(B, M, -1, T)
if M > 1:
# convert lengths from single-stream format to original multichannel
output_length = output_length[0:-1:M]
if self.channel_reduction_position == self.num_blocks:
# apply channel reduction after the last layer
# (B, M, F, T) -> (B, F, T)
output = self.channel_reduction(input=output)
# final mask for each output
masks = []
for output_layer in self.output_layers:
# calculate mask
with typecheck.disable_checks():
# output is AcousticEncodedRepresentation, conformer encoder requires SpectrogramType
mask, mask_length = output_layer(audio_signal=output, length=output_length)
mask = self.output_nonlinearity(mask)
# append to all masks
masks.append(mask)
# stack masks along channel dimensions
masks = torch.stack(masks, dim=1)
return masks, mask_length
class MaskEstimatorGSS(NeuralModule):
"""Estimate masks using guided source separation with a complex
angular Central Gaussian Mixture Model (cACGMM) [1].
This module corresponds to `GSS` in Fig. 2 in [2].
Notation is approximately following [1], where `gamma` denotes
the time-frequency mask, `alpha` denotes the mixture weights,
and `BM` denotes the shape matrix. Additionally, the provided
source activity is denoted as `activity`.
Args:
num_iterations: Number of iterations for the EM algorithm
eps: Small value for regularization
dtype: Data type for internal computations (default `torch.cdouble`)
References:
[1] Ito et al., Complex Angular Central Gaussian Mixture Model for Directional Statistics in Mask-Based Microphone Array Signal Processing, 2016
[2] Boeddeker et al., Front-End Processing for the CHiME-5 Dinner Party Scenario, 2018
"""
def __init__(self, num_iterations: int = 3, eps: float = 1e-8, dtype: torch.dtype = torch.cdouble):
super().__init__()
if num_iterations <= 0:
raise ValueError(f'Number of iterations must be positive, got {num_iterations}')
# number of iterations for the EM algorithm
self.num_iterations = num_iterations
if eps <= 0:
raise ValueError(f'eps must be positive, got {eps}')
# small regularization constant
self.eps = eps
# internal calculations
if dtype not in [torch.cfloat, torch.cdouble]:
raise ValueError(f'Unsupported dtype {dtype}, expecting cfloat or cdouble')
self.dtype = dtype
logging.debug('Initialized %s', self.__class__.__name__)
logging.debug('\tnum_iterations: %s', self.num_iterations)
logging.debug('\teps: %g', self.eps)
logging.debug('\tdtype: %s', self.dtype)
def normalize(self, x: torch.Tensor, dim: int = 1) -> torch.Tensor:
"""Normalize input to have a unit L2-norm across `dim`.
By default, normalizes across the input channels.
Args:
x: C-channel input signal, shape (B, C, F, T)
dim: Dimension for normalization, defaults to -3 to normalize over channels
Returns:
Normalized signal, shape (B, C, F, T)
"""
norm_x = torch.linalg.vector_norm(x, ord=2, dim=dim, keepdim=True)
x = x / (norm_x + self.eps)
return x
@typecheck(
input_types={
'alpha': NeuralType(('B', 'C', 'D')),
'activity': NeuralType(('B', 'C', 'T')),
'log_pdf': NeuralType(('B', 'C', 'D', 'T')),
},
output_types={
'gamma': NeuralType(('B', 'C', 'D', 'T')),
},
)
def update_masks(self, alpha: torch.Tensor, activity: torch.Tensor, log_pdf: torch.Tensor) -> torch.Tensor:
"""Update masks for the cACGMM.
Args:
alpha: component weights, shape (B, num_outputs, F)
activity: temporal activity for the components, shape (B, num_outputs, T)
log_pdf: logarithm of the PDF, shape (B, num_outputs, F, T)
Returns:
Masks for the components of the model, shape (B, num_outputs, F, T)
"""
# (B, num_outputs, F)
# normalize across outputs in the log domain
log_gamma = log_pdf - torch.max(log_pdf, axis=-3, keepdim=True)[0]
gamma = torch.exp(log_gamma)
# calculate the mask using weight, pdf and source activity
gamma = alpha[..., None] * gamma * activity[..., None, :]
# normalize across components/output channels
gamma = gamma / (torch.sum(gamma, dim=-3, keepdim=True) + self.eps)
return gamma
@typecheck(
input_types={
'gamma': NeuralType(('B', 'C', 'D', 'T')),
},
output_types={
'alpha': NeuralType(('B', 'C', 'D')),
},
)
def update_weights(self, gamma: torch.Tensor) -> torch.Tensor:
"""Update weights for the individual components
in the mixture model.
Args:
gamma: masks, shape (B, num_outputs, F, T)
Returns:
Component weights, shape (B, num_outputs, F)
"""
alpha = torch.mean(gamma, dim=-1)
return alpha
@typecheck(
input_types={
'z': NeuralType(('B', 'C', 'D', 'T')),
'gamma': NeuralType(('B', 'C', 'D', 'T')),
'zH_invBM_z': NeuralType(('B', 'C', 'D', 'T')),
},
output_types={
'log_pdf': NeuralType(('B', 'C', 'D', 'T')),
'zH_invBM_z': NeuralType(('B', 'C', 'D', 'T')),
},
)
def update_pdf(
self, z: torch.Tensor, gamma: torch.Tensor, zH_invBM_z: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Update PDF of the cACGMM.
Args:
z: directional statistics, shape (B, num_inputs, F, T)
gamma: masks, shape (B, num_outputs, F, T)
zH_invBM_z: energy weighted by shape matrices, shape (B, num_outputs, F, T)
Returns:
Logarithm of the PDF, shape (B, num_outputs, F, T), the energy term, shape (B, num_outputs, F, T)
"""
num_inputs = z.size(-3)
# shape (B, num_outputs, F, T)
scale = gamma / (zH_invBM_z + self.eps)
# scale outer product and sum over time
# shape (B, num_outputs, F, num_inputs, num_inputs)
BM = num_inputs * torch.einsum('bmft,bift,bjft->bmfij', scale.to(z.dtype), z, z.conj())
# normalize across time
denom = torch.sum(gamma, dim=-1)
BM = BM / (denom[..., None, None] + self.eps)
# make sure the matrix is Hermitian
BM = (BM + BM.conj().transpose(-1, -2)) / 2
# use eigenvalue decomposition to calculate the log determinant
# and the inverse-weighted energy term
L, Q = torch.linalg.eigh(BM)
# BM is positive definite, so all eigenvalues should be positive
# However, small negative values may occur due to a limited precision
L = torch.clamp(L.real, min=self.eps)
# PDF is invariant to scaling of the shape matrix [1], so
# eignevalues can be normalized (across num_inputs)
L = L / (torch.max(L, axis=-1, keepdim=True)[0] + self.eps)
# small regularization to avoid numerical issues
L = L + self.eps
# calculate the log determinant using the eigenvalues
log_detBM = torch.sum(torch.log(L), dim=-1)
# calculate the energy term using the inverse eigenvalues
# NOTE: keeping an alternative implementation for reference (slower)
# zH_invBM_z = torch.einsum('bift,bmfij,bmfj,bmfkj,bkft->bmft', z.conj(), Q, (1 / L).to(Q.dtype), Q.conj(), z)
# zH_invBM_z = zH_invBM_z.abs() + self.eps # small regularization
# calc sqrt(L) * Q^H * z
zH_invBM_z = torch.einsum('bmfj,bmfkj,bkft->bmftj', (1 / L.sqrt()).to(Q.dtype), Q.conj(), z)
# calc squared norm
zH_invBM_z = zH_invBM_z.abs().pow(2).sum(-1)
# small regularization
zH_invBM_z = zH_invBM_z + self.eps
# final log PDF
log_pdf = -num_inputs * torch.log(zH_invBM_z) - log_detBM[..., None]
return log_pdf, zH_invBM_z
@property
def input_types(self) -> Dict[str, NeuralType]:
"""Returns definitions of module output ports."""
return {
"input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
"activity": NeuralType(('B', 'C', 'T')),
}
@property
def output_types(self) -> Dict[str, NeuralType]:
"""Returns definitions of module output ports."""
return {
"gamma": NeuralType(('B', 'C', 'D', 'T')),
}
@typecheck()
def forward(self, input: torch.Tensor, activity: torch.Tensor) -> torch.Tensor:
"""Apply GSS to estimate the time-frequency masks for each output source.
Args:
input: batched C-channel input signal, shape (B, num_inputs, F, T)
activity: batched frame-wise activity for each output source, shape (B, num_outputs, T)
Returns:
Masks for the components of the model, shape (B, num_outputs, F, T)
"""
B, num_inputs, F, T = input.shape
num_outputs = activity.size(1)
device = input.device.type
if activity.size(0) != B:
raise ValueError(f'Batch dimension mismatch: activity {activity.shape} vs input {input.shape}')
if activity.size(-1) != T:
raise ValueError(f'Time dimension mismatch: activity {activity.shape} vs input {input.shape}')
if num_outputs == 1:
raise ValueError(f'Expecting multiple outputs, got {num_outputs}')
with torch.amp.autocast(device, enabled=False):
input = input.to(dtype=self.dtype)
assert input.is_complex(), f'Expecting complex input, got {input.dtype}'
# convert input to directional statistics by normalizing across channels
z = self.normalize(input, dim=-3)
# initialize masks
gamma = torch.clamp(activity, min=self.eps)
# normalize across channels
gamma = gamma / torch.sum(gamma, dim=-2, keepdim=True)
# expand to input shape
gamma = gamma.unsqueeze(2).expand(-1, -1, F, -1)
# initialize the energy term
zH_invBM_z = torch.ones(B, num_outputs, F, T, dtype=input.dtype, device=input.device)
# EM iterations
for it in range(self.num_iterations):
alpha = self.update_weights(gamma=gamma)
log_pdf, zH_invBM_z = self.update_pdf(z=z, gamma=gamma, zH_invBM_z=zH_invBM_z)
gamma = self.update_masks(alpha=alpha, activity=activity, log_pdf=log_pdf)
if torch.any(torch.isnan(gamma)):
raise RuntimeError(f'gamma contains NaNs: {gamma}')
return gamma
class MaskReferenceChannel(NeuralModule):
"""A simple mask processor which applies mask
on ref_channel of the input signal.
Args:
ref_channel: Index of the reference channel.
mask_min_db: Threshold mask to a minimal value before applying it, defaults to -200dB
mask_max_db: Threshold mask to a maximal value before applying it, defaults to 0dB
"""
def __init__(self, ref_channel: int = 0, mask_min_db: float = -200, mask_max_db: float = 0):
super().__init__()
self.ref_channel = ref_channel
# Mask thresholding
self.mask_min = db2mag(mask_min_db)
self.mask_max = db2mag(mask_max_db)
logging.debug('Initialized %s with', self.__class__.__name__)
logging.debug('\tref_channel: %d', self.ref_channel)
logging.debug('\tmask_min: %f', self.mask_min)
logging.debug('\tmask_max: %f', self.mask_max)
@property
def input_types(self) -> Dict[str, NeuralType]:
"""Returns definitions of module output ports."""
return {
"input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
"input_length": NeuralType(('B',), LengthsType()),
"mask": NeuralType(('B', 'C', 'D', 'T'), FloatType()),
}
@property
def output_types(self) -> Dict[str, NeuralType]:
"""Returns definitions of module output ports."""
return {
"output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
"output_length": NeuralType(('B',), LengthsType()),
}
@typecheck()
def forward(
self,
input: torch.Tensor,
input_length: torch.Tensor,
mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply mask on `ref_channel` of the input signal.
This can be used to generate multi-channel output.
If `mask` has `M` channels, the output will have `M` channels as well.
Args:
input: Input signal complex-valued spectrogram, shape (B, C, F, N)
input_length: Length of valid entries along the time dimension, shape (B,)
mask: Mask for M outputs, shape (B, M, F, N)
Returns:
M-channel output complex-valed spectrogram with shape (B, M, F, N)
"""
# Apply thresholds
mask = torch.clamp(mask, min=self.mask_min, max=self.mask_max)
# Apply each output mask on the ref channel
output = mask * input[:, self.ref_channel : self.ref_channel + 1, ...]
return output, input_length
class MaskBasedBeamformer(NeuralModule):
"""Multi-channel processor using masks to estimate signal statistics.
Args:
filter_type: string denoting the type of the filter. Defaults to `mvdr`
filter_beta: Parameter of the parameteric multichannel Wiener filter
filter_rank: Parameter of the parametric multichannel Wiener filter
filter_postfilter: Optional, postprocessing of the filter
ref_channel: Optional, reference channel. If None, it will be estimated automatically
ref_hard: If true, hard (one-hot) reference. If false, a soft reference
ref_hard_use_grad: If true, use straight-through gradient when using the hard reference
ref_subband_weighting: If true, use subband weighting when estimating reference channel
num_subbands: Optional, used to determine the parameter size for reference estimation
mask_min_db: Threshold mask to a minimal value before applying it, defaults to -200dB
mask_max_db: Threshold mask to a maximal value before applying it, defaults to 0dB
diag_reg: Optional, diagonal regularization for the multichannel filter
eps: Small regularization constant to avoid division by zero
"""
def __init__(
self,
filter_type: str = 'mvdr_souden',
filter_beta: float = 0.0,
filter_rank: str = 'one',
filter_postfilter: Optional[str] = None,
ref_channel: Optional[int] = 0,
ref_hard: bool = True,
ref_hard_use_grad: bool = False,
ref_subband_weighting: bool = False,
num_subbands: Optional[int] = None,
mask_min_db: float = -200,
mask_max_db: float = 0,
postmask_min_db: float = 0,
postmask_max_db: float = 0,
diag_reg: Optional[float] = 1e-6,
eps: float = 1e-8,
):
super().__init__()
if filter_type not in ['pmwf', 'mvdr_souden']:
raise ValueError(f'Unknown filter type {filter_type}')
self.filter_type = filter_type
if self.filter_type == 'mvdr_souden' and filter_beta != 0:
logging.warning(
'Using filter type %s: beta will be automatically set to zero (current beta %f) and rank to one (current rank %s).',
self.filter_type,
filter_beta,
filter_rank,
)
filter_beta = 0.0
filter_rank = 'one'
# Prepare filter
self.filter = ParametricMultichannelWienerFilter(
beta=filter_beta,
rank=filter_rank,
postfilter=filter_postfilter,
ref_channel=ref_channel,
ref_hard=ref_hard,
ref_hard_use_grad=ref_hard_use_grad,
ref_subband_weighting=ref_subband_weighting,
num_subbands=num_subbands,
diag_reg=diag_reg,
eps=eps,
)
# Mask thresholding
if mask_min_db >= mask_max_db:
raise ValueError(
f'Lower bound for the mask {mask_min_db}dB must be smaller than the upper bound {mask_max_db}dB'
)
self.mask_min = db2mag(mask_min_db)
self.mask_max = db2mag(mask_max_db)
# Postmask thresholding
if postmask_min_db > postmask_max_db:
raise ValueError(
f'Lower bound for the postmask {postmask_min_db}dB must be smaller or equal to the upper bound {postmask_max_db}dB'
)
self.postmask_min = db2mag(postmask_min_db)
self.postmask_max = db2mag(postmask_max_db)
logging.debug('Initialized %s', self.__class__.__name__)
logging.debug('\tfilter_type: %s', self.filter_type)
logging.debug('\tmask_min: %e', self.mask_min)
logging.debug('\tmask_max: %e', self.mask_max)
logging.debug('\tpostmask_min: %e', self.postmask_min)
logging.debug('\tpostmask_max: %e', self.postmask_max)
@property
def input_types(self) -> Dict[str, NeuralType]:
"""Returns definitions of module output ports."""
return {
"input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
"mask": NeuralType(('B', 'C', 'D', 'T'), FloatType()),
"mask_undesired": NeuralType(('B', 'C', 'D', 'T'), FloatType(), optional=True),
"input_length": NeuralType(('B',), LengthsType(), optional=True),
}
@property
def output_types(self) -> Dict[str, NeuralType]:
"""Returns definitions of module output ports."""
return {
"output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
"output_length": NeuralType(('B',), LengthsType(), optional=True),
}
@typecheck()
def forward(
self,
input: torch.Tensor,
mask: torch.Tensor,
mask_undesired: Optional[torch.Tensor] = None,
input_length: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Apply a mask-based beamformer to the input spectrogram.
This can be used to generate multi-channel output.
If `mask` has multiple channels, a multichannel filter is created for each mask,
and the output is concatenation of individual outputs along the channel dimension.
The total number of outputs is `num_masks * M`, where `M` is the number of channels
at the filter output.
Args:
input: Input signal complex-valued spectrogram, shape (B, C, F, N)
mask: Mask for M output signals, shape (B, num_masks, F, N)
input_length: Length of valid entries along the time dimension, shape (B,)
Returns:
Multichannel output signal complex-valued spectrogram, shape (B, num_masks * M, F, N)
"""
# Length mask
if input_length is not None:
length_mask: torch.Tensor = make_seq_mask_like(
lengths=input_length, like=mask[:, 0, ...], time_dim=-1, valid_ones=False
)
# Use each mask to generate an output
output, num_masks = [], mask.size(1)
for m in range(num_masks):
# Desired signal mask
mask_d = mask[:, m, ...]
# Undesired signal mask
if mask_undesired is not None:
mask_u = mask_undesired[:, m, ...]
elif num_masks == 1:
# If a single mask is estimated, use the complement
mask_u = 1 - mask_d
else:
# Use sum of all other sources
mask_u = torch.sum(mask, dim=1) - mask_d
# Threshold masks
mask_d = torch.clamp(mask_d, min=self.mask_min, max=self.mask_max)
mask_u = torch.clamp(mask_u, min=self.mask_min, max=self.mask_max)
if input_length is not None:
mask_d = mask_d.masked_fill(length_mask, 0.0)
mask_u = mask_u.masked_fill(length_mask, 0.0)
# Apply filter
output_m = self.filter(input=input, mask_s=mask_d, mask_n=mask_u)
# Optional: apply a postmask with min and max thresholds
if self.postmask_min < self.postmask_max:
postmask_m = torch.clamp(mask[:, m, ...], min=self.postmask_min, max=self.postmask_max)
output_m = output_m * postmask_m.unsqueeze(1)
# Save the current output (B, M, F, T)
output.append(output_m)
# Combine outputs along the channel dimension
# Each output is (B, M, F, T)
output = torch.concatenate(output, axis=1)
# Apply masking
if input_length is not None:
output = output.masked_fill(length_mask[:, None, ...], 0.0)
return output, input_length
class MaskBasedDereverbWPE(NeuralModule):
"""Multi-channel linear prediction-based dereverberation using
weighted prediction error for filter estimation.
An optional mask to estimate the signal power can be provided.
If a time-frequency mask is not provided, the algorithm corresponds
to the conventional WPE algorithm.
Args:
filter_length: Length of the convolutional filter for each channel in frames.
prediction_delay: Delay of the input signal for multi-channel linear prediction in frames.
num_iterations: Number of iterations for reweighting
mask_min_db: Threshold mask to a minimal value before applying it, defaults to -200dB
mask_max_db: Threshold mask to a minimal value before applying it, defaults to 0dB
diag_reg: Diagonal regularization for WPE
eps: Small regularization constant
dtype: Data type for internal computations
References:
- Kinoshita et al, Neural network-based spectrum estimation for online WPE dereverberation, 2017
- Yoshioka and Nakatani, Generalization of Multi-Channel Linear Prediction Methods for Blind MIMO Impulse Response Shortening, 2012
"""
def __init__(
self,
filter_length: int,
prediction_delay: int,
num_iterations: int = 1,
mask_min_db: float = -200,
mask_max_db: float = 0,
diag_reg: Optional[float] = 1e-6,
eps: float = 1e-8,
dtype: torch.dtype = torch.cdouble,
):
super().__init__()
# Filter setup
self.filter = WPEFilter(
filter_length=filter_length, prediction_delay=prediction_delay, diag_reg=diag_reg, eps=eps
)
self.num_iterations = num_iterations
# Mask thresholding
self.mask_min = db2mag(mask_min_db)
self.mask_max = db2mag(mask_max_db)
# Internal calculations
if dtype not in [torch.cfloat, torch.cdouble]:
raise ValueError(f'Unsupported dtype {dtype}, expecting torch.cfloat or torch.cdouble')
self.dtype = dtype
logging.debug('Initialized %s', self.__class__.__name__)
logging.debug('\tnum_iterations: %s', self.num_iterations)
logging.debug('\tmask_min: %g', self.mask_min)
logging.debug('\tmask_max: %g', self.mask_max)
logging.debug('\tdtype: %s', self.dtype)
@property
def input_types(self) -> Dict[str, NeuralType]:
"""Returns definitions of module output ports."""
return {
"input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
"input_length": NeuralType(('B',), LengthsType(), optional=True),
"mask": NeuralType(('B', 'C', 'D', 'T'), FloatType(), optional=True),
}
@property
def output_types(self) -> Dict[str, NeuralType]:
"""Returns definitions of module output ports."""
return {
"output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
"output_length": NeuralType(('B',), LengthsType(), optional=True),
}
@typecheck()
def forward(
self, input: torch.Tensor, input_length: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Given an input signal `input`, apply the WPE dereverberation algoritm.
Args:
input: C-channel complex-valued spectrogram, shape (B, C, F, T)
input_length: Optional length for each signal in the batch, shape (B,)
mask: Optional mask, shape (B, 1, F, N) or (B, C, F, T)
Returns:
Processed tensor with the same number of channels as the input,
shape (B, C, F, T).
"""
io_dtype = input.dtype
device = input.device.type
with torch.amp.autocast(device, enabled=False):
output = input.to(dtype=self.dtype)
if not output.is_complex():
raise RuntimeError(f'Expecting complex input, got {output.dtype}')
for i in range(self.num_iterations):
magnitude = torch.abs(output)
if i == 0 and mask is not None:
# Apply thresholds
mask = torch.clamp(mask, min=self.mask_min, max=self.mask_max)
# Mask magnitude
magnitude = mask * magnitude
# Calculate power
power = magnitude**2
# Apply filter
output, output_length = self.filter(input=output, input_length=input_length, power=power)
return output.to(io_dtype), output_length