File size: 44,486 Bytes
0558aa4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, Optional, Tuple

import torch

from nemo.collections.asr.modules.conformer_encoder import ConformerEncoder
from nemo.collections.asr.parts.preprocessing.features import make_seq_mask_like
from nemo.collections.audio.modules.features import SpectrogramToMultichannelFeatures
from nemo.collections.audio.parts.submodules.multichannel import (
    ChannelAttentionPool,
    ChannelAveragePool,
    ParametricMultichannelWienerFilter,
    TransformAttendConcatenate,
    TransformAverageConcatenate,
    WPEFilter,
)
from nemo.collections.audio.parts.utils.audio import db2mag
from nemo.core.classes import NeuralModule, typecheck
from nemo.core.neural_types import FloatType, LengthsType, NeuralType, SpectrogramType
from nemo.utils import logging


class MaskEstimatorRNN(NeuralModule):
    """Estimate `num_outputs` masks from the input spectrogram
    using stacked RNNs and projections.

    The module is structured as follows:
        input --> spatial features --> input projection -->
            --> stacked RNNs --> output projection for each output --> sigmoid

    Reference:
        Multi-microphone neural speech separation for far-field multi-talker
        speech recognition (https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=8462081)

    Args:
        num_outputs: Number of output masks to estimate
        num_subbands: Number of subbands of the input spectrogram
        num_features: Number of features after the input projections
        num_layers: Number of RNN layers
        num_hidden_features: Number of hidden features in RNN layers
        num_input_channels: Number of input channels
        dropout: If non-zero, introduces dropout on the outputs of each RNN layer except the last layer, with dropout
                 probability equal to `dropout`. Default: 0
        bidirectional: If `True`, use bidirectional RNN.
        rnn_type: Type of RNN, either `lstm` or `gru`. Default: `lstm`
        mag_reduction: Channel-wise reduction for magnitude features
        use_ipd: Use inter-channel phase difference (IPD) features
    """

    def __init__(
        self,
        num_outputs: int,
        num_subbands: int,
        num_features: int = 1024,
        num_layers: int = 3,
        num_hidden_features: Optional[int] = None,
        num_input_channels: Optional[int] = None,
        dropout: float = 0,
        bidirectional=True,
        rnn_type: str = 'lstm',
        mag_reduction: str = 'rms',
        use_ipd: bool = None,
    ):
        super().__init__()
        if num_hidden_features is None:
            num_hidden_features = num_features

        self.features = SpectrogramToMultichannelFeatures(
            num_subbands=num_subbands,
            num_input_channels=num_input_channels,
            mag_reduction=mag_reduction,
            use_ipd=use_ipd,
        )

        self.input_projection = torch.nn.Linear(
            in_features=self.features.num_features * self.features.num_channels, out_features=num_features
        )

        if rnn_type == 'lstm':
            self.rnn = torch.nn.LSTM(
                input_size=num_features,
                hidden_size=num_hidden_features,
                num_layers=num_layers,
                batch_first=True,
                dropout=dropout,
                bidirectional=bidirectional,
            )
        elif rnn_type == 'gru':
            self.rnn = torch.nn.GRU(
                input_size=num_features,
                hidden_size=num_hidden_features,
                num_layers=num_layers,
                batch_first=True,
                dropout=dropout,
                bidirectional=bidirectional,
            )
        else:
            raise ValueError(f'Unknown rnn_type: {rnn_type}')

        self.fc = torch.nn.Linear(
            in_features=2 * num_features if bidirectional else num_features, out_features=num_features
        )
        self.norm = torch.nn.LayerNorm(num_features)

        # Each output shares the RNN and has a separate projection
        self.output_projections = torch.nn.ModuleList(
            [torch.nn.Linear(in_features=num_features, out_features=num_subbands) for _ in range(num_outputs)]
        )
        self.output_nonlinearity = torch.nn.Sigmoid()

    @property
    def input_types(self) -> Dict[str, NeuralType]:
        """Returns definitions of module output ports."""
        return {
            "input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
            "input_length": NeuralType(('B',), LengthsType()),
        }

    @property
    def output_types(self) -> Dict[str, NeuralType]:
        """Returns definitions of module output ports."""
        return {
            "output": NeuralType(('B', 'C', 'D', 'T'), FloatType()),
            "output_length": NeuralType(('B',), LengthsType()),
        }

    @typecheck()
    def forward(self, input: torch.Tensor, input_length: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Estimate `num_outputs` masks from the input spectrogram.

        Args:
            input: C-channel input, shape (B, C, F, N)
            input_length: Length of valid entries along the time dimension, shape (B,)

        Returns:
            Returns `num_outputs` masks in a tensor, shape (B, num_outputs, F, N),
            and output length with shape (B,)
        """
        input, _ = self.features(input=input, input_length=input_length)
        B, num_feature_channels, num_features, N = input.shape

        # (B, num_feat_channels, num_feat, N) -> (B, N, num_feat_channels, num_feat)
        input = input.permute(0, 3, 1, 2)

        # (B, N, num_feat_channels, num_feat) -> (B, N, num_feat_channels * num_features)
        input = input.view(B, N, -1)

        # Apply projection on num_feat
        input = self.input_projection(input)

        # Apply RNN on the input sequence
        input_packed = torch.nn.utils.rnn.pack_padded_sequence(
            input, input_length.cpu(), batch_first=True, enforce_sorted=False
        ).to(input.device)
        self.rnn.flatten_parameters()
        input_packed, _ = self.rnn(input_packed)
        output, output_length = torch.nn.utils.rnn.pad_packed_sequence(input_packed, batch_first=True)
        output_length = output_length.to(input.device)

        # Layer normalization and skip connection
        output = self.norm(self.fc(output)) + input

        # Create `num_outputs` masks
        masks = []
        for output_projection in self.output_projections:
            # Output projection
            mask = output_projection(output)
            mask = self.output_nonlinearity(mask)

            # Back to the original format
            # (B, N, F) -> (B, F, N)
            mask = mask.transpose(2, 1)

            # Append to the output
            masks.append(mask)

        # Stack along channel dimension to get (B, M, F, N)
        masks = torch.stack(masks, axis=1)

        # Mask frames beyond output length
        length_mask: torch.Tensor = make_seq_mask_like(
            lengths=output_length, like=masks, time_dim=-1, valid_ones=False
        )
        masks = masks.masked_fill(length_mask, 0.0)

        return masks, output_length


class MaskEstimatorFlexChannels(NeuralModule):
    """Estimate `num_outputs` masks from the input spectrogram
    using stacked channel-wise and temporal layers.

    This model is using interlaved channel blocks and temporal blocks, and
    it can process arbitrary number of input channels.
    Default channel block is the transform-average-concatenate layer.
    Default temporal block is the Conformer encoder.
    Reduction from multichannel signal to single-channel signal is performed
    after `channel_reduction_position` blocks. Only temporal blocks are used afterwards.
    After the sequence of blocks, the output mask is computed using an additional
    output temporal layer and a nonlinearity.

    References:
        - Yoshioka et al, VarArray: Array-Geometry-Agnostic Continuous Speech Separation, 2022
        - Jukić et al, Flexible multichannel speech enhancement for noise-robust frontend, 2023

    Args:
        num_outputs: Number of output masks.
        num_subbands: Number of subbands on the input spectrogram.
        num_blocks: Number of blocks in the model.
        channel_reduction_position: After this block, the signal will be reduced across channels.
        channel_reduction_type: Reduction across channels: 'average' or 'attention'
        channel_block_type: Block for channel processing: 'transform_average_concatenate' or 'transform_attend_concatenate'
        temporal_block_type: Block for temporal processing: 'conformer_encoder'
        temporal_block_num_layers: Number of layers for the temporal block
        temporal_block_num_heads: Number of heads for the temporal block
        temporal_block_dimension: The hidden size of the model
        temporal_block_self_attention_model: Self attention model for the temporal block
        temporal_block_att_context_size: Attention context size for the temporal block
        mag_reduction: Channel-wise reduction for magnitude features
        mag_power: Power to apply on magnitude features
        use_ipd: Use inter-channel phase difference (IPD) features
        mag_normalization: Normalize using mean ('mean') or mean and variance ('mean_var')
        ipd_normalization: Normalize using mean ('mean') or mean and variance ('mean_var')
    """

    def __init__(
        self,
        num_outputs: int,
        num_subbands: int,
        num_blocks: int,
        channel_reduction_position: int = -1,  # if 0, apply before block 0, if -1 apply at the end
        channel_reduction_type: str = 'attention',
        channel_block_type: str = 'transform_attend_concatenate',
        temporal_block_type: str = 'conformer_encoder',
        temporal_block_num_layers: int = 5,
        temporal_block_num_heads: int = 4,
        temporal_block_dimension: int = 128,
        temporal_block_self_attention_model: str = 'rel_pos',
        temporal_block_att_context_size: Optional[List[int]] = None,
        num_input_channels: Optional[int] = None,
        mag_reduction: str = 'abs_mean',
        mag_power: Optional[float] = None,
        use_ipd: bool = True,
        mag_normalization: Optional[str] = None,
        ipd_normalization: Optional[str] = None,
    ):
        super().__init__()

        self.features = SpectrogramToMultichannelFeatures(
            num_subbands=num_subbands,
            num_input_channels=num_input_channels,
            mag_reduction=mag_reduction,
            mag_power=mag_power,
            use_ipd=use_ipd,
            mag_normalization=mag_normalization,
            ipd_normalization=ipd_normalization,
        )
        self.num_blocks = num_blocks
        logging.debug('Total number of blocks: %d', self.num_blocks)

        # Channel reduction
        if channel_reduction_position == -1:
            # Apply reduction after the last layer
            channel_reduction_position = num_blocks

        if channel_reduction_position > num_blocks:
            raise ValueError(
                f'Channel reduction position {channel_reduction_position} exceeds the number of blocks {num_blocks}'
            )
        self.channel_reduction_position = channel_reduction_position
        logging.debug('Channel reduction will be applied before block %d', self.channel_reduction_position)

        # Prepare processing blocks
        self.channel_blocks = torch.nn.ModuleList()
        self.temporal_blocks = torch.nn.ModuleList()

        for n in range(num_blocks):
            logging.debug('Prepare block %d', n)

            # Setup channel block
            if n < channel_reduction_position:
                # Number of input features is either the number of input channels or the number of temporal block features
                channel_in_features = self.features.num_features if n == 0 else temporal_block_dimension
                logging.debug(
                    'Setup channel block %s with %d input features and %d output features',
                    channel_block_type,
                    channel_in_features,
                    temporal_block_dimension,
                )

                # Instantiante the channel block
                if channel_block_type == 'transform_average_concatenate':
                    channel_block = TransformAverageConcatenate(
                        in_features=channel_in_features, out_features=temporal_block_dimension
                    )
                elif channel_block_type == 'transform_attend_concatenate':
                    channel_block = TransformAttendConcatenate(
                        in_features=channel_in_features, out_features=temporal_block_dimension
                    )
                else:
                    raise ValueError(f'Unknown channel layer type: {channel_block_type}')
                self.channel_blocks.append(channel_block)

            # Setup temporal block
            temporal_in_features = (
                self.features.num_features if n == self.channel_reduction_position == 0 else temporal_block_dimension
            )
            logging.debug('Setup temporal block %s', temporal_block_type)
            if temporal_block_type == 'conformer_encoder':
                temporal_block = ConformerEncoder(
                    feat_in=temporal_in_features,
                    n_layers=temporal_block_num_layers,
                    d_model=temporal_block_dimension,
                    subsampling_factor=1,
                    self_attention_model=temporal_block_self_attention_model,
                    att_context_size=temporal_block_att_context_size,
                    n_heads=temporal_block_num_heads,
                )
            else:
                raise ValueError(f'Unknown temporal block {temporal_block}.')

            self.temporal_blocks.append(temporal_block)

        logging.debug('Setup channel reduction %s', channel_reduction_type)
        if channel_reduction_type == 'average':
            # Mean across channel dimension
            self.channel_reduction = ChannelAveragePool()
        elif channel_reduction_type == 'attention':
            # Number of input features is either the number of input channels or the number of temporal block features
            channel_reduction_in_features = (
                self.features.num_features if self.channel_reduction_position == 0 else temporal_block_dimension
            )
            # Attention across channel dimension
            self.channel_reduction = ChannelAttentionPool(in_features=channel_reduction_in_features)
        else:
            raise ValueError(f'Unknown channel reduction type: {channel_reduction_type}')

        logging.debug('Setup %d output layers', num_outputs)
        self.output_layers = torch.nn.ModuleList(
            [
                ConformerEncoder(
                    feat_in=temporal_block_dimension,
                    n_layers=1,
                    d_model=temporal_block_dimension,
                    feat_out=num_subbands,
                    subsampling_factor=1,
                    self_attention_model=temporal_block_self_attention_model,
                    att_context_size=temporal_block_att_context_size,
                    n_heads=temporal_block_num_heads,
                )
                for _ in range(num_outputs)
            ]
        )

        # Output nonlinearity
        self.output_nonlinearity = torch.nn.Sigmoid()

    @property
    def input_types(self) -> Dict[str, NeuralType]:
        """Returns definitions of module output ports."""
        return {
            "input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
            "input_length": NeuralType(('B',), LengthsType()),
        }

    @property
    def output_types(self) -> Dict[str, NeuralType]:
        """Returns definitions of module output ports."""
        return {
            "output": NeuralType(('B', 'C', 'D', 'T'), FloatType()),
            "output_length": NeuralType(('B',), LengthsType()),
        }

    @typecheck()
    def forward(self, input: torch.Tensor, input_length: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Estimate `num_outputs` masks from the input spectrogram."""
        # get input features from a complex-valued spectrogram, (B, C, F, T)
        output, output_length = self.features(input=input, input_length=input_length)

        # batch and num channels
        B, M = input.size(0), input.size(1)

        # process all blocks
        for n in range(self.num_blocks):
            if n < self.channel_reduction_position:
                # apply multichannel block
                output = self.channel_blocks[n](input=output)
                # change to a single-stream format
                F, T = output.size(-2), output.size(-1)
                # (B, M, F, T) -> (B * M, F, T)
                output = output.reshape(-1, F, T)
                if M > 1:
                    # adjust the lengths accordingly
                    output_length = output_length.repeat_interleave(M)

            elif n == self.channel_reduction_position:
                # apply channel reduction
                # (B, M, F, T) -> (B, F, T)
                output = self.channel_reduction(input=output)

            # apply temporal model on each channel independently
            with typecheck.disable_checks():
                # output is AcousticEncodedRepresentation, conformer encoder requires SpectrogramType
                output, output_length = self.temporal_blocks[n](audio_signal=output, length=output_length)

            # if channel reduction has not been applied yet, go back to multichannel layout
            if n < self.channel_reduction_position:
                # back to multi-channel format with possibly a different number of features
                T = output.size(-1)
                # (B * M, F, T) -> (B, M, F, T)
                output = output.reshape(B, M, -1, T)
                if M > 1:
                    # convert lengths from single-stream format to original multichannel
                    output_length = output_length[0:-1:M]

        if self.channel_reduction_position == self.num_blocks:
            # apply channel reduction after the last layer
            # (B, M, F, T) -> (B, F, T)
            output = self.channel_reduction(input=output)

        # final mask for each output
        masks = []
        for output_layer in self.output_layers:
            # calculate mask
            with typecheck.disable_checks():
                # output is AcousticEncodedRepresentation, conformer encoder requires SpectrogramType
                mask, mask_length = output_layer(audio_signal=output, length=output_length)
            mask = self.output_nonlinearity(mask)
            # append to all masks
            masks.append(mask)

        # stack masks along channel dimensions
        masks = torch.stack(masks, dim=1)

        return masks, mask_length


class MaskEstimatorGSS(NeuralModule):
    """Estimate masks using guided source separation with a complex
    angular Central Gaussian Mixture Model (cACGMM) [1].

    This module corresponds to `GSS` in Fig. 2 in [2].

    Notation is approximately following [1], where `gamma` denotes
    the time-frequency mask, `alpha` denotes the mixture weights,
    and `BM` denotes the shape matrix. Additionally, the provided
    source activity is denoted as `activity`.

    Args:
        num_iterations: Number of iterations for the EM algorithm
        eps: Small value for regularization
        dtype: Data type for internal computations (default `torch.cdouble`)

    References:
        [1] Ito et al., Complex Angular Central Gaussian Mixture Model for Directional Statistics in Mask-Based Microphone Array Signal Processing, 2016
        [2] Boeddeker et al., Front-End Processing for the CHiME-5 Dinner Party Scenario, 2018
    """

    def __init__(self, num_iterations: int = 3, eps: float = 1e-8, dtype: torch.dtype = torch.cdouble):
        super().__init__()

        if num_iterations <= 0:
            raise ValueError(f'Number of iterations must be positive, got {num_iterations}')

        # number of iterations for the EM algorithm
        self.num_iterations = num_iterations

        if eps <= 0:
            raise ValueError(f'eps must be positive, got {eps}')

        # small regularization constant
        self.eps = eps

        # internal calculations
        if dtype not in [torch.cfloat, torch.cdouble]:
            raise ValueError(f'Unsupported dtype {dtype}, expecting cfloat or cdouble')
        self.dtype = dtype

        logging.debug('Initialized %s', self.__class__.__name__)
        logging.debug('\tnum_iterations: %s', self.num_iterations)
        logging.debug('\teps:            %g', self.eps)
        logging.debug('\tdtype:          %s', self.dtype)

    def normalize(self, x: torch.Tensor, dim: int = 1) -> torch.Tensor:
        """Normalize input to have a unit L2-norm across `dim`.
        By default, normalizes across the input channels.

        Args:
            x: C-channel input signal, shape (B, C, F, T)
            dim: Dimension for normalization, defaults to -3 to normalize over channels

        Returns:
            Normalized signal, shape (B, C, F, T)
        """
        norm_x = torch.linalg.vector_norm(x, ord=2, dim=dim, keepdim=True)
        x = x / (norm_x + self.eps)
        return x

    @typecheck(
        input_types={
            'alpha': NeuralType(('B', 'C', 'D')),
            'activity': NeuralType(('B', 'C', 'T')),
            'log_pdf': NeuralType(('B', 'C', 'D', 'T')),
        },
        output_types={
            'gamma': NeuralType(('B', 'C', 'D', 'T')),
        },
    )
    def update_masks(self, alpha: torch.Tensor, activity: torch.Tensor, log_pdf: torch.Tensor) -> torch.Tensor:
        """Update masks for the cACGMM.

        Args:
            alpha: component weights, shape (B, num_outputs, F)
            activity: temporal activity for the components, shape (B, num_outputs, T)
            log_pdf: logarithm of the PDF, shape (B, num_outputs, F, T)

        Returns:
            Masks for the components of the model, shape (B, num_outputs, F, T)
        """
        # (B, num_outputs, F)
        # normalize across outputs in the log domain
        log_gamma = log_pdf - torch.max(log_pdf, axis=-3, keepdim=True)[0]

        gamma = torch.exp(log_gamma)

        # calculate the mask using weight, pdf and source activity
        gamma = alpha[..., None] * gamma * activity[..., None, :]

        # normalize across components/output channels
        gamma = gamma / (torch.sum(gamma, dim=-3, keepdim=True) + self.eps)

        return gamma

    @typecheck(
        input_types={
            'gamma': NeuralType(('B', 'C', 'D', 'T')),
        },
        output_types={
            'alpha': NeuralType(('B', 'C', 'D')),
        },
    )
    def update_weights(self, gamma: torch.Tensor) -> torch.Tensor:
        """Update weights for the individual components
        in the mixture model.

        Args:
            gamma: masks, shape (B, num_outputs, F, T)

        Returns:
            Component weights, shape (B, num_outputs, F)
        """
        alpha = torch.mean(gamma, dim=-1)
        return alpha

    @typecheck(
        input_types={
            'z': NeuralType(('B', 'C', 'D', 'T')),
            'gamma': NeuralType(('B', 'C', 'D', 'T')),
            'zH_invBM_z': NeuralType(('B', 'C', 'D', 'T')),
        },
        output_types={
            'log_pdf': NeuralType(('B', 'C', 'D', 'T')),
            'zH_invBM_z': NeuralType(('B', 'C', 'D', 'T')),
        },
    )
    def update_pdf(
        self, z: torch.Tensor, gamma: torch.Tensor, zH_invBM_z: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Update PDF of the cACGMM.

        Args:
            z: directional statistics, shape (B, num_inputs, F, T)
            gamma: masks, shape (B, num_outputs, F, T)
            zH_invBM_z: energy weighted by shape matrices, shape (B, num_outputs, F, T)

        Returns:
            Logarithm of the PDF, shape (B, num_outputs, F, T), the energy term, shape (B, num_outputs, F, T)
        """
        num_inputs = z.size(-3)

        # shape (B, num_outputs, F, T)
        scale = gamma / (zH_invBM_z + self.eps)

        # scale outer product and sum over time
        # shape (B, num_outputs, F, num_inputs, num_inputs)
        BM = num_inputs * torch.einsum('bmft,bift,bjft->bmfij', scale.to(z.dtype), z, z.conj())

        # normalize across time
        denom = torch.sum(gamma, dim=-1)
        BM = BM / (denom[..., None, None] + self.eps)

        # make sure the matrix is Hermitian
        BM = (BM + BM.conj().transpose(-1, -2)) / 2

        # use eigenvalue decomposition to calculate the log determinant
        # and the inverse-weighted energy term
        L, Q = torch.linalg.eigh(BM)

        # BM is positive definite, so all eigenvalues should be positive
        # However, small negative values may occur due to a limited precision
        L = torch.clamp(L.real, min=self.eps)

        # PDF is invariant to scaling of the shape matrix [1], so
        # eignevalues can be normalized (across num_inputs)
        L = L / (torch.max(L, axis=-1, keepdim=True)[0] + self.eps)

        # small regularization to avoid numerical issues
        L = L + self.eps

        # calculate the log determinant using the eigenvalues
        log_detBM = torch.sum(torch.log(L), dim=-1)

        # calculate the energy term using the inverse eigenvalues
        # NOTE: keeping an alternative implementation for reference (slower)
        # zH_invBM_z = torch.einsum('bift,bmfij,bmfj,bmfkj,bkft->bmft', z.conj(), Q, (1 / L).to(Q.dtype), Q.conj(), z)
        # zH_invBM_z = zH_invBM_z.abs() + self.eps # small regularization

        # calc sqrt(L) * Q^H * z
        zH_invBM_z = torch.einsum('bmfj,bmfkj,bkft->bmftj', (1 / L.sqrt()).to(Q.dtype), Q.conj(), z)
        # calc squared norm
        zH_invBM_z = zH_invBM_z.abs().pow(2).sum(-1)
        # small regularization
        zH_invBM_z = zH_invBM_z + self.eps

        # final log PDF
        log_pdf = -num_inputs * torch.log(zH_invBM_z) - log_detBM[..., None]

        return log_pdf, zH_invBM_z

    @property
    def input_types(self) -> Dict[str, NeuralType]:
        """Returns definitions of module output ports."""
        return {
            "input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
            "activity": NeuralType(('B', 'C', 'T')),
        }

    @property
    def output_types(self) -> Dict[str, NeuralType]:
        """Returns definitions of module output ports."""
        return {
            "gamma": NeuralType(('B', 'C', 'D', 'T')),
        }

    @typecheck()
    def forward(self, input: torch.Tensor, activity: torch.Tensor) -> torch.Tensor:
        """Apply GSS to estimate the time-frequency masks for each output source.

        Args:
            input: batched C-channel input signal, shape (B, num_inputs, F, T)
            activity: batched frame-wise activity for each output source, shape (B, num_outputs, T)

        Returns:
            Masks for the components of the model, shape (B, num_outputs, F, T)
        """
        B, num_inputs, F, T = input.shape
        num_outputs = activity.size(1)
        device = input.device.type

        if activity.size(0) != B:
            raise ValueError(f'Batch dimension mismatch: activity {activity.shape} vs input {input.shape}')

        if activity.size(-1) != T:
            raise ValueError(f'Time dimension mismatch: activity {activity.shape} vs input {input.shape}')

        if num_outputs == 1:
            raise ValueError(f'Expecting multiple outputs, got {num_outputs}')

        with torch.amp.autocast(device, enabled=False):
            input = input.to(dtype=self.dtype)

            assert input.is_complex(), f'Expecting complex input, got {input.dtype}'

            # convert input to directional statistics by normalizing across channels
            z = self.normalize(input, dim=-3)

            # initialize masks
            gamma = torch.clamp(activity, min=self.eps)
            # normalize across channels
            gamma = gamma / torch.sum(gamma, dim=-2, keepdim=True)
            # expand to input shape
            gamma = gamma.unsqueeze(2).expand(-1, -1, F, -1)

            # initialize the energy term
            zH_invBM_z = torch.ones(B, num_outputs, F, T, dtype=input.dtype, device=input.device)

            # EM iterations
            for it in range(self.num_iterations):
                alpha = self.update_weights(gamma=gamma)
                log_pdf, zH_invBM_z = self.update_pdf(z=z, gamma=gamma, zH_invBM_z=zH_invBM_z)
                gamma = self.update_masks(alpha=alpha, activity=activity, log_pdf=log_pdf)

        if torch.any(torch.isnan(gamma)):
            raise RuntimeError(f'gamma contains NaNs: {gamma}')

        return gamma


class MaskReferenceChannel(NeuralModule):
    """A simple mask processor which applies mask
    on ref_channel of the input signal.

    Args:
        ref_channel: Index of the reference channel.
        mask_min_db: Threshold mask to a minimal value before applying it, defaults to -200dB
        mask_max_db: Threshold mask to a maximal value before applying it, defaults to 0dB
    """

    def __init__(self, ref_channel: int = 0, mask_min_db: float = -200, mask_max_db: float = 0):
        super().__init__()
        self.ref_channel = ref_channel
        # Mask thresholding
        self.mask_min = db2mag(mask_min_db)
        self.mask_max = db2mag(mask_max_db)

        logging.debug('Initialized %s with', self.__class__.__name__)
        logging.debug('\tref_channel: %d', self.ref_channel)
        logging.debug('\tmask_min:    %f', self.mask_min)
        logging.debug('\tmask_max:    %f', self.mask_max)

    @property
    def input_types(self) -> Dict[str, NeuralType]:
        """Returns definitions of module output ports."""
        return {
            "input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
            "input_length": NeuralType(('B',), LengthsType()),
            "mask": NeuralType(('B', 'C', 'D', 'T'), FloatType()),
        }

    @property
    def output_types(self) -> Dict[str, NeuralType]:
        """Returns definitions of module output ports."""
        return {
            "output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
            "output_length": NeuralType(('B',), LengthsType()),
        }

    @typecheck()
    def forward(
        self,
        input: torch.Tensor,
        input_length: torch.Tensor,
        mask: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Apply mask on `ref_channel` of the input signal.
        This can be used to generate multi-channel output.
        If `mask` has `M` channels, the output will have `M` channels as well.

        Args:
            input: Input signal complex-valued spectrogram, shape (B, C, F, N)
            input_length: Length of valid entries along the time dimension, shape (B,)
            mask: Mask for M outputs, shape (B, M, F, N)

        Returns:
            M-channel output complex-valed spectrogram with shape (B, M, F, N)
        """
        # Apply thresholds
        mask = torch.clamp(mask, min=self.mask_min, max=self.mask_max)

        # Apply each output mask on the ref channel
        output = mask * input[:, self.ref_channel : self.ref_channel + 1, ...]
        return output, input_length


class MaskBasedBeamformer(NeuralModule):
    """Multi-channel processor using masks to estimate signal statistics.

    Args:
        filter_type: string denoting the type of the filter. Defaults to `mvdr`
        filter_beta: Parameter of the parameteric multichannel Wiener filter
        filter_rank: Parameter of the parametric multichannel Wiener filter
        filter_postfilter: Optional, postprocessing of the filter
        ref_channel: Optional, reference channel. If None, it will be estimated automatically
        ref_hard: If true, hard (one-hot) reference. If false, a soft reference
        ref_hard_use_grad: If true, use straight-through gradient when using the hard reference
        ref_subband_weighting: If true, use subband weighting when estimating reference channel
        num_subbands: Optional, used to determine the parameter size for reference estimation
        mask_min_db: Threshold mask to a minimal value before applying it, defaults to -200dB
        mask_max_db: Threshold mask to a maximal value before applying it, defaults to 0dB
        diag_reg: Optional, diagonal regularization for the multichannel filter
        eps: Small regularization constant to avoid division by zero
    """

    def __init__(
        self,
        filter_type: str = 'mvdr_souden',
        filter_beta: float = 0.0,
        filter_rank: str = 'one',
        filter_postfilter: Optional[str] = None,
        ref_channel: Optional[int] = 0,
        ref_hard: bool = True,
        ref_hard_use_grad: bool = False,
        ref_subband_weighting: bool = False,
        num_subbands: Optional[int] = None,
        mask_min_db: float = -200,
        mask_max_db: float = 0,
        postmask_min_db: float = 0,
        postmask_max_db: float = 0,
        diag_reg: Optional[float] = 1e-6,
        eps: float = 1e-8,
    ):
        super().__init__()
        if filter_type not in ['pmwf', 'mvdr_souden']:
            raise ValueError(f'Unknown filter type {filter_type}')

        self.filter_type = filter_type
        if self.filter_type == 'mvdr_souden' and filter_beta != 0:
            logging.warning(
                'Using filter type %s: beta will be automatically set to zero (current beta %f) and rank to one (current rank %s).',
                self.filter_type,
                filter_beta,
                filter_rank,
            )
            filter_beta = 0.0
            filter_rank = 'one'
        # Prepare filter
        self.filter = ParametricMultichannelWienerFilter(
            beta=filter_beta,
            rank=filter_rank,
            postfilter=filter_postfilter,
            ref_channel=ref_channel,
            ref_hard=ref_hard,
            ref_hard_use_grad=ref_hard_use_grad,
            ref_subband_weighting=ref_subband_weighting,
            num_subbands=num_subbands,
            diag_reg=diag_reg,
            eps=eps,
        )
        # Mask thresholding
        if mask_min_db >= mask_max_db:
            raise ValueError(
                f'Lower bound for the mask {mask_min_db}dB must be smaller than the upper bound {mask_max_db}dB'
            )
        self.mask_min = db2mag(mask_min_db)
        self.mask_max = db2mag(mask_max_db)
        # Postmask thresholding
        if postmask_min_db > postmask_max_db:
            raise ValueError(
                f'Lower bound for the postmask {postmask_min_db}dB must be smaller or equal to the upper bound {postmask_max_db}dB'
            )
        self.postmask_min = db2mag(postmask_min_db)
        self.postmask_max = db2mag(postmask_max_db)

        logging.debug('Initialized %s', self.__class__.__name__)
        logging.debug('\tfilter_type:  %s', self.filter_type)
        logging.debug('\tmask_min:     %e', self.mask_min)
        logging.debug('\tmask_max:     %e', self.mask_max)
        logging.debug('\tpostmask_min: %e', self.postmask_min)
        logging.debug('\tpostmask_max: %e', self.postmask_max)

    @property
    def input_types(self) -> Dict[str, NeuralType]:
        """Returns definitions of module output ports."""
        return {
            "input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
            "mask": NeuralType(('B', 'C', 'D', 'T'), FloatType()),
            "mask_undesired": NeuralType(('B', 'C', 'D', 'T'), FloatType(), optional=True),
            "input_length": NeuralType(('B',), LengthsType(), optional=True),
        }

    @property
    def output_types(self) -> Dict[str, NeuralType]:
        """Returns definitions of module output ports."""
        return {
            "output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
            "output_length": NeuralType(('B',), LengthsType(), optional=True),
        }

    @typecheck()
    def forward(
        self,
        input: torch.Tensor,
        mask: torch.Tensor,
        mask_undesired: Optional[torch.Tensor] = None,
        input_length: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Apply a mask-based beamformer to the input spectrogram.
        This can be used to generate multi-channel output.
        If `mask` has multiple channels, a multichannel filter is created for each mask,
        and the output is concatenation of individual outputs along the channel dimension.
        The total number of outputs is `num_masks * M`, where `M` is the number of channels
        at the filter output.

        Args:
            input: Input signal complex-valued spectrogram, shape (B, C, F, N)
            mask: Mask for M output signals, shape (B, num_masks, F, N)
            input_length: Length of valid entries along the time dimension, shape (B,)

        Returns:
            Multichannel output signal complex-valued spectrogram, shape (B, num_masks * M, F, N)
        """
        # Length mask
        if input_length is not None:
            length_mask: torch.Tensor = make_seq_mask_like(
                lengths=input_length, like=mask[:, 0, ...], time_dim=-1, valid_ones=False
            )

        # Use each mask to generate an output
        output, num_masks = [], mask.size(1)
        for m in range(num_masks):
            # Desired signal mask
            mask_d = mask[:, m, ...]
            # Undesired signal mask
            if mask_undesired is not None:
                mask_u = mask_undesired[:, m, ...]
            elif num_masks == 1:
                # If a single mask is estimated, use the complement
                mask_u = 1 - mask_d
            else:
                # Use sum of all other sources
                mask_u = torch.sum(mask, dim=1) - mask_d

            # Threshold masks
            mask_d = torch.clamp(mask_d, min=self.mask_min, max=self.mask_max)
            mask_u = torch.clamp(mask_u, min=self.mask_min, max=self.mask_max)

            if input_length is not None:
                mask_d = mask_d.masked_fill(length_mask, 0.0)
                mask_u = mask_u.masked_fill(length_mask, 0.0)

            # Apply filter
            output_m = self.filter(input=input, mask_s=mask_d, mask_n=mask_u)

            # Optional: apply a postmask with min and max thresholds
            if self.postmask_min < self.postmask_max:
                postmask_m = torch.clamp(mask[:, m, ...], min=self.postmask_min, max=self.postmask_max)
                output_m = output_m * postmask_m.unsqueeze(1)

            # Save the current output (B, M, F, T)
            output.append(output_m)

        # Combine outputs along the channel dimension
        # Each output is (B, M, F, T)
        output = torch.concatenate(output, axis=1)

        # Apply masking
        if input_length is not None:
            output = output.masked_fill(length_mask[:, None, ...], 0.0)

        return output, input_length


class MaskBasedDereverbWPE(NeuralModule):
    """Multi-channel linear prediction-based dereverberation using
    weighted prediction error for filter estimation.

    An optional mask to estimate the signal power can be provided.
    If a time-frequency mask is not provided, the algorithm corresponds
    to the conventional WPE algorithm.

    Args:
        filter_length: Length of the convolutional filter for each channel in frames.
        prediction_delay: Delay of the input signal for multi-channel linear prediction in frames.
        num_iterations: Number of iterations for reweighting
        mask_min_db: Threshold mask to a minimal value before applying it, defaults to -200dB
        mask_max_db: Threshold mask to a minimal value before applying it, defaults to 0dB
        diag_reg: Diagonal regularization for WPE
        eps: Small regularization constant
        dtype: Data type for internal computations

    References:
        - Kinoshita et al, Neural network-based spectrum estimation for online WPE dereverberation, 2017
        - Yoshioka and Nakatani, Generalization of Multi-Channel Linear Prediction Methods for Blind MIMO Impulse Response Shortening, 2012
    """

    def __init__(
        self,
        filter_length: int,
        prediction_delay: int,
        num_iterations: int = 1,
        mask_min_db: float = -200,
        mask_max_db: float = 0,
        diag_reg: Optional[float] = 1e-6,
        eps: float = 1e-8,
        dtype: torch.dtype = torch.cdouble,
    ):
        super().__init__()
        # Filter setup
        self.filter = WPEFilter(
            filter_length=filter_length, prediction_delay=prediction_delay, diag_reg=diag_reg, eps=eps
        )
        self.num_iterations = num_iterations
        # Mask thresholding
        self.mask_min = db2mag(mask_min_db)
        self.mask_max = db2mag(mask_max_db)
        # Internal calculations
        if dtype not in [torch.cfloat, torch.cdouble]:
            raise ValueError(f'Unsupported dtype {dtype}, expecting torch.cfloat or torch.cdouble')
        self.dtype = dtype

        logging.debug('Initialized %s', self.__class__.__name__)
        logging.debug('\tnum_iterations: %s', self.num_iterations)
        logging.debug('\tmask_min:       %g', self.mask_min)
        logging.debug('\tmask_max:       %g', self.mask_max)
        logging.debug('\tdtype:          %s', self.dtype)

    @property
    def input_types(self) -> Dict[str, NeuralType]:
        """Returns definitions of module output ports."""
        return {
            "input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
            "input_length": NeuralType(('B',), LengthsType(), optional=True),
            "mask": NeuralType(('B', 'C', 'D', 'T'), FloatType(), optional=True),
        }

    @property
    def output_types(self) -> Dict[str, NeuralType]:
        """Returns definitions of module output ports."""
        return {
            "output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()),
            "output_length": NeuralType(('B',), LengthsType(), optional=True),
        }

    @typecheck()
    def forward(
        self, input: torch.Tensor, input_length: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """Given an input signal `input`, apply the WPE dereverberation algoritm.

        Args:
            input: C-channel complex-valued spectrogram, shape (B, C, F, T)
            input_length: Optional length for each signal in the batch, shape (B,)
            mask: Optional mask, shape (B, 1, F, N) or (B, C, F, T)

        Returns:
            Processed tensor with the same number of channels as the input,
            shape (B, C, F, T).
        """
        io_dtype = input.dtype
        device = input.device.type

        with torch.amp.autocast(device, enabled=False):
            output = input.to(dtype=self.dtype)

            if not output.is_complex():
                raise RuntimeError(f'Expecting complex input, got {output.dtype}')

            for i in range(self.num_iterations):
                magnitude = torch.abs(output)
                if i == 0 and mask is not None:
                    # Apply thresholds
                    mask = torch.clamp(mask, min=self.mask_min, max=self.mask_max)
                    # Mask magnitude
                    magnitude = mask * magnitude
                # Calculate power
                power = magnitude**2
                # Apply filter
                output, output_length = self.filter(input=output, input_length=input_length, power=power)

        return output.to(io_dtype), output_length