Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions nemo_rl/utils/flops_formulas.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def gpt3(config: FLOPSConfig):


def llama(config: FLOPSConfig):
"""Model FLOPs for llama3 family."""
"""Model FLOPs for llama2/3 family."""
return (
config.gbs
* config.enc_seq_len
Expand All @@ -87,6 +87,24 @@ def llama(config: FLOPSConfig):
)


def llama4(config: FLOPSConfig):
"""Model FLOPs for llama4 family."""
return (
config.gbs
* config.enc_seq_len
* config.layers
* config.hs
* config.hs
* (
12
+ (12 * config.query_groups / config.attention_heads)
+ (18 * config.moe_router_topk * config.ffn_hs / config.hs)
+ (6 * config.enc_seq_len / config.hs)
+ (6 * config.vocab_size / (config.layers * config.hs))
)
)


def nemotron(config: FLOPSConfig):
"""Model FLOPs for nemotron family."""
return (
Expand Down Expand Up @@ -117,7 +135,7 @@ def mixtral(config: FLOPSConfig):
12
+ (12 * config.query_groups / config.attention_heads)
+ (18 * config.moe_router_topk * config.ffn_hs / config.hs)
+ (12 * config.enc_seq_len / config.hs)
+ (6 * config.enc_seq_len / config.hs)
+ (6 * config.vocab_size / (config.layers * config.hs))
)
)
Expand Down
47 changes: 45 additions & 2 deletions nemo_rl/utils/flops_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,21 @@
from transformers import AutoConfig
from transformers.configuration_utils import PretrainedConfig
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama4.configuration_llama4 import Llama4Config
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
from transformers.models.qwen3_moe.configuration_qwen3_moe import Qwen3MoeConfig

from nemo_rl.models.policy.utils import sliding_window_overwrite
from nemo_rl.utils.flops_formulas import FLOPSConfig, llama, qwen2, qwen3
from nemo_rl.utils.flops_formulas import (
FLOPSConfig,
llama,
llama4,
mixtral,
qwen2,
qwen3,
)


def get_default_hf_config(model_name: str) -> PretrainedConfig:
Expand Down Expand Up @@ -54,7 +63,7 @@ def convert_config_to_flops_config(
ffn_hs=config.intermediate_size,
vocab_size=config.vocab_size,
), qwen2
elif isinstance(config, (Qwen3Config, Qwen3MoeConfig)):
elif isinstance(config, Qwen3Config):
return FLOPSConfig(
gbs=0,
hs=config.hidden_size,
Expand All @@ -67,6 +76,18 @@ def convert_config_to_flops_config(
moe_ffn_hidden_size=config.intermediate_size,
moe_router_topk=1,
), qwen3
elif isinstance(config, Qwen3MoeConfig):
return FLOPSConfig(
gbs=0,
hs=config.hidden_size,
layers=config.num_hidden_layers,
ffn_hs=config.intermediate_size,
vocab_size=config.vocab_size,
query_groups=config.num_key_value_heads,
attention_heads=config.num_attention_heads,
moe_ffn_hidden_size=config.moe_intermediate_size,
moe_router_topk=config.num_experts_per_tok,
), qwen3
elif isinstance(config, LlamaConfig):
return FLOPSConfig(
gbs=0,
Expand All @@ -77,6 +98,28 @@ def convert_config_to_flops_config(
attention_heads=config.num_attention_heads,
vocab_size=config.vocab_size,
), llama
elif isinstance(config, Llama4Config):
return FLOPSConfig(
gbs=0,
hs=config.text_config.hidden_size,
layers=config.text_config.num_hidden_layers,
ffn_hs=config.text_config.intermediate_size_mlp,
query_groups=config.text_config.num_key_value_heads,
attention_heads=config.text_config.num_attention_heads,
vocab_size=config.text_config.vocab_size,
moe_router_topk=1,
), llama4
elif isinstance(config, MixtralConfig):
return FLOPSConfig(
gbs=0,
hs=config.hidden_size,
layers=config.num_hidden_layers,
ffn_hs=config.intermediate_size,
vocab_size=config.vocab_size,
query_groups=config.num_key_value_heads,
attention_heads=config.num_attention_heads,
moe_router_topk=config.num_experts_per_tok,
), mixtral
else:
raise ValueError(f"Unsupported config type: {type(config)}")

Expand Down
6 changes: 5 additions & 1 deletion tests/unit/utils/test_flops_counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,12 @@
("meta-llama/Meta-Llama-3-8B", 128, 8192, 5.31e16),
("meta-llama/Llama-3.1-70B-Instruct", 128, 8192, 4.71e17),
("meta-llama/Llama-3.1-405B-Instruct", 128, 8192, 2.65e18),
("meta-llama/Llama-4-Scout-17B-16E", 128, 8192, 1.14e17),
("meta-llama/Llama-4-Maverick-17B-128E", 128, 8192, 1.14e17),
("Qwen/Qwen3-30B-A3B", 128, 4096, 9.37e15),
("Qwen/Qwen3-235B-A22B", 128, 4096, 6.21e16),
("mistralai/Mixtral-8x7B-v0.1", 128, 4096, 4.18e16),
("mistralai/Mixtral-8x22B-v0.1", 128, 65536, 3.1e18),
],
)
def test_flops_counter(model_name, gbs, seqlen, expected_flops):
Expand All @@ -37,5 +41,5 @@ def test_flops_counter(model_name, gbs, seqlen, expected_flops):

# check within 5% relative difference
assert abs(flops_tracker.total_flops - expected_flops) / expected_flops <= 0.05, (
f"Expected {expected_flops} flops, got {flops_tracker.total_flops}"
f"{model_name}: Expected {expected_flops} flops, got {flops_tracker.total_flops}"
)
Loading