@@ -513,14 +513,38 @@ def apply_compile(model: nn.Module, compile_config: CompileConfig):
513513 # torch._dynamo.config.capture_scalar_outputs = True
514514 for layer_id , transformer_block in model .layers .named_children ():
515515 # TODO: remove when torch.compile supports fullgraph=True for MoE
516- fullgraph = True
517516 if transformer_block .moe_enabled :
518- fullgraph = False
519- transformer_block = torch .compile (
520- transformer_block ,
521- backend = compile_config .backend ,
522- fullgraph = fullgraph ,
523- )
517+ transformer_block .moe .experts = torch .compile (
518+ transformer_block .moe .experts ,
519+ backend = compile_config .backend ,
520+ fullgraph = True ,
521+ )
522+ transformer_block .moe .shared_experts = torch .compile (
523+ transformer_block .moe .shared_experts ,
524+ backend = compile_config .backend ,
525+ fullgraph = True ,
526+ )
527+ # transformer_block.attention = torch.compile(
528+ # transformer_block.attention,
529+ # backend=compile_config.backend,
530+ # fullgraph=True,
531+ # )
532+ # transformer_block.attention_norm = torch.compile(
533+ # transformer_block.attention_norm,
534+ # backend=compile_config.backend,
535+ # fullgraph=True,
536+ # )
537+ # transformer_block.ffn_norm = torch.compile(
538+ # transformer_block.ffn_norm,
539+ # backend=compile_config.backend,
540+ # fullgraph=True,
541+ # )
542+ else :
543+ transformer_block = torch .compile (
544+ transformer_block ,
545+ backend = compile_config .backend ,
546+ fullgraph = True ,
547+ )
524548 model .layers .register_module (layer_id , transformer_block )
525549
526550 logger .info ("Compiling each TransformerBlock with torch.compile" )
0 commit comments