You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
{{ message }}
This repository was archived by the owner on Sep 24, 2025. It is now read-only.
log.info("Updating the model weights") # this is a blocking operation, we need to wait until the weights are updated before we can start generating rollouts
# TODO: start generating rollouts and put it in the experience buffer
70
70
71
-
experience_buffer=torch.tensor([6]).to('cuda')
71
+
experience_buffer=torch.tensor([20+i])
72
72
experience_buffer_work=torch.distributed.broadcast(group=experience_buffer_group, src=1,tensor=experience_buffer, async_op=True) # don't block, send it off and continue generating rollouts
torch.distributed.broadcast(group=model_update_group, src=0,tensor=is_ready_to_update) # BLOCKING, let the other process know that we're ready to update the model weights
torch.distributed.broadcast(group=model_update_group, src=0,tensor=weights) # broadcast all the model weights, BLOCKING
63
-
log.info(f"Rank {dist.get_global_rank()} Broadcasted model weights{weights}") # TODO: update the model weights
60
+
log.info(f"Rank {dist.get_global_rank()} Broadcasted model weights{weights}") # TODO: update the model weights
64
61
65
62
# Get the experience buffer results from the rollout process
66
-
experience_buffer=torch.tensor([0]).to('cuda')
63
+
experience_buffer=torch.tensor([0])
67
64
ifexperience_buffer_groupisnotNone:
68
65
torch.distributed.broadcast(group=experience_buffer_group, src=1,tensor=experience_buffer) # block until the broadcast is complete, need to get the new experiences
0 commit comments