Skip to content
This repository was archived by the owner on Sep 24, 2025. It is now read-only.

Commit 2240f6d

Browse files
committed
wait for the last experience buffer to be received
1 parent 52d02b5 commit 2240f6d

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

rollout.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,6 @@
99
NUM_INFERENCE_ENGINES=1
1010
MAX_ITERATIONS=2
1111

12-
# Global iteration tracker
13-
CURRENT_ITERATION = 0
14-
1512
logging.basicConfig(
1613
# Example of format string
1714
# 2022-06-29 11:22:26,152: rank0[822018][MainThread]: INFO: composer.trainer.trainer: Using precision Precision.FP32
@@ -72,10 +69,16 @@
7269
# TODO: start generating rollouts and put it in the experience buffer
7370

7471
experience_buffer = torch.tensor([6]).to('cuda')
75-
torch.distributed.broadcast(group=experience_buffer_group, src=1,tensor=experience_buffer, async_op=True) # don't block, send it off and continue generating rollouts
76-
log.info(f"Rank {dist.get_global_rank()} Sent experience buffer {experience_buffer}")
72+
experience_buffer_work = torch.distributed.broadcast(group=experience_buffer_group, src=1,tensor=experience_buffer, async_op=True) # don't block, send it off and continue generating rollouts
73+
log.info(f"Sent experience buffer {experience_buffer}")
7774

7875
log.info(f"Completed iteration {i + 1}/{MAX_ITERATIONS}")
7976

77+
if i == MAX_ITERATIONS - 1:
78+
assert experience_buffer_work is not None
79+
log.info(f"Waiting for the last experience buffer to be received")
80+
experience_buffer_work.wait()
81+
82+
8083

8184

0 commit comments

Comments
 (0)