Skip to content

Commit 1c701d7

Browse files
author
Avinash
committed
Fix formatting issues for DataStates-LLM
Signed-off-by: Avinash Maurya <[email protected]>
1 parent 1acf0e3 commit 1c701d7

File tree

8 files changed

+34
-19
lines changed

8 files changed

+34
-19
lines changed

deepspeed/datastates/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,6 @@
1-
# Copyright by DataStates Team (Argonne National Laboratory): https://github.com/DataStates/
2-
# Maintained by DataStates Team (Argonne National Laboratory): https://github.com/DataStates/
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# MIT License Copyright (c) UChicago Argonne LLC, operator of Argonne National Laboratory.
5+
6+
# DeepSpeed Team

deepspeed/datastates/config.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
1-
# Copyright DataStates Team (Argonne National Laboratory): https://github.com/DataStates/
2-
# Maintained by DataStates Team (Argonne National Laboratory): https://github.com/DataStates/
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# MIT License Copyright (c) UChicago Argonne LLC, operator of Argonne National Laboratory.
5+
6+
# DeepSpeed Team
37

48
from deepspeed.runtime.config_utils import DeepSpeedConfigObject
9+
10+
511
class DeepSpeedDataStatesConfig(DeepSpeedConfigObject):
12+
613
def __init__(self, param_dict):
714
super(DeepSpeedDataStatesConfig, self).__init__()
815

@@ -11,4 +18,4 @@ def __init__(self, param_dict):
1118

1219
if "datastates_ckpt" in param_dict.keys():
1320
self.enabled = True
14-
self.config = param_dict["datastates_ckpt"]
21+
self.config = param_dict["datastates_ckpt"]

deepspeed/runtime/checkpoint_engine/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class CheckpointEngine(object):
4040

4141
### Asynchronous Lazy Checkpointing using DataStates-LLM
4242

43-
DataStates-LLM is an asynchrnous checkpointing approach optimized for LLM pre-training and can be obtained at https://github.com/DataStates/datastates-llm. To enable datastates-llm checkpointing, specify the `host_cache_size` (in gigabytes) which reserves pinned host memory for asynchronous checkpoint flushing, and `parser_threads` to parse multiple checkpoint file requests in parallel using the following lines in config.json supplied during the launch:
43+
DataStates-LLM is an asynchronous checkpointing approach optimized for LLM pre-training and can be obtained at https://github.com/DataStates/datastates-llm. To enable datastates-llm checkpointing, specify the `host_cache_size` (in gigabytes) which reserves pinned host memory for asynchronous checkpoint flushing, and `parser_threads` to parse multiple checkpoint file requests in parallel using the following lines in config.json supplied during the launch:
4444
```
4545
{
4646
... other deepspeed config options,

deepspeed/runtime/checkpoint_engine/checkpoint_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,4 @@ def commit(self, tag):
3131

3232
def wait(self):
3333
# To wait in asynchronous checkpoint engines (e.g. DataStates-LLM) for the previous snapshot to finish
34-
pass
34+
pass

deepspeed/runtime/checkpoint_engine/datastates_checkpoint_engine.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
13

2-
# Copyright by DataStates Team (Argonne National Laboratory): https://github.com/DataStates/
3-
# Maintained by DataStates Team (Argonne National Laboratory): https://github.com/DataStates/
4+
# MIT License Copyright (c) UChicago Argonne LLC, operator of Argonne National Laboratory.
45

5-
from deepspeed.utils import logger, log_dist
6+
# DeepSpeed Team
7+
8+
from deepspeed.utils import log_dist
69
from deepspeed.runtime.checkpoint_engine.checkpoint_engine import \
710
CheckpointEngine
811
from datastates.llm import Checkpointing
@@ -13,21 +16,19 @@ class DataStatesCheckpointEngine(CheckpointEngine):
1316
def __init__(self, deepspeed_config, rank):
1417
super().__init__(deepspeed_config)
1518
self.ckpt_engine = Checkpointing(deepspeed_config, rank)
16-
19+
1720
def create(self, tag):
1821
log_dist(f"[DataStates] Checkpoint {tag} is about to be saved!", ranks=[0])
1922
return None
2023

2124
def save(self, state_dict, path: str):
2225
return self.ckpt_engine.save(state_dict, path)
23-
26+
2427
def load(self, path: str, map_location=None):
2528
return self.ckpt_engine.load(path, map_location)
26-
29+
2730
def commit(self, tag):
2831
return self.ckpt_engine.commit(tag)
2932

3033
def wait(self):
3134
return self.ckpt_engine.wait()
32-
33-

deepspeed/runtime/engine.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1064,9 +1064,12 @@ def _configure_checkpointing(self, dist_init_required):
10641064
if self._config is not None and self._config.datastates_config.enabled:
10651065
try:
10661066
from deepspeed.runtime.checkpoint_engine.datastates_checkpoint_engine import DataStatesCheckpointEngine
1067-
self.checkpoint_engine = DataStatesCheckpointEngine(deepspeed_config=self._config, rank=dist.get_rank())
1067+
self.checkpoint_engine = DataStatesCheckpointEngine(deepspeed_config=self._config,
1068+
rank=dist.get_rank())
10681069
except ImportError as err:
1069-
raise Exception(f"The datastates-llm checkpoint engine was not found! Will fall back to torch.save. Details: {err}")
1070+
raise Exception(
1071+
f"The datastates-llm checkpoint engine was not found! Will fall back to torch.save. Details: {err}"
1072+
)
10701073

10711074
dp_rank = groups._get_sequence_data_parallel_rank()
10721075

deepspeed/runtime/pipe/module.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from deepspeed.accelerator import get_accelerator
2323
from deepspeed.checkpoint.utils import clone_tensors_for_torch_save
2424

25+
2526
class PipelineError(Exception):
2627
"""Errors related to the use of deepspeed.PipelineModule """
2728

@@ -617,7 +618,7 @@ def save_state_dict(self, save_dir, checkpoint_engine, exclude_frozen_params=Fal
617618
if exclude_frozen_params:
618619
for n in self._get_frozen_parameter_names(layer):
619620
del orig_state_dict[n]
620-
621+
621622
if debloat_memory:
622623
final_state_dict = clone_tensors_for_torch_save(orig_state_dict)
623624
else:

deepspeed/runtime/swap_tensor/pipelined_optimizer_swapper.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import torch
99
from deepspeed.ops.op_builder import AsyncIOBuilder
1010
from deepspeed import comm as dist
11-
import torch
1211

1312
from deepspeed.runtime.swap_tensor.constants import *
1413
from deepspeed.runtime.swap_tensor.utils import swap_in_tensors, swap_out_tensors, print_object

0 commit comments

Comments
 (0)