Skip to content

Commit ffcdbf0

Browse files
[autoparallel]integrate auto parallel feature with new tracer (#3408)
* [autoparallel] integrate new analyzer in module level * unify the profiling method * polish * fix no codegen bug * fix pass bug * fix liveness test * polish
1 parent 573af84 commit ffcdbf0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+397
-471
lines changed

colossalai/_analyzer/_subclasses/flop_tensor.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,28 @@ def matmul_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
235235
# Inputs contains the shapes of two matrices.
236236
input_shapes = [v.shape for v in inputs]
237237
assert len(input_shapes) == 2, input_shapes
238-
assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes
238+
239+
# There are three cases: 1) gemm, 2) gemv, 3) dot
240+
if all(len(shape) == 2 for shape in input_shapes):
241+
# gemm
242+
assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes
243+
elif all(len(shape) == 1 for shape in input_shapes):
244+
# dot
245+
assert input_shapes[0][0] == input_shapes[1][0], input_shapes
246+
247+
# expand shape
248+
input_shapes[0] = torch.Size([1, input_shapes[0][0]])
249+
input_shapes[1] = torch.Size([input_shapes[1][0], 1])
250+
else:
251+
# gemv
252+
if len(input_shapes[0]) == 1:
253+
assert input_shapes[0][0] == input_shapes[1][-2], input_shapes
254+
input_shapes.reverse()
255+
else:
256+
assert input_shapes[1][0] == input_shapes[0][-1], input_shapes
257+
258+
# expand the shape of the vector to [batch size, 1]
259+
input_shapes[-1] = torch.Size([input_shapes[-1][-1], 1])
239260
flops = reduce(operator.mul, input_shapes[0]) * input_shapes[-1][-1]
240261
return flops
241262

colossalai/_analyzer/fx/codegen.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
from typing import Any, Callable, Dict, Iterable, List, Tuple
22

33
import torch
4+
5+
try:
6+
from torch.fx.graph import CodeGen
7+
except:
8+
pass
49
from torch.fx.graph import (
5-
CodeGen,
610
PythonCode,
711
_custom_builtins,
812
_format_target,
@@ -48,8 +52,8 @@ def _end_of_ckpt(node: Node, ckpt_level: int) -> bool:
4852
"""
4953
Check if the node could end the ckpt region at `ckpt_level`
5054
"""
51-
if len(node.meta['info'].to_recompute) > ckpt_level:
52-
return node.meta['info'].to_recompute[ckpt_level] is not None
55+
if len(node.meta['info'].activation_checkpoint) > ckpt_level:
56+
return node.meta['info'].activation_checkpoint[ckpt_level] is not None
5357
return True
5458

5559

@@ -90,8 +94,8 @@ def _find_nested_ckpt_regions(node_list: List[Node], ckpt_level: int = 0):
9094
current_region = None
9195

9296
for idx, node in enumerate(node_list):
93-
if len(node.meta['info'].to_recompute) > ckpt_level:
94-
act_ckpt_label = node.meta['info'].to_recompute[ckpt_level]
97+
if len(node.meta['info'].activation_checkpoint) > ckpt_level:
98+
act_ckpt_label = node.meta['info'].activation_checkpoint[ckpt_level]
9599

96100
# this activation checkpoint label is not set yet
97101
# meaning this is the first node of the activation ckpt region
@@ -152,12 +156,12 @@ def emit_ckpt_func(body,
152156

153157
# label given by each layer, e.g. if you are currently at level (0, 1, 1)
154158
# the label will be '0_1_1'
155-
label = "_".join([str(idx) for idx in node_list[0].meta['info'].to_recompute[:ckpt_level + 1]])
159+
label = "_".join([str(idx) for idx in node_list[0].meta['info'].activation_checkpoint[:ckpt_level + 1]])
156160
ckpt_fn_def = _gen_ckpt_fn_def(label, inputs)
157161
ckpt_func.append(f'{ckpt_fn_def}\n')
158162

159163
# if there is more level to fetch
160-
if ckpt_level + 1 < max(map(lambda node: len(node.meta['info'].to_recompute), node_list)):
164+
if ckpt_level + 1 < max(map(lambda node: len(node.meta['info'].activation_checkpoint), node_list)):
161165
ckpt_regions = _find_nested_ckpt_regions(node_list, ckpt_level + 1)
162166
start_idx = [item[0] for item in ckpt_regions]
163167
end_idx = [item[1] for item in ckpt_regions]
@@ -215,7 +219,6 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
215219
ckpt_regions = _find_nested_ckpt_regions(nodes, 0)
216220
start_idx = [item[0] for item in ckpt_regions]
217221
end_idx = [item[1] for item in ckpt_regions]
218-
219222
node_list = list(nodes)
220223

221224
node_idx = 0

colossalai/_analyzer/fx/node_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ class MetaInfo:
112112

113113
# should keep the same whenever manipulated
114114
# ============================= Invariant ==================================
115-
to_recompute: Tuple[torch.Tensor] = () # (region_0, region_1, ...) support nested codegen
115+
activation_checkpoint: Tuple[torch.Tensor] = () # (region_0, region_1, ...) support nested codegen
116116
to_offload: Optional[bool] = False
117117
sharding_spec: str = 'RR'
118118

colossalai/_analyzer/fx/passes/shape_prop.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,14 @@ def propagate(self, *args, device=None):
237237
Returns:
238238
Any: The value returned from executing the Module
239239
"""
240-
wrap_fn = lambda elem: MetaTensor(elem, device=device)
240+
241+
# wrap_fn = lambda elem: MetaTensor(elem, device=device)
242+
def wrap_fn(elem, device=device):
243+
if isinstance(elem, torch.Tensor):
244+
return MetaTensor(elem, device=device)
245+
else:
246+
return elem
247+
241248
with self._mode:
242249
return super().run(*tree_map(wrap_fn, args))
243250

colossalai/_analyzer/fx/tracer/bias_addition.py

Lines changed: 78 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -21,69 +21,111 @@ def linear_impl(input, weight, bias=None):
2121

2222

2323
@register_tracer_impl(F.conv1d, name='_bias_addition_impl')
24-
def conv1d_impl(input, weight, **kwargs):
25-
bias = getattr(kwargs, 'bias', None)
24+
def conv1d_impl(input, weight, bias=None, stride=_single(1), padding=_single(0), dilation=_single(1), groups=1):
2625
if bias is None:
27-
return F.conv1d(input, weight, **kwargs)
26+
return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
2827
else:
29-
new_kwargs = kwargs
30-
new_kwargs['bias'] = None
31-
return F.conv1d(input, weight, **kwargs) + bias.reshape((-1, 1))
28+
return F.conv1d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
29+
(-1, 1))
3230

3331

3432
@register_tracer_impl(F.conv2d, name='_bias_addition_impl')
35-
def conv2d_impl(input, weight, **kwargs):
36-
bias = getattr(kwargs, 'bias', None)
33+
def conv2d_impl(input, weight, bias=None, stride=_pair(1), padding=_pair(0), dilation=_pair(1), groups=1):
3734
if bias is None:
38-
return F.conv2d(input, weight, **kwargs)
35+
return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
3936
else:
40-
new_kwargs = kwargs
41-
new_kwargs['bias'] = None
42-
return F.conv2d(input, weight, **kwargs) + bias.reshape((-1, 1, 1))
37+
return F.conv2d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
38+
(-1, 1, 1))
4339

4440

4541
@register_tracer_impl(F.conv3d, name='_bias_addition_impl')
46-
def conv3d_impl(input, weight, **kwargs):
47-
bias = getattr(kwargs, 'bias', None)
42+
def conv3d_impl(input, weight, bias=None, stride=_triple(1), padding=_triple(0), dilation=_triple(1), groups=1):
4843
if bias is None:
49-
return F.conv3d(input, weight, **kwargs)
44+
return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups)
5045
else:
51-
new_kwargs = kwargs
52-
new_kwargs['bias'] = None
53-
return F.conv3d(input, weight, **new_kwargs) + bias.reshape((-1, 1, 1, 1))
46+
return F.conv3d(input, weight, stride=stride, padding=padding, dilation=dilation, groups=groups) + bias.reshape(
47+
(-1, 1, 1, 1))
5448

5549

5650
@register_tracer_impl(F.conv_transpose1d, name='_bias_addition_impl')
57-
def conv_transpose1d_impl(input, weight, **kwargs):
58-
bias = getattr(kwargs, 'bias', None)
51+
def conv_transpose1d_impl(input,
52+
weight,
53+
bias=None,
54+
stride=_single(1),
55+
padding=_single(0),
56+
output_padding=_single(0),
57+
groups=1,
58+
dilation=_single(1)):
5959
if bias is None:
60-
return F.conv_transpose1d(input, weight, **kwargs)
60+
return F.conv_transpose1d(input,
61+
weight,
62+
stride=stride,
63+
padding=padding,
64+
output_padding=output_padding,
65+
groups=groups,
66+
dilation=dilation)
6167
else:
62-
new_kwargs = kwargs
63-
new_kwargs['bias'] = None
64-
return F.conv_transpose1d(input, weight, **new_kwargs) + bias.reshape((-1, 1))
68+
return F.conv_transpose1d(input,
69+
weight,
70+
stride=stride,
71+
padding=padding,
72+
output_padding=output_padding,
73+
groups=groups,
74+
dilation=dilation) + bias.reshape((-1, 1))
6575

6676

6777
@register_tracer_impl(F.conv_transpose2d, name='_bias_addition_impl')
68-
def conv_transpose2d_impl(input, weight, **kwargs):
69-
bias = getattr(kwargs, 'bias', None)
78+
def conv_transpose2d_impl(input,
79+
weight,
80+
bias=None,
81+
stride=_pair(1),
82+
padding=_pair(0),
83+
output_padding=_pair(0),
84+
groups=1,
85+
dilation=_pair(1)):
7086
if bias is None:
71-
return F.conv_transpose2d(input, weight, **kwargs)
87+
return F.conv_transpose2d(input,
88+
weight,
89+
stride=stride,
90+
padding=padding,
91+
output_padding=output_padding,
92+
groups=groups,
93+
dilation=dilation)
7294
else:
73-
new_kwargs = kwargs
74-
new_kwargs['bias'] = None
75-
return F.conv_transpose2d(input, weight, **new_kwargs) + bias.reshape((-1, 1, 1))
95+
return F.conv_transpose2d(input,
96+
weight,
97+
stride=stride,
98+
padding=padding,
99+
output_padding=output_padding,
100+
groups=groups,
101+
dilation=dilation) + bias.reshape((-1, 1, 1))
76102

77103

78104
@register_tracer_impl(F.conv_transpose3d, name='_bias_addition_impl')
79-
def conv_transpose3d_impl(input, weight, **kwargs):
80-
bias = getattr(kwargs, 'bias', None)
105+
def conv_transpose3d_impl(input,
106+
weight,
107+
bias=None,
108+
stride=_triple(1),
109+
padding=_triple(0),
110+
output_padding=_triple(0),
111+
groups=1,
112+
dilation=_triple(1)):
81113
if bias is None:
82-
return F.conv_transpose3d(input, weight, **kwargs)
114+
return F.conv_transpose3d(input,
115+
weight,
116+
stride=stride,
117+
padding=padding,
118+
output_padding=output_padding,
119+
groups=groups,
120+
dilation=dilation)
83121
else:
84-
new_kwargs = kwargs
85-
new_kwargs['bias'] = None
86-
return F.conv_transpose3d(input, weight, **new_kwargs) + bias.reshape((-1, 1, 1, 1))
122+
return F.conv_transpose3d(input,
123+
weight,
124+
stride=stride,
125+
padding=padding,
126+
output_padding=output_padding,
127+
groups=groups,
128+
dilation=dilation) + bias.reshape((-1, 1, 1, 1))
87129

88130

89131
@register_tracer_impl(torch.addmm, name='_bias_addition_impl')

colossalai/_analyzer/fx/tracer/tracer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def create_proxy(self,
155155

156156
def create_node(self, *args, **kwargs) -> Node:
157157
node = super().create_node(*args, **kwargs)
158-
n_info = MetaInfo(node, mod_dir=self.mod_dir, to_recompute=tuple(self.ckpt_regions))
158+
n_info = MetaInfo(node, mod_dir=self.mod_dir, activation_checkpoint=tuple(self.ckpt_regions))
159159
return node
160160

161161
def trace(self,
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from .meta_registry import *
2-
from .metainfo import *
32
from .registry import meta_register
3+
from .shard_metainfo import *

colossalai/auto_parallel/meta_profiler/meta_registry/activation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
import torch
44

5+
from colossalai._analyzer._subclasses.flop_tensor import ewise_flop_counter as elementwise_flop_counter
6+
from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size
57
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
6-
from colossalai.fx.profiler.memory_utils import activation_size
7-
from colossalai.fx.profiler.opcount import elementwise_flop_counter
88

99
from ..registry import meta_register
1010

colossalai/auto_parallel/meta_profiler/meta_registry/binary_elementwise_ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
import torch
44

5+
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
6+
from colossalai._analyzer.fx.node_util import compute_size_in_bytes as activation_size
57
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, OperationDataType, TrainCycleItem
6-
from colossalai.fx.profiler.memory_utils import activation_size
7-
from colossalai.fx.profiler.opcount import flop_mapping
88

99
from ..constants import BCAST_FUNC_OP, NO_SAVE_ACTIVATION
1010
from ..registry import meta_register
@@ -17,7 +17,7 @@ def binary_elementwise_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, Train
1717
"""Meta information generator for binary elementwise operations
1818
NOTE: Some of the binary elementwise operations will discard the input activation after computation, as they
1919
don't need those tensors for back propagation, for example, if there are two tensors being sent for `torch.add`,
20-
they will be discarded right after add operation is done. We create a simple API in `MetaInfo` class to identify
20+
they will be discarded right after add operation is done. We create a simple API in `ShardMetaInfo` class to identify
2121
this behavior, it is critical for better memory estimation.
2222
2323
Returns:

colossalai/auto_parallel/meta_profiler/meta_registry/conv.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import torch
44

5+
from colossalai._analyzer._subclasses.flop_tensor import flop_mapping
6+
from colossalai._analyzer.fx.node_util import compute_size_in_bytes
57
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
68
MemoryCost,
79
OperationData,
@@ -10,8 +12,6 @@
1012
StrategiesVector,
1113
TrainCycleItem,
1214
)
13-
from colossalai.fx.profiler.memory_utils import activation_size
14-
from colossalai.fx.profiler.opcount import flop_mapping
1515
from colossalai.tensor.sharding_spec import ShardingSpec
1616

1717
from ..registry import meta_register
@@ -110,18 +110,18 @@ def convnd_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
110110
# calculate memory cost
111111
# TODO: use profiler to check conv temp memory
112112
# NOTE: currently in SPMD solver we always believe that there will be a new tensor created in forward
113-
fwd_memory_cost = MemoryCost(
114-
activation=activation_size([input_tensor, output_tensor]),
115-
parameter=activation_size([weight_tensor, bias_tensor]) if has_bias else activation_size(weight_tensor),
116-
temp=0,
117-
buffer=0)
118-
119-
bwd_memory_cost = MemoryCost(
120-
activation=activation_size([input_tensor, weight_tensor, bias_tensor])
121-
if has_bias else activation_size([input_tensor, weight_tensor]),
122-
parameter=activation_size([weight_tensor, bias_tensor]) if has_bias else activation_size(weight_tensor),
123-
temp=0,
124-
buffer=0)
113+
fwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, output_tensor]),
114+
parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
115+
if has_bias else compute_size_in_bytes(weight_tensor),
116+
temp=0,
117+
buffer=0)
118+
119+
bwd_memory_cost = MemoryCost(activation=compute_size_in_bytes([input_tensor, weight_tensor, bias_tensor])
120+
if has_bias else compute_size_in_bytes([input_tensor, weight_tensor]),
121+
parameter=compute_size_in_bytes([weight_tensor, bias_tensor])
122+
if has_bias else compute_size_in_bytes(weight_tensor),
123+
temp=0,
124+
buffer=0)
125125

126126
# total cost is the sum of forward and backward cost
127127
total_cost = MemoryCost(activation=fwd_memory_cost.activation + bwd_memory_cost.activation,

0 commit comments

Comments
 (0)