Skip to content
This repository was archived by the owner on Sep 24, 2025. It is now read-only.

Commit d443a1b

Browse files
committed
simple prototype is done and working on single node
1 parent 2240f6d commit d443a1b

File tree

3 files changed

+11
-117
lines changed

3 files changed

+11
-117
lines changed

rollout.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
group_name="model_update_group",
3333
)
3434
experience_buffer_group = init_process_group(
35-
backend="nccl",
35+
backend="gloo",
3636
init_method=f"tcp://localhost:{EXPERIENCE_BUFFER_PORT}",
3737
world_size=1 + NUM_INFERENCE_ENGINES,
3838
rank=rank,
@@ -58,7 +58,7 @@
5858
log.info(f"Weights are ready to update")
5959

6060
log.info("Updating the model weights") # this is a blocking operation, we need to wait until the weights are updated before we can start generating rollouts
61-
weights = torch.tensor([i]).to('cuda')
61+
weights = torch.tensor([0]).to('cuda')
6262
torch.distributed.broadcast(group=model_update_group, src=0,tensor=weights)
6363
log.info(f"Updating the weights to {weights}")
6464
# rest the update check
@@ -68,7 +68,7 @@
6868

6969
# TODO: start generating rollouts and put it in the experience buffer
7070

71-
experience_buffer = torch.tensor([6]).to('cuda')
71+
experience_buffer = torch.tensor([20+i])
7272
experience_buffer_work = torch.distributed.broadcast(group=experience_buffer_group, src=1,tensor=experience_buffer, async_op=True) # don't block, send it off and continue generating rollouts
7373
log.info(f"Sent experience buffer {experience_buffer}")
7474

test_no_ray.py

Lines changed: 3 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -1,114 +1,10 @@
1-
import logging
2-
import os
31
import subprocess
4-
import tempfile
52
import traceback
6-
from typing import Any, Dict, Optional
7-
from composer.cli.launcher import _launch_processes, _monitor_processes, _cleanup_processes, _aggregate_process_returncode
8-
9-
def run_distributed_training(
10-
nproc: int,
11-
world_size: int,
12-
base_rank: int,
13-
node_rank: int,
14-
master_addr: str,
15-
master_port: int,
16-
training_script: str,
17-
training_script_args: Any=None,
18-
module_mode: bool=False,
19-
command_mode: bool=False,
20-
stdout: Optional[str]=None,
21-
stderr: Optional[str]=None,
22-
verbose: bool = False,
23-
) -> int:
24-
"""
25-
Run distributed training with the given parameters.
26-
27-
Args:
28-
nproc (int): Number of processes to launch.
29-
world_size (int): Total number of processes across all nodes.
30-
base_rank (int): Base rank of the current node.
31-
node_rank (int): Rank of the current node.
32-
master_addr (str): Address of the master node.
33-
master_port (int): Port of the master node.
34-
module_mode (bool): Whether to use module mode.
35-
command_mode (bool): Whether to use command mode.
36-
stdout (Optional[str]): Stdout file format.
37-
stderr (Optional[str]): Stderr file format.
38-
training_script (str): Training script to run.
39-
training_script_args (Any): Arguments for the training script.
40-
verbose (bool): Whether to use verbose logging.
41-
42-
Returns:
43-
int: Aggregated return code from all processes.
44-
"""
45-
if training_script_args is None:
46-
training_script_args = []
47-
48-
MOSAICML_PLATFORM_ENV_VAR = "MOSAICML_PLATFORM"
49-
MOSAICML_LOG_DIR_ENV_VAR = "MOSAICML_LOG_DIR"
50-
MOSAICML_GPU_LOG_FILE_PREFIX_ENV_VAR = "MOSAICML_GPU_LOG_FILE_PREFIX"
51-
52-
logger = logging.getLogger("distributed_training")
53-
logging.basicConfig()
54-
logger.setLevel(logging.INFO if verbose else logging.WARNING)
55-
56-
processes: Dict[Any, Any] = {}
57-
58-
log_tmpdir = tempfile.TemporaryDirectory()
59-
if stdout is None:
60-
stdout = f'{log_tmpdir.name}/rank{{rank}}.stdout.txt'
61-
if stderr is None:
62-
stderr = f'{log_tmpdir.name}/rank{{rank}}.stderr.txt'
63-
64-
# If running on the Mosaic platform, log all gpu ranks' stderr and stdout to Mosaic platform
65-
if (
66-
os.environ.get(MOSAICML_PLATFORM_ENV_VAR, 'false').lower() == 'true'
67-
and str(os.environ.get(MOSAICML_LOG_DIR_ENV_VAR, 'false')).lower() != 'false'
68-
and os.environ.get(MOSAICML_GPU_LOG_FILE_PREFIX_ENV_VAR, 'false').lower() != 'false'
69-
):
70-
logger.info('Logging all GPU ranks to Mosaic AI Training.')
71-
log_file_format = (
72-
f"{os.environ.get(MOSAICML_LOG_DIR_ENV_VAR)}/"
73-
f"{os.environ.get(MOSAICML_GPU_LOG_FILE_PREFIX_ENV_VAR)}{{local_rank}}.txt"
74-
)
75-
stdout = log_file_format
76-
stderr = None
77-
78-
try:
79-
_launch_processes(
80-
nproc=nproc,
81-
world_size=world_size,
82-
base_rank=base_rank,
83-
node_rank=node_rank,
84-
master_addr=master_addr,
85-
master_port=master_port,
86-
module_mode=module_mode,
87-
command_mode=command_mode,
88-
stdout_file_format=stdout,
89-
stderr_file_format=stderr,
90-
training_script=training_script,
91-
training_script_args=training_script_args,
92-
processes=processes,
93-
)
94-
_monitor_processes(processes)
95-
except Exception:
96-
# Print the exception first, then kill the training processes, since killing
97-
# may take up to CLEANUP_TIMEOUT seconds, and the user should know immediately
98-
# what failed. No need to re-raise the exception, as `aggregate_process_returncode`
99-
# will return an appropriate error code, which will cause the script to exit.
100-
logger.error("Exception occurred during distributed training", exc_info=True)
101-
traceback.print_exc()
102-
print('Killing training processes')
103-
finally:
104-
_cleanup_processes(processes)
105-
log_tmpdir.cleanup()
106-
return _aggregate_process_returncode(processes)
1073

1084

1095
if __name__ == "__main__":
1106
# test on 4 gpus!
111-
7+
# for multinode, we should determine which command to launch on which node
1128
try:
1139
p1 = subprocess.Popen('CUDA_VISIBLE_DEVICES=0,1 composer -n 2 train.py', shell=True)
11410
p2 = subprocess.Popen('CUDA_VISIBLE_DEVICES=2,3 python rollout.py', shell=True)
@@ -119,5 +15,6 @@ def run_distributed_training(
11915
print(traceback.format_exc())
12016
print('Killing training processes')
12117
finally:
122-
_cleanup_processes({0: p1, 1: p2})
18+
p1.terminate()
19+
p2.terminate()
12320

train.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@
2626
torch.distributed.init_process_group(backend="nccl")
2727
log.info(f"Hello from rank {dist.get_global_rank()}")
2828

29-
30-
3129
model_update_group = None
3230
experience_buffer_group = None
3331
if dist.get_global_rank() == 0:
@@ -40,30 +38,29 @@
4038
group_name="model_update_group",
4139
)
4240
experience_buffer_group = init_process_group(
43-
backend="nccl",
41+
backend="gloo",
4442
init_method=f"tcp://localhost:{EXPERIENCE_BUFFER_PORT}",
4543
world_size=1 + NUM_INFERENCE_ENGINES,
4644
rank=0,
4745
group_name="experience_buffer_group",
4846
)
4947

50-
# TODO: broadcast the model weights to the inference engines
5148
for i in range(MAX_ITERATIONS):
5249
# Update global iteration tracker
5350
log.info(f"Starting iteration {i + 1}/{MAX_ITERATIONS}")
5451

5552
if model_update_group is not None:
5653
is_ready_to_update = torch.tensor([1]).to('cuda')
5754
torch.distributed.broadcast(group=model_update_group, src=0,tensor=is_ready_to_update) # BLOCKING, let the other process know that we're ready to update the model weights
58-
log.info(f"Rank {dist.get_global_rank()} Broadcasted is_ready_to_update{is_ready_to_update}")
55+
log.info(f"Rank {dist.get_global_rank()} Broadcasted is_ready_to_update {is_ready_to_update}")
5956

6057
# Actually broadcast the model weights
61-
weights = torch.tensor([5]).to('cuda')
58+
weights = torch.tensor([10+i]).to('cuda')
6259
torch.distributed.broadcast(group=model_update_group, src=0,tensor=weights) # broadcast all the model weights, BLOCKING
63-
log.info(f"Rank {dist.get_global_rank()} Broadcasted model weights{weights}") # TODO: update the model weights
60+
log.info(f"Rank {dist.get_global_rank()} Broadcasted model weights {weights}") # TODO: update the model weights
6461

6562
# Get the experience buffer results from the rollout process
66-
experience_buffer = torch.tensor([0]).to('cuda')
63+
experience_buffer = torch.tensor([0])
6764
if experience_buffer_group is not None:
6865
torch.distributed.broadcast(group=experience_buffer_group, src=1,tensor=experience_buffer) # block until the broadcast is complete, need to get the new experiences
6966
log.info(f"Rank {dist.get_global_rank()} Got experience buffer {experience_buffer}")

0 commit comments

Comments
 (0)