vlad-moroshan commited on
Commit
03621a4
·
1 Parent(s): 87d1677

Add custom GatedDeltaProduct implementation with state weaving

Browse files
src/models/blocks.py CHANGED
@@ -1,7 +1,10 @@
1
  import torch
2
  import torch.nn as nn
3
- from fla.models.gated_deltaproduct import GatedDeltaProductConfig
4
- from fla.models.gated_deltaproduct.modeling_gated_deltaproduct import GatedDeltaProductBlock
 
 
 
5
 
6
 
7
  class GatedDeltaProductEncoder(nn.Module):
 
1
  import torch
2
  import torch.nn as nn
3
+
4
+ from src.models.gated_deltaproduct import GatedDeltaProductConfig
5
+ from src.models.gated_deltaproduct.modeling_gated_deltaproduct import (
6
+ GatedDeltaProductBlock,
7
+ )
8
 
9
 
10
  class GatedDeltaProductEncoder(nn.Module):
src/models/gated_deltaproduct/README.md ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Custom GatedDeltaProduct Implementation
2
+
3
+ This directory contains a custom implementation of the GatedDeltaProduct layer, based on the [Flash Linear Attention (FLA)](https://github.com/fla-org/flash-linear-attention) library, with modifications specifically designed for **time series forecasting** tasks.
4
+
5
+ ## Overview
6
+
7
+ Our custom implementation adds **hidden state weaving** functionality that enables information to flow across encoder layers, maintaining temporal continuity - a crucial feature for time series forecasting that differs from the general-purpose language modeling focus of the official FLA implementation.
8
+
9
+ ## Reference
10
+
11
+ This implementation is based on:
12
+ - **Official FLA Repository**: [https://github.com/fla-org/flash-linear-attention](https://github.com/fla-org/flash-linear-attention)
13
+ - **Original Paper**: [DeltaProduct: Improving State-Tracking in Linear RNNs via Householder Products](https://arxiv.org/html/2502.10297v3) (Siems et al., 2025)
14
+
15
+ ---
16
+
17
+ ## What is DeltaProduct?
18
+
19
+ DeltaProduct is a linear RNN architecture that uses **diagonal plus rank-nₕ** state-transition matrices, formed as products of `nₕ` generalized Householder transformations. This provides a tunable mechanism to balance expressivity and efficiency compared to diagonal-only architectures like Mamba or GLA.
20
+
21
+ ### Key Concepts
22
+
23
+ - **Householder transformations**: Enable simultaneous token-channel mixing, overcoming the expressivity limitations of purely diagonal state-transition matrices
24
+ - **Rank-nₕ structure**: Allows better expressivity than rank-1 (DeltaNet) while maintaining training efficiency. The parameter `nₕ` (number of Householder transformations) provides a tunable trade-off between expressivity and computational cost
25
+ - **Gated variant**: Adds gating mechanisms for improved performance, allowing the model to control information flow through forget gates and output gates
26
+
27
+ ### Architecture Overview
28
+
29
+ DeltaProduct improves upon earlier linear RNN architectures:
30
+
31
+ - **Diagonal architectures** (Mamba, GLA, mLSTM): Use diagonal state-transition matrices for fast runtime but suffer from limited expressivity
32
+ - **Rank-1 architectures** (DeltaNet, RWKV-7): Use diagonal plus rank-1 structure, enabling simultaneous token-channel mixing with only a slight decrease in training efficiency
33
+ - **DeltaProduct**: Extends this to diagonal plus rank-nₕ structure, where multiple Householder transformations (nₕ ≥ 1) provide greater expressivity while maintaining computational efficiency
34
+
35
+ The architecture interprets DeltaNet's recurrence as performing one step of online gradient descent per token on an associative recall loss. DeltaProduct instead takes multiple (`nₕ`) steps per token, naturally leading to the rank-nₕ structure.
36
+
37
+ ---
38
+
39
+ ## State Weaving Mechanism
40
+
41
+ Unlike DeltaProduct's original design for autoregressive language modeling, time series forecasting across a full horizon does not require causal masking. To exploit this property, we introduce **state weaving**, a mechanism that enables bidirectional information flow across the entire sequence length without additional parameters or computational overhead.
42
+
43
+ <div align="center">
44
+ <img src="https://iili.io/Ks86Z0X.png" alt="State Weaving Architecture" width="450"/>
45
+ </div>
46
+
47
+ *Figure: The TempoPFN architecture using stacked GatedDeltaProduct blocks with learnable initial states H₀ⁱ and state-weaving. The final hidden state of each layer Hₜⁱ is added to the learnable initial state of the next layer H₀ⁱ⁺¹, enabling bidirectional information flow.*
48
+
49
+ ### How State Weaving Works
50
+
51
+ In our implementation, state weaving operates as follows:
52
+
53
+ 1. **Learnable Initial States**: Each encoder layer `i` has a learnable initial hidden state `H₀ⁱ` that is optimized during training.
54
+
55
+ 2. **State Propagation**: The final hidden state from layer `i`, denoted `Hₜⁱ`, is propagated forward and combined with the learnable initial state of the next layer:
56
+ ```
57
+ H₀ⁱ⁺¹ = H₀ⁱ⁺¹ + Hₜⁱ
58
+ ```
59
+
60
+ 3. **Bidirectional Information Flow**: This mechanism effectively lifts the causal constraint while maintaining computational efficiency. Information from later tokens can influence earlier layers through the accumulated hidden states, enabling the model to process the entire sequence (history + future horizon) coherently.
61
+
62
+ 4. **No Extra Overhead**: Unlike explicit bidirectional architectures, state weaving requires no additional parameters or computational overhead beyond the existing forward pass.
63
+
64
+ This design is particularly powerful for time series forecasting, where:
65
+ - The full prediction horizon is known at inference time
66
+ - Coherent predictions across all future time steps are desired
67
+ - Historical context should inform all future predictions simultaneously
68
+
69
+ ---
70
+
71
+ ## Key Differences from Official FLA
72
+
73
+ ### 1. **`initial_state` Parameter in Forward Method**
74
+
75
+ #### Official FLA (`fla/layers/gated_deltaproduct.py`)
76
+ ```python
77
+ def forward(
78
+ self,
79
+ hidden_states: torch.Tensor,
80
+ attention_mask: torch.Tensor | None = None,
81
+ past_key_values: Cache | None = None,
82
+ use_cache: bool | None = False,
83
+ output_attentions: bool | None = False,
84
+ **kwargs: Unpack[dict],
85
+ ) -> tuple[torch.Tensor, torch.Tensor | None, Cache | None]:
86
+ ```
87
+ **No `initial_state` parameter** - The official implementation only uses `recurrent_state` from `past_key_values`.
88
+
89
+ #### Our Custom Implementation (`gated_deltaproduct.py`)
90
+ ```python
91
+ def forward(
92
+ self,
93
+ hidden_states: torch.Tensor,
94
+ attention_mask: Optional[torch.Tensor] = None,
95
+ past_key_values: Optional[Cache] = None,
96
+ initial_state: Optional[torch.Tensor] = None, # ← ADDED
97
+ use_cache: Optional[bool] = False,
98
+ output_attentions: Optional[bool] = False,
99
+ **kwargs: Unpack[Dict],
100
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
101
+ ```
102
+ **Added `initial_state` parameter** - Allows external control of the initial recurrent state, enabling layer-to-layer state propagation.
103
+
104
+ ---
105
+
106
+ ### 2. **Usage of `initial_state` in Chunk Mode**
107
+
108
+ #### Official FLA
109
+ ```python
110
+ if mode == 'chunk':
111
+ o, recurrent_state = chunk_gated_delta_product(
112
+ q=q, k=k, v=v, g=g, beta=beta,
113
+ initial_state=recurrent_state, # ← Only from past_key_values
114
+ output_final_state=use_cache,
115
+ cu_seqlens=cu_seqlens,
116
+ num_householder=self.num_householder,
117
+ use_qk_l2norm_in_kernel=True,
118
+ )
119
+ ```
120
+
121
+ #### Our Custom Implementation
122
+ ```python
123
+ if mode == "chunk":
124
+ o, recurrent_state = chunk_gated_delta_product(
125
+ q=q, k=k, v=v, g=g, beta=beta,
126
+ initial_state=initial_state, # ← Uses external initial_state if provided
127
+ output_final_state=output_attentions,
128
+ cu_seqlens=cu_seqlens,
129
+ num_householder=self.num_householder,
130
+ use_qk_l2norm_in_kernel=True,
131
+ )
132
+ ```
133
+
134
+ **Key Difference**: Our implementation prioritizes the externally provided `initial_state` over `recurrent_state` from `past_key_values`, enabling layer-to-layer state propagation.
135
+
136
+ ---
137
+
138
+ ### 3. **Return Value: Hidden State Output**
139
+
140
+ #### Official FLA (`fla/models/gated_deltaproduct/modeling_gated_deltaproduct.py`)
141
+ ```python
142
+ def forward(
143
+ self,
144
+ hidden_states: torch.Tensor,
145
+ attention_mask: torch.Tensor | None = None,
146
+ past_key_values: Cache | list[torch.FloatTensor] | None = None,
147
+ use_cache: bool | None = False,
148
+ output_attentions: bool | None = False,
149
+ **kwargs: Unpack[dict],
150
+ ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
151
+ # ...
152
+ return outputs # Returns (hidden_states, attentions, past_key_values)
153
+ ```
154
+
155
+ **No `initial_state` parameter** - The block doesn't accept or return hidden states explicitly.
156
+
157
+ #### Our Custom Implementation (`modeling_gated_deltaproduct.py`)
158
+ ```python
159
+ def forward(
160
+ self,
161
+ hidden_states: torch.Tensor,
162
+ attention_mask: Optional[torch.Tensor] = None,
163
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
164
+ use_cache: Optional[bool] = False,
165
+ output_attentions: Optional[bool] = False,
166
+ initial_state: Optional[torch.FloatTensor] = None, # ← ADDED
167
+ **kwargs: Unpack[Dict],
168
+ ) -> Tuple[
169
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
170
+ ]:
171
+ # ...
172
+ hidden_states, attentions, past_key_values = self.attn(
173
+ # ...
174
+ initial_state=initial_state, # ← Passed through
175
+ **kwargs,
176
+ )
177
+ # ...
178
+ return outputs # Returns (hidden_states, attentions, past_key_values)
179
+ ```
180
+
181
+ **Added `initial_state` parameter** - The block accepts and forwards `initial_state` to the attention layer.
182
+
183
+ ---
184
+
185
+ ### 4. **Hidden State Weaving Implementation**
186
+
187
+ Our implementation supports two modes of hidden state weaving (controlled by the `weaving` parameter in encoder config):
188
+
189
+ #### **Mode 1: Weaving Enabled (`weaving=True`)** - Default
190
+ ```python
191
+ if self.encoder_config.get("weaving", True):
192
+ # initial hidden state is learnable
193
+ hidden_state = torch.zeros_like(
194
+ self.initial_hidden_state[0].repeat(batch_size * num_channels, 1, 1, 1)
195
+ )
196
+ for layer_idx, encoder_layer in enumerate(self.encoder_layers):
197
+ x, hidden_state = encoder_layer(
198
+ x,
199
+ hidden_state + self.initial_hidden_state[layer_idx].repeat(
200
+ batch_size * num_channels, 1, 1, 1
201
+ ),
202
+ )
203
+ ```
204
+
205
+ **Key Features**:
206
+ - Hidden state accumulates across layers
207
+ - Each layer receives: `previous_hidden_state + learnable_initial_state[layer_idx]`
208
+ - State persists between layers, allowing information to flow through the network
209
+
210
+ #### **Mode 2: No Weaving (`weaving=False`)**
211
+ ```python
212
+ else:
213
+ # initial hidden state is separately learnable for each layer
214
+ for layer_idx, encoder_layer in enumerate(self.encoder_layers):
215
+ initial_hidden_state = self.initial_hidden_state[layer_idx].repeat(
216
+ batch_size * num_channels, 1, 1, 1
217
+ )
218
+ x, _ = encoder_layer(x, initial_hidden_state)
219
+ ```
220
+
221
+ **Key Features**:
222
+ - Each layer uses its own independent learnable initial state
223
+ - No accumulation between layers
224
+ - Hidden state is discarded after each layer
225
+
226
+ ---
227
+
228
+ ### 5. **Learnable Initial Hidden States**
229
+
230
+ Our implementation includes learnable initial states managed at the model level:
231
+
232
+ ```python
233
+ num_initial_hidden_states = self.num_encoder_layers
234
+ self.initial_hidden_state = nn.ParameterList(
235
+ [
236
+ nn.Parameter(
237
+ torch.randn(
238
+ 1, self.encoder_config["num_heads"], head_k_dim, head_v_dim
239
+ )
240
+ / head_k_dim,
241
+ requires_grad=True,
242
+ )
243
+ for _ in range(num_initial_hidden_states)
244
+ ]
245
+ )
246
+ ```
247
+
248
+ **Key Features**:
249
+ - One learnable parameter per encoder layer
250
+ - Shape: `[1, num_heads, head_k_dim, head_v_dim]`
251
+ - Initialized with small random values scaled by `head_k_dim`
252
+ - These are trainable parameters that can be optimized during training
253
+
254
+ ---
255
+
256
+ ### 6. **Parameter Name Differences**
257
+
258
+ - **Official FLA**: Uses `use_output_gate` parameter
259
+ - **Our Implementation**: Uses `use_gate` parameter (renamed for clarity)
260
+
261
+ ---
262
+
263
+ ### 7. **Return Value Differences**
264
+
265
+ #### Official FLA (`fla/layers/gated_deltaproduct.py`)
266
+ ```python
267
+ return o, None, past_key_values # Returns (output, None, past_key_values)
268
+ ```
269
+
270
+ #### Our Custom Implementation (`gated_deltaproduct.py`)
271
+ ```python
272
+ return o, recurrent_state, past_key_values # Returns (output, recurrent_state, past_key_values)
273
+ ```
274
+
275
+ **Key Difference**: Our implementation returns `recurrent_state` (the final hidden state) instead of `None`, enabling state propagation.
276
+
277
+ ---
278
+
279
+ ### 8. **Encoder Wrapper Return Values**
280
+
281
+ Our `GatedDeltaProductEncoder` (in `src/models/blocks.py`) returns both the output and hidden state:
282
+
283
+ ```python
284
+ x, last_hidden_state, _ = self.encoder_layer(
285
+ x, output_attentions=True, initial_state=initial_state
286
+ )
287
+ return x, last_hidden_state # ← Returns hidden state for weaving
288
+ ```
289
+
290
+ This allows state propagation between layers in the `TimeSeriesModel`.
291
+
292
+ ---
293
+
294
+ ## Summary Table
295
+
296
+ | Feature | Official FLA | Our Custom Implementation |
297
+ |---------|-------------|---------------------------|
298
+ | `initial_state` in `forward()` | ❌ No | ✅ Yes |
299
+ | `initial_state` in `GatedDeltaProductBlock.forward()` | ❌ No | ✅ Yes |
300
+ | Hidden state weaving | ❌ No | ✅ Yes (configurable) |
301
+ | Learnable initial states | ❌ No | ✅ Yes (`nn.ParameterList`) |
302
+ | Returns `recurrent_state` | ❌ No (returns `None`) | ✅ Yes |
303
+ | Layer-to-layer state propagation | ❌ No | ✅ Yes (when `weaving=True`) |
304
+ | Parameter name | `use_output_gate` | `use_gate` |
305
+
306
+ ---
307
+
308
+ ## Why These Differences Matter for Time Series Forecasting
309
+
310
+ 1. **Temporal Continuity**: Hidden state weaving allows information to flow across layers, maintaining temporal patterns across the encoder stack. This is crucial for time series where historical context matters.
311
+
312
+ 2. **Learnable Initialization**: Learnable initial states allow the model to learn optimal starting points for the recurrent computation, which can be crucial for capturing time series patterns.
313
+
314
+ 3. **Flexible State Management**: The `weaving` parameter allows switching between:
315
+ - **Weaving mode**: Better for capturing long-term dependencies across layers
316
+ - **Independent mode**: Each layer processes independently, potentially more stable
317
+
318
+ 4. **State Propagation**: Returning and propagating hidden states enables the model to maintain context across multiple encoder layers, which is beneficial for time series forecasting where historical context matters.
319
+
320
+ These modifications make our implementation better suited for time series forecasting tasks compared to the general-purpose language modeling focus of the official FLA implementation.
321
+
322
+ ---
323
+
324
+ ## Files in This Directory
325
+
326
+ - **`gated_deltaproduct.py`**: Core GatedDeltaProduct layer implementation with `initial_state` support
327
+ - **`modeling_gated_deltaproduct.py`**: GatedDeltaProductBlock wrapper that integrates the layer
328
+ - **`configuration_gated_deltaproduct.py`**: Configuration class for the model
329
+ - **`__init__.py`**: Module exports
330
+
331
+ ---
332
+
333
+ ## Usage
334
+
335
+ See `src/models/model.py` and `src/models/blocks.py` for examples of how to use this custom implementation with hidden state weaving.
336
+
337
+ To enable/disable weaving, set the `weaving` parameter in your encoder configuration:
338
+ ```python
339
+ encoder_config = {
340
+ "weaving": True, # Enable state propagation across layers
341
+ # ... other config parameters
342
+ }
343
+ ```
344
+
src/models/gated_deltaproduct/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.models.gated_deltaproduct.configuration_gated_deltaproduct import (
2
+ GatedDeltaProductConfig,
3
+ )
4
+ from src.models.gated_deltaproduct.modeling_gated_deltaproduct import (
5
+ GatedDeltaProductBlock,
6
+ )
7
+
8
+ __all__ = [
9
+ "GatedDeltaProductConfig",
10
+ "GatedDeltaProductBlock",
11
+ ]
src/models/gated_deltaproduct/configuration_gated_deltaproduct.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ from transformers.configuration_utils import PretrainedConfig
4
+
5
+
6
+ class GatedDeltaProductConfig(PretrainedConfig):
7
+ model_type = "gated_deltaproduct"
8
+ keys_to_ignore_at_inference = ["past_key_values"]
9
+
10
+ def __init__(
11
+ self,
12
+ attn_mode: str = "chunk",
13
+ conv_size: int = 4,
14
+ head_dim: int = 256,
15
+ num_heads: int = 6,
16
+ hidden_size: int = 2048,
17
+ expand_v: float = 2.0,
18
+ use_gate: bool = True, # Changed from use_output_gate to use_gate for custom implementation
19
+ use_short_conv: bool = True,
20
+ max_position_embeddings: int = 2048,
21
+ hidden_ratio: int | None = 4,
22
+ intermediate_size: int | None = None,
23
+ hidden_act: str = "swish",
24
+ num_hidden_layers: int = 21,
25
+ norm_eps: float = 1e-6,
26
+ attn: dict | None = None,
27
+ use_cache: bool = True,
28
+ pad_token_id: int = None,
29
+ bos_token_id: int = 1,
30
+ eos_token_id: int = 2,
31
+ tie_word_embeddings: bool = False,
32
+ initializer_range: float = 0.02,
33
+ fuse_norm: bool = True,
34
+ fuse_swiglu: bool = True,
35
+ fuse_cross_entropy: bool = True,
36
+ fuse_linear_cross_entropy: bool = False,
37
+ use_l2warp: bool = False,
38
+ vocab_size: int = 32000,
39
+ use_forget_gate: bool = False,
40
+ allow_neg_eigval: bool = False,
41
+ num_householder: int = 1,
42
+ **kwargs,
43
+ ):
44
+ self.attn_mode = attn_mode
45
+ self.conv_size = conv_size
46
+ self.head_dim = head_dim
47
+ self.num_heads = num_heads
48
+ self.hidden_size = hidden_size
49
+ self.expand_v = expand_v
50
+ self.use_gate = use_gate # Changed from use_output_gate to use_gate
51
+ self.use_short_conv = use_short_conv
52
+ self.max_position_embeddings = max_position_embeddings
53
+
54
+ self.hidden_ratio = hidden_ratio
55
+ self.intermediate_size = intermediate_size
56
+ self.hidden_act = hidden_act
57
+ self.num_hidden_layers = num_hidden_layers
58
+ self.norm_eps = norm_eps
59
+ self.attn = attn
60
+ self.use_cache = use_cache
61
+ self.initializer_range = initializer_range
62
+
63
+ self.fuse_norm = fuse_norm
64
+ self.fuse_swiglu = fuse_swiglu
65
+ self.fuse_cross_entropy = fuse_cross_entropy
66
+ self.fuse_linear_cross_entropy = fuse_linear_cross_entropy
67
+ self.use_l2warp = use_l2warp
68
+ self.vocab_size = vocab_size
69
+
70
+ if fuse_cross_entropy and fuse_linear_cross_entropy:
71
+ raise ValueError(
72
+ "`fuse_cross_entropy` and `fuse_linear_cross_entropy` cannot be True at the same time.",
73
+ )
74
+ if fuse_linear_cross_entropy:
75
+ warnings.warn(
76
+ "`fuse_linear_cross_entropy` is enabled, which can improves memory efficiency "
77
+ "at the potential cost of reduced precision. "
78
+ "If you observe issues like loss divergence, consider disabling this setting.",
79
+ )
80
+
81
+ # DeltaProduct specific
82
+ self.allow_neg_eigval = allow_neg_eigval
83
+ self.num_householder = num_householder
84
+ self.use_forget_gate = use_forget_gate
85
+
86
+ if attn is not None:
87
+ if not isinstance(attn, dict):
88
+ raise ValueError("attn must be a dictionary")
89
+ if "layers" not in attn:
90
+ raise ValueError(
91
+ "Layer indices must be provided to initialize hybrid attention layers"
92
+ )
93
+ if "num_heads" not in attn:
94
+ raise ValueError(
95
+ "Number of heads must be provided to initialize hybrid attention layers"
96
+ )
97
+ attn["num_kv_heads"] = attn.get("num_kv_heads", attn["num_heads"])
98
+ attn["qkv_bias"] = attn.get("qkv_bias", False)
99
+ attn["window_size"] = attn.get("window_size", None)
100
+ attn["rope_theta"] = attn.get("rope_theta", 10000.0)
101
+
102
+ super().__init__(
103
+ pad_token_id=pad_token_id,
104
+ bos_token_id=bos_token_id,
105
+ eos_token_id=eos_token_id,
106
+ tie_word_embeddings=tie_word_embeddings,
107
+ **kwargs,
108
+ )
src/models/gated_deltaproduct/gated_deltaproduct.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ import math
7
+ import warnings
8
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from einops import rearrange, repeat
13
+ from fla.layers.utils import get_unpad_data, index_first_axis, pad_input
14
+ from fla.modules import FusedRMSNormGated, RMSNorm, ShortConvolution
15
+ from fla.ops.delta_rule import fused_recurrent_delta_rule
16
+ from fla.ops.gated_delta_product import chunk_gated_delta_product
17
+ from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule
18
+ from torch.nn import functional as F
19
+
20
+ if TYPE_CHECKING:
21
+ from fla.models.utils import Cache
22
+ from transformers.processing_utils import Unpack
23
+
24
+
25
+ class GatedDeltaProduct(nn.Module):
26
+ """
27
+ Generalized version of GatedDoubleDeltaNet that supports arbitrary number of householder transformations.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ hidden_size: int = 2048,
33
+ expand_v: float = 2,
34
+ head_dim: int = 256,
35
+ num_heads: int = 6,
36
+ num_v_heads: int = None,
37
+ mode: str = "chunk",
38
+ use_gate: bool = True,
39
+ use_short_conv: bool = True,
40
+ conv_size: int = 4,
41
+ conv_bias: bool = False,
42
+ layer_idx: int = None,
43
+ norm_eps: float = 1e-5,
44
+ use_forget_gate: bool = True,
45
+ allow_neg_eigval: bool = True,
46
+ num_householder: int = 2,
47
+ **kwargs,
48
+ ) -> GatedDeltaProduct:
49
+ super().__init__()
50
+
51
+ self.mode = mode
52
+
53
+ self.hidden_size = hidden_size
54
+ self.expand_v = expand_v
55
+
56
+ self.use_forget_gate = use_forget_gate
57
+ self.allow_neg_eigval = allow_neg_eigval
58
+ self.num_householder = num_householder
59
+ self.use_gate = use_gate
60
+ self.use_short_conv = use_short_conv
61
+ self.conv_size = conv_size
62
+ self.conv_bias = conv_bias
63
+
64
+ self.head_dim = head_dim
65
+ self.num_heads = num_heads
66
+ self.num_v_heads = num_v_heads if num_v_heads is not None else num_heads
67
+
68
+ self.head_k_dim = head_dim
69
+ self.head_v_dim = int(self.head_dim * self.expand_v)
70
+ self.key_dim = int(self.num_heads * self.head_k_dim)
71
+ self.value_dim = int(self.num_v_heads * self.head_v_dim)
72
+ self.layer_idx = layer_idx
73
+ self.init_hidden_state = nn.Parameter(
74
+ torch.randn(self.num_heads, self.head_dim, self.head_dim)
75
+ )
76
+
77
+ # Consistency check: Ensure expand_v produces integer values
78
+ if not math.isclose(
79
+ self.num_v_heads * self.head_dim * expand_v, self.value_dim, rel_tol=1e-5
80
+ ):
81
+ raise ValueError(
82
+ f"expand_v={expand_v} does not produce an integer value when multiplied by key_dim={self.key_dim}. "
83
+ f"Resulting value_dim would be {self.num_v_heads * self.head_dim * expand_v}, which is invalid for nn.Linear."
84
+ )
85
+ if self.num_v_heads > self.num_heads and self.num_v_heads % self.num_heads != 0:
86
+ raise ValueError(
87
+ f"num_v_heads={self.num_v_heads} must be divisible by num_heads={self.num_heads}."
88
+ )
89
+
90
+ if not math.isclose(head_dim * expand_v, self.head_v_dim, rel_tol=1e-5):
91
+ raise ValueError(
92
+ f"expand_v={expand_v} does not produce an integer value when multiplied by head_dim={head_dim}. "
93
+ f"Resulting head_v_dim would be {head_dim * expand_v}, which is invalid for FusedRMSNormGated."
94
+ )
95
+ assert mode in ["chunk", "fused_recurrent"], f"Not suppoerted mode `{mode}`."
96
+
97
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
98
+ self.k_proj = nn.Linear(hidden_size, self.key_dim * num_householder, bias=False)
99
+ self.v_proj = nn.Linear(
100
+ hidden_size, self.value_dim * num_householder, bias=False
101
+ )
102
+ self.b_proj = nn.Linear(
103
+ hidden_size, self.num_v_heads * num_householder, bias=False
104
+ )
105
+
106
+ if self.use_forget_gate:
107
+ self.a_proj = nn.Linear(hidden_size, self.num_v_heads, bias=False)
108
+ A = torch.empty(self.num_v_heads, dtype=torch.float32).uniform_(0, 16)
109
+ self.A_log = nn.Parameter(torch.log(A))
110
+ self.A_log._no_weight_decay = True
111
+ # hard coded for now
112
+ dt_min = 0.001
113
+ dt_max = 0.1
114
+ dt_init_floor = 1e-4
115
+ dt = torch.exp(
116
+ torch.rand(self.num_v_heads) * (math.log(dt_max) - math.log(dt_min))
117
+ + math.log(dt_min)
118
+ )
119
+ dt = torch.clamp(dt, min=dt_init_floor)
120
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
121
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
122
+ self.dt_bias = nn.Parameter(inv_dt)
123
+ # Just to be explicit. Without this we already don't put wd on dt_bias because of the check
124
+ # name.endswith("bias") in param_grouping.py
125
+ self.dt_bias._no_weight_decay = True
126
+
127
+ if use_short_conv:
128
+ self.conv_size = conv_size
129
+ self.q_conv1d = ShortConvolution(
130
+ hidden_size=self.key_dim,
131
+ kernel_size=conv_size,
132
+ bias=conv_bias,
133
+ activation="silu",
134
+ )
135
+ self.k_conv1d = ShortConvolution(
136
+ hidden_size=self.key_dim * num_householder,
137
+ kernel_size=conv_size,
138
+ bias=conv_bias,
139
+ activation="silu",
140
+ )
141
+ self.v_conv1d = ShortConvolution(
142
+ hidden_size=self.value_dim * num_householder,
143
+ kernel_size=conv_size,
144
+ bias=conv_bias,
145
+ activation="silu",
146
+ )
147
+ else:
148
+ warnings.warn(
149
+ "ShortConvolution is crucial to the performance. "
150
+ "Do not turn it off, i.e., setting `use_short_conv=False` unless you know what you are doing."
151
+ )
152
+ if use_gate:
153
+ self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
154
+ self.o_norm = FusedRMSNormGated(self.head_v_dim, eps=norm_eps)
155
+ else:
156
+ self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps)
157
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
158
+
159
+ def _initialize_weights(self, module: nn.Module):
160
+ if getattr(module, "_is_hf_initialized", False):
161
+ return
162
+ if isinstance(module, nn.Linear):
163
+ nn.init.xavier_uniform_(module.weight, gain=2**-2.5)
164
+ if module.bias is not None:
165
+ nn.init.zeros_(module.bias)
166
+ module._is_hf_initialized = True
167
+
168
+ def forward(
169
+ self,
170
+ hidden_states: torch.Tensor,
171
+ attention_mask: Optional[torch.Tensor] = None,
172
+ past_key_values: Optional[Cache] = None,
173
+ initial_state: Optional[torch.Tensor] = None,
174
+ use_cache: Optional[bool] = False,
175
+ output_attentions: Optional[bool] = False,
176
+ **kwargs: Unpack[Dict],
177
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]:
178
+ if attention_mask is not None:
179
+ assert len(attention_mask.shape) == 2, (
180
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
181
+ "for padding purposes (0 indicating padding). "
182
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
183
+ )
184
+
185
+ batch_size, q_len, _ = hidden_states.shape
186
+ # change to inference mode.
187
+ mode = self.mode
188
+
189
+ if self.training:
190
+ assert mode == "chunk", "Only chunk mode is supported in training."
191
+
192
+ last_state = None
193
+ if past_key_values is not None and len(past_key_values) > self.layer_idx:
194
+ last_state = past_key_values[self.layer_idx]
195
+
196
+ cu_seqlens = kwargs.get("cu_seqlens", None)
197
+ if attention_mask is not None:
198
+ indices, cu_seqlens, _ = get_unpad_data(attention_mask[:, -q_len:])
199
+ hidden_states = index_first_axis(
200
+ rearrange(hidden_states, "b s ... -> (b s) ..."), indices
201
+ ).unsqueeze(0)
202
+
203
+ if self.use_short_conv:
204
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
205
+ if last_state is not None:
206
+ conv_state_q, conv_state_k, conv_state_v = last_state["conv_state"]
207
+ q, conv_state_q = self.q_conv1d(
208
+ x=self.q_proj(hidden_states),
209
+ cache=conv_state_q,
210
+ output_final_state=use_cache,
211
+ cu_seqlens=cu_seqlens,
212
+ )
213
+ k, conv_state_k = self.k_conv1d(
214
+ x=self.k_proj(hidden_states),
215
+ cache=conv_state_k,
216
+ output_final_state=use_cache,
217
+ cu_seqlens=cu_seqlens,
218
+ )
219
+ v, conv_state_v = self.v_conv1d(
220
+ x=self.v_proj(hidden_states),
221
+ cache=conv_state_v,
222
+ output_final_state=use_cache,
223
+ cu_seqlens=cu_seqlens,
224
+ )
225
+ else:
226
+ q = F.silu(self.q_proj(hidden_states))
227
+ k = F.silu(self.k_proj(hidden_states))
228
+ v = F.silu(self.v_proj(hidden_states))
229
+
230
+ q = rearrange(q, "... (h d) -> ... h d", d=self.head_k_dim)
231
+ k = rearrange(
232
+ k,
233
+ "... l (n h d) -> ... (l n) h d",
234
+ n=self.num_householder,
235
+ d=self.head_k_dim,
236
+ )
237
+ v = rearrange(
238
+ v,
239
+ "... l (n h d) -> ... (l n) h d",
240
+ n=self.num_householder,
241
+ d=self.head_v_dim,
242
+ )
243
+
244
+ if self.num_v_heads > self.num_heads:
245
+ q, k = map(
246
+ lambda x: repeat(
247
+ x, "... h d -> ... (h g) d", g=self.num_v_heads // self.num_heads
248
+ ),
249
+ (q, k),
250
+ )
251
+
252
+ beta = self.b_proj(hidden_states).sigmoid()
253
+ if self.allow_neg_eigval:
254
+ beta = beta * 2.0
255
+
256
+ beta = rearrange(beta, "... l (n h) -> ... (l n) h", n=self.num_householder)
257
+ if self.use_forget_gate:
258
+ g = -self.A_log.float().exp() * F.softplus(
259
+ self.a_proj(hidden_states).float() + self.dt_bias
260
+ )
261
+ else:
262
+ g = None
263
+
264
+ recurrent_state = (
265
+ last_state["recurrent_state"] if last_state is not None else None
266
+ )
267
+ if mode == "chunk":
268
+ o, recurrent_state = chunk_gated_delta_product(
269
+ q=q,
270
+ k=k,
271
+ v=v,
272
+ g=g,
273
+ beta=beta,
274
+ initial_state=initial_state,
275
+ output_final_state=output_attentions,
276
+ cu_seqlens=cu_seqlens,
277
+ num_householder=self.num_householder,
278
+ use_qk_l2norm_in_kernel=True,
279
+ )
280
+
281
+ elif mode == "fused_recurrent":
282
+ if self.use_forget_gate:
283
+ g_new = torch.zeros(
284
+ g.shape[0],
285
+ g.shape[1],
286
+ self.num_householder,
287
+ g.shape[2],
288
+ device=g.device,
289
+ dtype=torch.float32,
290
+ )
291
+ g_new[:, :, 0] = g
292
+ g = rearrange(g_new, "... l n h -> ... (l n) h")
293
+
294
+ q_new = q.new_zeros(
295
+ q.shape[0], q.shape[1], self.num_householder, q.shape[2], q.shape[3]
296
+ )
297
+ q_new[:, :, -1] = q
298
+ q = rearrange(q_new, "... l n h d-> ... (l n) h d")
299
+ if self.use_forget_gate:
300
+ o, recurrent_state = fused_recurrent_gated_delta_rule(
301
+ q=q,
302
+ k=k,
303
+ v=v,
304
+ g=g,
305
+ beta=beta,
306
+ initial_state=recurrent_state,
307
+ output_final_state=use_cache,
308
+ cu_seqlens=cu_seqlens * self.num_householder
309
+ if cu_seqlens is not None
310
+ else None,
311
+ use_qk_l2norm_in_kernel=True,
312
+ )
313
+ else:
314
+ o, recurrent_state = fused_recurrent_delta_rule(
315
+ q=q,
316
+ k=k,
317
+ v=v,
318
+ beta=beta,
319
+ initial_state=recurrent_state,
320
+ output_final_state=use_cache,
321
+ cu_seqlens=cu_seqlens * self.num_householder
322
+ if cu_seqlens is not None
323
+ else None,
324
+ use_qk_l2norm_in_kernel=True,
325
+ )
326
+ o = rearrange(o, "... (l n) h d -> ... l n h d", n=self.num_householder)[
327
+ ..., -1, :, :
328
+ ].contiguous()
329
+
330
+ if past_key_values is not None:
331
+ past_key_values.update(
332
+ recurrent_state=recurrent_state,
333
+ conv_state=(conv_state_q, conv_state_k, conv_state_v)
334
+ if self.use_short_conv
335
+ else None,
336
+ layer_idx=self.layer_idx,
337
+ offset=q_len,
338
+ )
339
+
340
+ if self.use_gate:
341
+ g = rearrange(
342
+ self.g_proj(hidden_states), "... (h d) -> ... h d", d=self.head_v_dim
343
+ )
344
+ o = self.o_norm(o, g)
345
+ else:
346
+ o = self.o_norm(o)
347
+ o = rearrange(o, "b t h d -> b t (h d)")
348
+ o = self.o_proj(o)
349
+ if attention_mask is not None:
350
+ o = pad_input(o.squeeze(0), indices, batch_size, q_len)
351
+ return o, recurrent_state, past_key_values
src/models/gated_deltaproduct/modeling_gated_deltaproduct.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from fla.layers.attn import Attention
10
+ from fla.models.utils import Cache
11
+ from fla.modules import GatedMLP as GatedDeltaProductMLP
12
+ from fla.modules import RMSNorm
13
+
14
+ from src.models.gated_deltaproduct.configuration_gated_deltaproduct import (
15
+ GatedDeltaProductConfig,
16
+ )
17
+ from src.models.gated_deltaproduct.gated_deltaproduct import GatedDeltaProduct
18
+
19
+ if TYPE_CHECKING:
20
+ from transformers.processing_utils import Unpack
21
+
22
+
23
+ class GatedDeltaProductBlock(nn.Module):
24
+ def __init__(self, config: GatedDeltaProductConfig, layer_idx: int):
25
+ super().__init__()
26
+
27
+ self.config = config
28
+ self.layer_idx = layer_idx
29
+
30
+ self.attn_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(
31
+ config.hidden_size, eps=config.norm_eps
32
+ )
33
+ if config.attn is not None and layer_idx in config.attn["layers"]:
34
+ self.attn = Attention(
35
+ hidden_size=config.hidden_size,
36
+ num_heads=config.attn["num_heads"],
37
+ num_kv_heads=config.attn["num_kv_heads"],
38
+ qkv_bias=config.attn["qkv_bias"],
39
+ window_size=config.attn["window_size"],
40
+ rope_theta=config.attn["rope_theta"],
41
+ max_position_embeddings=config.max_position_embeddings,
42
+ layer_idx=layer_idx,
43
+ )
44
+ else:
45
+ self.attn = GatedDeltaProduct(
46
+ mode=config.attn_mode,
47
+ hidden_size=config.hidden_size,
48
+ expand_v=config.expand_v,
49
+ head_dim=config.head_dim,
50
+ num_heads=config.num_heads,
51
+ use_gate=config.use_gate,
52
+ use_forget_gate=config.use_forget_gate,
53
+ use_short_conv=config.use_short_conv,
54
+ conv_size=config.conv_size,
55
+ norm_eps=config.norm_eps,
56
+ allow_neg_eigval=config.allow_neg_eigval,
57
+ num_householder=config.num_householder,
58
+ layer_idx=layer_idx,
59
+ )
60
+ self.mlp_norm = (RMSNorm if config.fuse_norm else nn.RMSNorm)(
61
+ config.hidden_size, eps=config.norm_eps
62
+ )
63
+ self.mlp = GatedDeltaProductMLP(
64
+ hidden_size=config.hidden_size,
65
+ hidden_ratio=config.hidden_ratio,
66
+ intermediate_size=config.intermediate_size,
67
+ hidden_act=config.hidden_act,
68
+ fuse_swiglu=config.fuse_swiglu,
69
+ )
70
+
71
+ def forward(
72
+ self,
73
+ hidden_states: torch.Tensor,
74
+ attention_mask: Optional[torch.Tensor] = None,
75
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
76
+ use_cache: Optional[bool] = False,
77
+ output_attentions: Optional[bool] = False,
78
+ initial_state: Optional[torch.FloatTensor] = None,
79
+ **kwargs: Unpack[Dict],
80
+ ) -> Tuple[
81
+ torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
82
+ ]:
83
+ residual = hidden_states
84
+ hidden_states = self.attn_norm(hidden_states)
85
+ hidden_states, attentions, past_key_values = self.attn(
86
+ hidden_states=hidden_states,
87
+ attention_mask=attention_mask,
88
+ past_key_values=past_key_values,
89
+ use_cache=use_cache,
90
+ output_attentions=output_attentions,
91
+ initial_state=initial_state,
92
+ **kwargs,
93
+ )
94
+ if self.config.fuse_norm:
95
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
96
+ else:
97
+ hidden_states = residual + hidden_states
98
+ residual = hidden_states
99
+ hidden_states = self.mlp_norm(hidden_states)
100
+ hidden_states = self.mlp(hidden_states, **kwargs)
101
+ hidden_states = residual + hidden_states
102
+
103
+ outputs = (hidden_states, attentions, past_key_values)
104
+
105
+ return outputs