Skip to content

Commit a2570f2

Browse files
committed
rebase
1 parent 8844963 commit a2570f2

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

torchtitan/distributed/activation_checkpoint.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,13 @@ def _apply_op_sac(
9191

9292
def _get_custom_policy(meta):
9393
def _custom_policy(ctx, func, *args, **kwargs):
94-
if (
95-
func == torch.ops.aten._to_copy.default
96-
and "cuda" in str(args[0].device)
97-
and "device" in kwargs
98-
and str(kwargs["device"]) == "cpu"
99-
):
100-
return CheckpointPolicy.MUST_SAVE
94+
if (
95+
func == torch.ops.aten._to_copy.default
96+
and "cuda" in str(args[0].device)
97+
and "device" in kwargs
98+
and str(kwargs["device"]) == "cpu"
99+
):
100+
return CheckpointPolicy.MUST_SAVE
101101

102102
mode = "recompute" if ctx.is_recompute else "forward"
103103
mm_count_key = f"{mode}_mm_count"

0 commit comments

Comments
 (0)