|
54 | 54 |
|
55 | 55 | if model_update_group is not None:
|
56 | 56 | is_ready_to_update = torch.tensor([1]).to('cuda')
|
57 |
| - is_ready_to_update_work = torch.distributed.broadcast(group=model_update_group, src=0,tensor=is_ready_to_update, async_op=True) |
| 57 | + torch.distributed.broadcast(group=model_update_group, src=0,tensor=is_ready_to_update) # BLOCKING, let the other process know that we're ready to update the model weights |
58 | 58 | log.info(f"Rank {dist.get_global_rank()} Broadcasted is_ready_to_update{is_ready_to_update}")
|
59 | 59 |
|
60 |
| - is_ready_to_update_work.wait() # wait until the broadcast is complete (the rollout process has received the message) before we update the model weights |
61 |
| - |
62 | 60 | # Actually broadcast the model weights
|
63 | 61 | weights = torch.tensor([5]).to('cuda')
|
64 |
| - torch.distributed.broadcast(group=model_update_group, src=0,tensor=weights, async_op=True) # broadcast all the model weights |
| 62 | + torch.distributed.broadcast(group=model_update_group, src=0,tensor=weights) # broadcast all the model weights, BLOCKING |
65 | 63 | log.info(f"Rank {dist.get_global_rank()} Broadcasted model weights{weights}") # TODO: update the model weights
|
66 | 64 |
|
67 |
| - # TODO: get the experience buffer results from the rollout process |
| 65 | + # Get the experience buffer results from the rollout process |
68 | 66 | experience_buffer = torch.tensor([0]).to('cuda')
|
69 | 67 | if experience_buffer_group is not None:
|
70 | 68 | torch.distributed.broadcast(group=experience_buffer_group, src=1,tensor=experience_buffer) # block until the broadcast is complete, need to get the new experiences
|
|
75 | 73 |
|
76 | 74 | # distributed the experiences results to each of the training ranks
|
77 | 75 |
|
78 |
| - # TODO: train the model |
| 76 | + # TODO: train the model TRAINING CODE HERE |
79 | 77 |
|
80 | 78 | log.info(f"Completed iteration {i + 1}/{MAX_ITERATIONS}")
|
81 | 79 |
|
0 commit comments