Skip to content
This repository was archived by the owner on Sep 24, 2025. It is now read-only.

Commit 52d02b5

Browse files
committed
correct training to be blocking
1 parent 49400b4 commit 52d02b5

File tree

2 files changed

+5
-7
lines changed

2 files changed

+5
-7
lines changed

rollout.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
assert is_ready_to_update_work.is_completed()
6161
log.info(f"Weights are ready to update")
6262

63-
log.info("Updating the model weights")
63+
log.info("Updating the model weights") # this is a blocking operation, we need to wait until the weights are updated before we can start generating rollouts
6464
weights = torch.tensor([i]).to('cuda')
6565
torch.distributed.broadcast(group=model_update_group, src=0,tensor=weights)
6666
log.info(f"Updating the weights to {weights}")

train.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,17 +54,15 @@
5454

5555
if model_update_group is not None:
5656
is_ready_to_update = torch.tensor([1]).to('cuda')
57-
is_ready_to_update_work = torch.distributed.broadcast(group=model_update_group, src=0,tensor=is_ready_to_update, async_op=True)
57+
torch.distributed.broadcast(group=model_update_group, src=0,tensor=is_ready_to_update) # BLOCKING, let the other process know that we're ready to update the model weights
5858
log.info(f"Rank {dist.get_global_rank()} Broadcasted is_ready_to_update{is_ready_to_update}")
5959

60-
is_ready_to_update_work.wait() # wait until the broadcast is complete (the rollout process has received the message) before we update the model weights
61-
6260
# Actually broadcast the model weights
6361
weights = torch.tensor([5]).to('cuda')
64-
torch.distributed.broadcast(group=model_update_group, src=0,tensor=weights, async_op=True) # broadcast all the model weights
62+
torch.distributed.broadcast(group=model_update_group, src=0,tensor=weights) # broadcast all the model weights, BLOCKING
6563
log.info(f"Rank {dist.get_global_rank()} Broadcasted model weights{weights}") # TODO: update the model weights
6664

67-
# TODO: get the experience buffer results from the rollout process
65+
# Get the experience buffer results from the rollout process
6866
experience_buffer = torch.tensor([0]).to('cuda')
6967
if experience_buffer_group is not None:
7068
torch.distributed.broadcast(group=experience_buffer_group, src=1,tensor=experience_buffer) # block until the broadcast is complete, need to get the new experiences
@@ -75,7 +73,7 @@
7573

7674
# distributed the experiences results to each of the training ranks
7775

78-
# TODO: train the model
76+
# TODO: train the model TRAINING CODE HERE
7977

8078
log.info(f"Completed iteration {i + 1}/{MAX_ITERATIONS}")
8179

0 commit comments

Comments
 (0)