File size: 11,513 Bytes
1c8d125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96e1a32
 
1c8d125
 
 
 
 
 
96e1a32
1c8d125
 
 
 
 
 
96e1a32
1c8d125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96e1a32
 
1c8d125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96e1a32
1c8d125
 
 
96e1a32
1c8d125
 
 
 
 
96e1a32
1c8d125
 
 
 
 
 
96e1a32
1c8d125
 
 
 
 
96e1a32
1c8d125
 
 
 
 
 
96e1a32
1c8d125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96e1a32
 
1c8d125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96e1a32
1c8d125
 
 
 
 
 
 
 
 
96e1a32
1c8d125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96e1a32
 
1c8d125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96e1a32
1c8d125
 
 
 
 
 
96e1a32
1c8d125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96e1a32
 
1c8d125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96e1a32
1c8d125
 
 
 
 
 
96e1a32
1c8d125
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
from abc import ABC, abstractmethod

import torch


class BaseScaler(ABC):
    """
    Abstract base class for time series scalers.

    Defines the interface for scaling multivariate time series data with support
    for masked values and channel-wise scaling.
    """

    @abstractmethod
    def compute_statistics(
        self, history_values: torch.Tensor, history_mask: torch.Tensor | None = None
    ) -> dict[str, torch.Tensor]:
        """
        Compute scaling statistics from historical data.
        """
        pass

    @abstractmethod
    def scale(self, data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor:
        """
        Apply scaling transformation to data.
        """
        pass

    @abstractmethod
    def inverse_scale(self, scaled_data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor:
        """
        Apply inverse scaling transformation to recover original scale.
        """
        pass


class RobustScaler(BaseScaler):
    """
    Robust scaler using median and IQR for normalization.
    """

    def __init__(self, epsilon: float = 1e-6, min_scale: float = 1e-3):
        if epsilon <= 0:
            raise ValueError("epsilon must be positive")
        if min_scale <= 0:
            raise ValueError("min_scale must be positive")
        self.epsilon = epsilon
        self.min_scale = min_scale

    def compute_statistics(
        self, history_values: torch.Tensor, history_mask: torch.Tensor | None = None
    ) -> dict[str, torch.Tensor]:
        """
        Compute median and IQR statistics from historical data with improved numerical stability.
        """
        batch_size, seq_len, num_channels = history_values.shape
        device = history_values.device

        medians = torch.zeros(batch_size, 1, num_channels, device=device)
        iqrs = torch.ones(batch_size, 1, num_channels, device=device)

        for b in range(batch_size):
            for c in range(num_channels):
                channel_data = history_values[b, :, c]

                if history_mask is not None:
                    mask = history_mask[b, :].bool()
                    valid_data = channel_data[mask]
                else:
                    valid_data = channel_data

                if len(valid_data) == 0:
                    continue

                valid_data = valid_data[torch.isfinite(valid_data)]

                if len(valid_data) == 0:
                    continue

                median_val = torch.median(valid_data)
                medians[b, 0, c] = median_val

                if len(valid_data) > 1:
                    try:
                        q75 = torch.quantile(valid_data, 0.75)
                        q25 = torch.quantile(valid_data, 0.25)
                        iqr_val = q75 - q25
                        iqr_val = torch.max(iqr_val, torch.tensor(self.min_scale, device=device))
                        iqrs[b, 0, c] = iqr_val
                    except Exception:
                        std_val = torch.std(valid_data)
                        iqrs[b, 0, c] = torch.max(std_val, torch.tensor(self.min_scale, device=device))
                else:
                    iqrs[b, 0, c] = self.min_scale

        return {"median": medians, "iqr": iqrs}

    def scale(self, data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor:
        """
        Apply robust scaling: (data - median) / (iqr + epsilon).
        """
        median = statistics["median"]
        iqr = statistics["iqr"]

        denominator = torch.max(iqr + self.epsilon, torch.tensor(self.min_scale, device=iqr.device))
        scaled_data = (data - median) / denominator
        scaled_data = torch.clamp(scaled_data, -50.0, 50.0)

        return scaled_data

    def inverse_scale(self, scaled_data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor:
        """
        Apply inverse robust scaling, now compatible with 3D or 4D tensors.
        """
        median = statistics["median"]
        iqr = statistics["iqr"]

        denominator = torch.max(iqr + self.epsilon, torch.tensor(self.min_scale, device=iqr.device))

        if scaled_data.ndim == 4:
            denominator = denominator.unsqueeze(-1)
            median = median.unsqueeze(-1)

        return scaled_data * denominator + median


class MinMaxScaler(BaseScaler):
    """
    Min-Max scaler that normalizes data to the range [-1, 1].
    """

    def __init__(self, epsilon: float = 1e-8):
        if epsilon <= 0:
            raise ValueError("epsilon must be positive")
        self.epsilon = epsilon

    def compute_statistics(
        self, history_values: torch.Tensor, history_mask: torch.Tensor | None = None
    ) -> dict[str, torch.Tensor]:
        """
        Compute min and max statistics from historical data.
        """
        batch_size, seq_len, num_channels = history_values.shape
        device = history_values.device

        mins = torch.zeros(batch_size, 1, num_channels, device=device)
        maxs = torch.ones(batch_size, 1, num_channels, device=device)

        for b in range(batch_size):
            for c in range(num_channels):
                channel_data = history_values[b, :, c]

                if history_mask is not None:
                    mask = history_mask[b, :].bool()
                    valid_data = channel_data[mask]
                else:
                    valid_data = channel_data

                if len(valid_data) == 0:
                    continue

                min_val = torch.min(valid_data)
                max_val = torch.max(valid_data)

                mins[b, 0, c] = min_val
                maxs[b, 0, c] = max_val

                if torch.abs(max_val - min_val) < self.epsilon:
                    maxs[b, 0, c] = min_val + 1.0

        return {"min": mins, "max": maxs}

    def scale(self, data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor:
        """
        Apply min-max scaling to range [-1, 1].
        """
        min_val = statistics["min"]
        max_val = statistics["max"]

        normalized = (data - min_val) / (max_val - min_val + self.epsilon)
        return normalized * 2.0 - 1.0

    def inverse_scale(self, scaled_data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor:
        """
        Apply inverse min-max scaling, now compatible with 3D or 4D tensors.
        """
        min_val = statistics["min"]
        max_val = statistics["max"]

        if scaled_data.ndim == 4:
            min_val = min_val.unsqueeze(-1)
            max_val = max_val.unsqueeze(-1)

        normalized = (scaled_data + 1.0) / 2.0
        return normalized * (max_val - min_val + self.epsilon) + min_val


class MeanScaler(BaseScaler):
    """
    A scaler that centers the data by subtracting the channel-wise mean.

    This scaler only performs centering and does not affect the scale of the data.
    """

    def compute_statistics(
        self, history_values: torch.Tensor, history_mask: torch.Tensor | None = None
    ) -> dict[str, torch.Tensor]:
        """
        Compute the mean for each channel from historical data.
        """
        batch_size, seq_len, num_channels = history_values.shape
        device = history_values.device

        # Initialize a tensor to store the mean for each channel in each batch item
        means = torch.zeros(batch_size, 1, num_channels, device=device)

        for b in range(batch_size):
            for c in range(num_channels):
                channel_data = history_values[b, :, c]

                # Use the mask to select only valid (observed) data points
                if history_mask is not None:
                    mask = history_mask[b, :].bool()
                    valid_data = channel_data[mask]
                else:
                    valid_data = channel_data

                # Skip if there's no valid data for this channel
                if len(valid_data) == 0:
                    continue

                # Filter out non-finite values like NaN or Inf before computing
                valid_data = valid_data[torch.isfinite(valid_data)]

                if len(valid_data) == 0:
                    continue

                # Compute the mean and store it
                means[b, 0, c] = torch.mean(valid_data)

        return {"mean": means}

    def scale(self, data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor:
        """
        Apply mean centering: data - mean.
        """
        mean = statistics["mean"]
        return data - mean

    def inverse_scale(self, scaled_data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor:
        """
        Apply inverse mean centering: scaled_data + mean.

        Handles both 3D (e.g., training input) and 4D (e.g., model output samples) tensors.
        """
        mean = statistics["mean"]

        # Adjust shape for 4D tensors (batch, seq_len, channels, samples)
        if scaled_data.ndim == 4:
            mean = mean.unsqueeze(-1)

        return scaled_data + mean


class MedianScaler(BaseScaler):
    """
    A scaler that centers the data by subtracting the channel-wise median.

    This scaler only performs centering and does not affect the scale of the data.
    It is more robust to outliers than the MeanScaler.
    """

    def compute_statistics(
        self, history_values: torch.Tensor, history_mask: torch.Tensor | None = None
    ) -> dict[str, torch.Tensor]:
        """
        Compute the median for each channel from historical data.
        """
        batch_size, seq_len, num_channels = history_values.shape
        device = history_values.device

        # Initialize a tensor to store the median for each channel in each batch item
        medians = torch.zeros(batch_size, 1, num_channels, device=device)

        for b in range(batch_size):
            for c in range(num_channels):
                channel_data = history_values[b, :, c]

                # Use the mask to select only valid (observed) data points
                if history_mask is not None:
                    mask = history_mask[b, :].bool()
                    valid_data = channel_data[mask]
                else:
                    valid_data = channel_data

                # Skip if there's no valid data for this channel
                if len(valid_data) == 0:
                    continue

                # Filter out non-finite values like NaN or Inf before computing
                valid_data = valid_data[torch.isfinite(valid_data)]

                if len(valid_data) == 0:
                    continue

                # Compute the median and store it
                medians[b, 0, c] = torch.median(valid_data)

        return {"median": medians}

    def scale(self, data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor:
        """
        Apply median centering: data - median.
        """
        median = statistics["median"]
        return data - median

    def inverse_scale(self, scaled_data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor:
        """
        Apply inverse median centering: scaled_data + median.

        Handles both 3D (e.g., training input) and 4D (e.g., model output samples) tensors.
        """
        median = statistics["median"]

        # Adjust shape for 4D tensors (batch, seq_len, channels, samples)
        if scaled_data.ndim == 4:
            median = median.unsqueeze(-1)

        return scaled_data + median