Spaces:
Runtime error
Runtime error
| # Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # Copyright (c) 2024 Arc Institute. All rights reserved. | |
| # Copyright (c) 2024 Michael Poli. All rights reserved. | |
| # Copyright (c) 2024 Stanford University. All rights reserved | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import math | |
| from typing import Optional | |
| # TODO(@cye): Merge MCore HyenaConfig with NeMo HyenaConfig to have all model params in 1 config. | |
| from nemo.collections.llm.gpt.model.megatron.hyena.hyena_config import HyenaConfig | |
| from nemo.utils.flops_formulas import FLOPSConfig | |
| def hyena(config: FLOPSConfig): | |
| """Model FLOPs for Hyena family. FPL = 'flops per layer'.""" | |
| # TODO(@cye): For now, pull the Hyena defaults directly from a constant dataclass. Merge this config with the NeMo | |
| # model config. | |
| hyena_config = HyenaConfig() | |
| # Hyena Parameters | |
| hyena_short_conv_L = hyena_config.short_conv_L | |
| hyena_short_conv_len = hyena_config.hyena_short_conv_len | |
| hyena_medium_conv_len = hyena_config.hyena_medium_conv_len | |
| def _hyena_layer_count(model_pattern: Optional[str]): | |
| """Count how many small, medium, and large Hyena layers there are in the model. Also, count the | |
| number of Attention layers. | |
| """ | |
| S, D, H, A = 0, 0, 0, 0 | |
| if model_pattern is None: | |
| return 0, 0, 0, 0 | |
| for layer in model_pattern: | |
| if layer == "S": | |
| S += 1 | |
| elif layer == "D": | |
| D += 1 | |
| elif layer == "H": | |
| H += 1 | |
| elif layer == "*": | |
| A += 1 | |
| return S, D, H, A | |
| # Count S, D, H, and * layers in HyenaModel. | |
| S, D, H, A = _hyena_layer_count(config.model_pattern) | |
| # Logits FLOPs per batch for a flattened L x H -> V GEMM. | |
| logits_fpl = 2 * config.gbs * config.enc_seq_len * config.hs * config.vocab_size | |
| # Hyena Mixer Common FLOPs - Pre-Attention QKV Projections, Post-Attention Projections, and | |
| # GLU FFN FLOPs per layer. | |
| pre_attn_qkv_proj_fpl = 2 * 3 * config.gbs * config.enc_seq_len * config.hs**2 | |
| post_attn_proj_fpl = 2 * config.gbs * config.enc_seq_len * config.hs**2 | |
| # 3 Batched GEMMs: y = A(gelu(Bx) * Cx) where B,C: H -> F and A: F -> H. | |
| glu_ffn_fpl = 2 * 3 * config.gbs * config.enc_seq_len * config.ffn_hs * config.hs | |
| # Transformer (Self) Attention FLOPs - QK Attention Logits ((L, D) x (D, L)) & Attention-Weighted | |
| # Values FLOPs ((L, L) x (L, D)) | |
| attn_fpl = 2 * 2 * config.gbs * config.hs * config.enc_seq_len**2 | |
| # Hyena Projection | |
| hyena_proj_fpl = 2 * 3 * config.gbs * config.enc_seq_len * hyena_short_conv_L * config.hs | |
| # Hyena Short Conv | |
| hyena_short_conv_fpl = 2 * config.gbs * config.enc_seq_len * hyena_short_conv_len * config.hs | |
| # Hyena Medium Conv | |
| hyena_medium_conv_fpl = 2 * config.gbs * config.enc_seq_len * hyena_medium_conv_len * config.hs | |
| # Hyena Long Conv (FFT) | |
| hyena_long_conv_fft_fpl = config.gbs * 10 * config.enc_seq_len * math.log2(config.enc_seq_len) * config.hs | |
| # Based off of https://gitlab-master.nvidia.com/clara-discovery/savanna/-/blob/main/savanna/mfu.py#L182 | |
| # Assumption: 1x Backwards Pass FLOPS = 2x Forward Pass FLOPS | |
| return 3 * ( | |
| logits_fpl | |
| + config.layers * (pre_attn_qkv_proj_fpl + post_attn_proj_fpl + glu_ffn_fpl) | |
| + A * attn_fpl | |
| + (S + D + H) * hyena_proj_fpl | |
| + S * hyena_short_conv_fpl | |
| + D * hyena_medium_conv_fpl | |
| + H * hyena_long_conv_fft_fpl | |
| ) | |