1111from typing import Optional
1212import torch
1313from deepspeed import comm as dist
14- from .layers import LinearAllreduce , LinearLayer , LmHeadLinearAllreduce
14+ from .layers import LinearAllreduce , LinearLayer , LmHeadLinearAllreduce , Yuan_LinearAllreduce , Yuan_LinearLayer , GateUpPack_LinearLayer , Conv_LinearALlreduce , fused_LinearLayer , conv_LinearLayer
1515from deepspeed .accelerator import get_accelerator
16- from .fusedqkv_utils import require_tp_fused_qkvw , prepare_tp_fused_qkvw , shard_value_with_share_qk , shard_chunk_mlp
16+ from .fusedqkv_utils import require_tp_fused_qkvw
1717from deepspeed .module_inject .tp_shard import get_shard_size , get_shard_size_list
18+ from deepspeed .utils import groups
19+ from deepspeed .module_inject .layers import is_autotp_training_mode
1820
1921
2022def move (tensor , device , copy = True ):
@@ -333,10 +335,18 @@ def tp_parser(model):
333335 return policy_list
334336
335337 def set_tensor_parallel_config (self , mp_size , mp_group ):
338+
339+ if is_autotp_training_mode ():
340+ self .mp_group = groups .get_tensor_model_parallel_group ()
341+ self .mp_size = groups .get_tensor_model_parallel_world_size ()
342+ return
343+
336344 self .mp_size = mp_size
337345 self .mp_group = mp_group
338346
339347 def _replace (self , child , name , conv_linear_layer ):
348+ # This function should clearly define the routing rules for specific layers
349+ # and avoid any complex shard-related logic.
340350 if getattr (child , "replaced" , False ) == True :
341351 return
342352 device_name = 'cpu' if self .keep_module_on_host else get_accelerator ().current_device_name ()
@@ -352,80 +362,41 @@ def _replace(self, child, name, conv_linear_layer):
352362 # For Yuan model
353363 if 'Yuan' in str (self .module ):
354364 if 'v_proj' in name :
355- weight , bias = shard_value_with_share_qk (child .weight .data , child .bias , dist .get_rank (),
356- dist .get_world_size (), True )
357- return LinearLayer (weight = weight , bias = bias )
365+ return Yuan_LinearLayer (child , self .mp_group )
366+
358367 elif 'o_proj' in name :
359- weight , bias = shard_value_with_share_qk (child .weight .data , child .bias , dist .get_rank (),
360- dist .get_world_size (), False )
361- return LinearAllreduce (weight , bias , self .mp_group )
362- # For Arctic model, bypass to all_reduce replacement for w2 weights
368+ return Yuan_LinearAllreduce (child , self .mp_group )
369+
370+ # For MLP including chunk layer.
371+ if 'gate_up_proj' in name or ('dense_h_to_4h' in name and 'GLM' in str (self .module )):
372+ return GateUpPack_LinearLayer (child , self .mp_group )
373+ # For Arctic model, bypass to all_reduce replacement for w2 weights
363374 arctic_w2_all_reduce_linear = False
364375 if 'Arctic' in str (self .module ) and 'w2' in name :
365376 arctic_w2_all_reduce_linear = True
366377 # For MoE MLP model, e.g., deepseek and jamba
367378 down_proj = False
368379 if 'down_proj' in name :
369380 down_proj = True
370- # For MLP including chunk layer.
371- if 'gate_up_proj' in name or ('dense_h_to_4h' in name and 'GLM' in str (self .module )):
372- weight , bias = shard_chunk_mlp (child .weight .data , child .bias , dist .get_rank (), dist .get_world_size ())
373- return LinearLayer (weight = weight , bias = bias )
374381 if name in self .all_reduce_linears or arctic_w2_all_reduce_linear or down_proj :
375- # if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
376- # else [weight_shape[0], weight_shape[1] // mp_size]
377382
383+ setattr (child , "replaced" , True )
378384 if self .conv_linear_layer :
379- child .weight .data = child .weight .data .transpose (- 1 , - 2 ).contiguous ()
380- data = child .weight .data .split (get_shard_size_list (
381- weight_shape [0 ] if self .conv_linear_layer else weight_shape [1 ], self .mp_size , name ),
382- dim = 1 )
383- data_dc = move (data [mp_replace .gpu_index ], device_name , return_new_copy ).detach ()
384- del data
385+ return Conv_LinearALlreduce (child , self .mp_group , name = name )
386+ elif name == "lm_head" or name == 'embed_out' :
387+ return LmHeadLinearAllreduce (child , self .mp_group )
385388
386- setattr (child , "replaced" , True )
387- if name == "lm_head" or name == 'embed_out' :
388- return LmHeadLinearAllreduce (
389- torch .nn .parameter .Parameter (data_dc , requires_grad = False ), dist .get_rank (), dist .get_world_size (),
390- child .bias if child .bias is None else torch .nn .parameter .Parameter (
391- move (child .bias , device_name , return_new_copy )), self .mp_group )
392- return LinearAllreduce (torch .nn .parameter .Parameter (data_dc , requires_grad = False ), child .bias if child .bias is None else \
393- torch .nn .parameter .Parameter (move (child .bias , device_name , return_new_copy )), self .mp_group )
389+ return LinearAllreduce (child , self .mp_group , name = name )
394390 else :
395391
396- # if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
397- # else [weight_shape[0] // mp_size, weight_shape[1]]
392+ setattr (child , "replaced" , True )
398393 if self .conv_linear_layer :
399- child .weight .data = child .weight .data .transpose (- 1 , - 2 ).contiguous ()
400-
401- if require_tp_fused_qkvw (name , self .mp_size ):
394+ conv_LinearLayer (child , self .mp_group )
395+ elif require_tp_fused_qkvw (name , self .mp_size ):
402396 #Check and handle fused qkv for TP
403- #The copy is a regular copy, The shape of dst and src is the same
404- data_dc = move (
405- prepare_tp_fused_qkvw (self .module , child .weight .data , self .mp_size , mp_replace .gpu_index ),
406- device_name , return_new_copy )
407-
408- bias_data_dc = None if child .bias is None else move (
409- prepare_tp_fused_qkvw (self .module , child .bias .data , self .mp_size , mp_replace .gpu_index ),
410- device_name , return_new_copy )
411- else :
412- data = child .weight .data .split (get_shard_size_list (weight_shape [0 ], self .mp_size , name ),
413- dim = 1 if self .conv_linear_layer else 0 )
414- data_dc = move (data [mp_replace .gpu_index ], device_name , return_new_copy ).detach ()
415- del data
416-
417- if child .bias is not None :
418- bias_data = child .bias .data .split (get_shard_size_list (
419- weight_shape [1 ] if self .conv_linear_layer else weight_shape [0 ], self .mp_size , name ),
420- dim = 0 )
421- bias_data = move (bias_data [mp_replace .gpu_index ], device_name , return_new_copy )
422- bias_data_dc = torch .nn .parameter .Parameter (bias_data , requires_grad = False )
423- del bias_data
424- else :
425- bias_data_dc = None
397+ return fused_LinearLayer (child , self .mp_group , fused_module = self .module )
426398
427- setattr (child , "replaced" , True )
428- return LinearLayer (weight = torch .nn .parameter .Parameter (data_dc , requires_grad = False ), bias = bias_data_dc )
399+ return LinearLayer (child , self .mp_group , name = name )
429400
430401 def _slice_embedding (self , child , name , conv_linear_layer ):
431402 if getattr (child , "replaced" , False ) == True :
0 commit comments