Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 118 additions & 23 deletions compose_rl/algorithms/online/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,17 @@
from compose_rl.utils import (
add_right_padding,
compute_advantages,
concat_resolved_outputs,
dist_compute_masked_mean_and_var,
filter_resolved_outputs,
flatten,
get_decoded_sequence,
get_entropies,
get_log_probs,
mask_eos,
masked_mean,
masked_sum,
partition_global_batch,
switch_left_to_right_padding,
)

Expand Down Expand Up @@ -424,14 +427,16 @@ def __init__(
self.generations_per_prompt,
)

self.global_iter_batch_size = self.num_batches_per_update * self.global_train_batch_size

log.info(
f'Per iteration using: {self.num_unique_prompts_per_iter} prompts.',
)

if self.num_unique_prompts_per_iter * self.generations_per_prompt != self.global_train_batch_size * self.num_batches_per_update:
raise ValueError(
f'{self.num_unique_prompts_per_iter=} * {self.generations_per_prompt=} must equal {self.global_train_batch_size=} * {self.num_batches_per_update=}',
)
# if self.num_unique_prompts_per_iter * self.generations_per_prompt != self.global_train_batch_size * self.num_batches_per_update:
# raise ValueError(
# f'{self.num_unique_prompts_per_iter=} * {self.generations_per_prompt=} must equal {self.global_train_batch_size=} * {self.num_batches_per_update=}',
# )

self.epochs_per_iteration = ensure_time(
var_config.get('epoch_per_iteration', 1),
Expand All @@ -441,7 +446,9 @@ def __init__(

# Programmatically setting the max buffer size instead of the yaml
var_config['buffer']['max_buffer_size'] = self.num_batches_per_update

self.buffer = MinibatchRolloutBuffer(var_config['buffer'])
self.global_sample_list = []

# Build the KL controller through registries
kl_ctl_name = var_config['kl_controller'].pop('kl_ctl_type')
Expand Down Expand Up @@ -470,6 +477,11 @@ def __init__(
train_config['python_log_level'].upper(),
)

self.same_reward_filter_threshold = var_config.get(
'same_reward_filter_threshold',
None,
)

self.vllm_engines = None
self.num_vllm_engines = 0
self.vllm_tensor_parallel_size = var_config.get(
Expand Down Expand Up @@ -507,7 +519,11 @@ def __init__(
self.test_prompt = 'Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions.'
else:
# HF generate route extra checks
num_gen_calls = self.num_batches_per_update * self.device_train_batch_size // self.device_generate_batch_size
world_size = dist.get_world_size()
num_gen_calls = self.global_iter_batch_size // self.device_generate_batch_size // world_size
log.info(
f'Using {num_gen_calls} generation calls per iteration with HuggingFace generate.',
)
if num_gen_calls <= 0:
raise ValueError(f'{num_gen_calls=} must be greater than 0')

Expand Down Expand Up @@ -602,13 +618,69 @@ def after_load(self, state: State, logger: Logger):
def iteration_start(self, state: State, logger: Logger):
del logger # unused

batch = self._get_next_iter_prompts()
batch = state.device.batch_to_device(batch)

batch = self._get_next_iter_prompts(state)
if self.vllm_engines is not None:
self._update_inference_model(batch)

self._interact_with_env(batch)
num_env_interactions = 0
while len(self.buffer) < self.num_batches_per_update:
if num_env_interactions > 0:
batch = self._get_next_iter_prompts(state)

num_env_interactions += 1

# TODO: the case where we are not filtering
# We do not do an all gather, so this logic is slightly wrong right now
self._interact_with_env(batch)

cur_global_samples = concat_resolved_outputs(
self.global_sample_list,
self.pad_token_idx, # type: ignore
)

bs = cur_global_samples['prompt_id'].shape[0]

log.info(f'Current global batch size is {bs}.')
log.info(
f'Current global iter batch size is {self.global_iter_batch_size}.',
)

if bs >= self.global_iter_batch_size:
log.info(
'We have enough samples, adding samples to the buffer.',
)
rank = dist.get_global_rank()
world_size = dist.get_world_size()

local_samples = {}
for key, value in cur_global_samples.items():
local_samples[key] = partition_global_batch(
value,
world_size=world_size,
rank=rank,
device_train_batch_size=self.device_train_batch_size,
)

local_bs = local_samples['prompt_id'].shape[0]
# Add the local samples to the buffer
for idx in range(local_bs // self.device_train_batch_size):
minibatch = self._extract_minibatch(
batch=local_samples,
idx=idx,
minibatch_size=self.device_train_batch_size,
)
self.buffer.add(minibatch)

log.info(
f'For iteration {self.iter_num}, we have {len(self.buffer)} samples in the buffer. Starting training.',
)
log.info(
f'It took {num_env_interactions} environment interactions to fill the buffer.',
)

# Making sure we correctly parsed the minibatches
assert len(self.buffer) >= self.num_batches_per_update

# Reset and initialize state train dataloader
log.warning(
'trainer._train_data_spec should be updated whenever the dataloader is updated',
Expand All @@ -635,16 +707,25 @@ def iteration_end(self, state: State, logger: Logger):
del logger # unused
self._log_generations_to_logger(state)
self._increment_rl_iter()

self.buffer.reset()

# A list of all samples across ranks
# These can be filtered or unfiltered
self.global_sample_list = []

self.buffer.set_state_dict(
self.train_prompt_loader.state_dict(), # pyright: ignore
0,
)

def _get_next_iter_prompts(self):
def _get_next_iter_prompts(self, state: State):
"""Gets the next iteration's batch of prompts."""
# Sample fewer batches for the Online RL interation depending on the number of generations per prompt
n_unique_batches = self.num_unique_prompts_per_iter // self.global_train_batch_size
log.info(
f'Getting {n_unique_batches} unique batches of prompts for the current iteration.',
)

batches = [
self._get_single_batch_prompts() for _ in range(n_unique_batches)
]
Expand Down Expand Up @@ -696,7 +777,7 @@ def _get_next_iter_prompts(self):
# this is an edge case that we will not hit currently, but just handling it as needed
ret_batch[key] = curr_values

return ret_batch
return state.device.batch_to_device(ret_batch)

def _get_single_batch_prompts(self):
"""Gets a single batch of prompts from the dataloader."""
Expand Down Expand Up @@ -770,6 +851,7 @@ def _interact_with_env(self, batch: dict[str, torch.Tensor]):
)
padded_sequences.append(padded_sequence)
sequences = torch.cat(padded_sequences, dim=0)

# Add the prepared sequences to the batch again
batch['sequences'] = sequences

Expand All @@ -795,8 +877,6 @@ def _interact_with_env(self, batch: dict[str, torch.Tensor]):
f'Finished reward computation for the rollout in {total_reward_time:.4f} seconds.',
)

self.prompts_and_gens.extend(prompts_and_gens)

gen_batch_partial_outputs = (env_outputs, ref_outputs, all_rewards_dict)
# For every partial output we want to resolve them together
# And compute the global per iteration batch advantage's mean and variance
Expand All @@ -805,17 +885,32 @@ def _interact_with_env(self, batch: dict[str, torch.Tensor]):
gen_batch_partial_outputs,
)

# We need to split the resolved outputs into minibatches
for idx in range(bs // self.device_train_batch_size):
minibatch = self._extract_minibatch(
resolved_outputs,
idx,
self.device_train_batch_size,
if self.same_reward_filter_threshold is not None:
log.info(
f'in reward thresholding, trying to filter with: {self.same_reward_filter_threshold}',
)
self.buffer.add(minibatch)
start_time = time.time()
all_gathered_outputs = dist.all_gather_object(resolved_outputs)

# Making sure we correctly parsed the minibatches
assert len(self.buffer) == self.num_batches_per_update
log.info(
f'It took {time.time() - start_time} seconds to gather all resolved outputs.',
)

all_resolved_outputs = concat_resolved_outputs(
all_gathered_outputs, # type: ignore
self.pad_token_idx, # type: ignore
)

# Filter the resolved outputs based on the generation filtering values
resolved_outputs = filter_resolved_outputs(
all_resolved_outputs,
self.same_reward_filter_threshold,
)

self.global_sample_list.append(resolved_outputs)

# TODO: bcui fix
self.prompts_and_gens.extend(prompts_and_gens)

self.actor_critic.train()

Expand Down
2 changes: 1 addition & 1 deletion compose_rl/algorithms/online/model_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def policy_loss(
logits=gen_logits,
)
assert token_entropies.shape == batch['action_mask'].shape, (
f'Token entropies shape {token_entropies.shape} does not match action mask shape {batch["action_mask"].shape}.',
f"Token entropies shape {token_entropies.shape} does not match action mask shape {batch['action_mask'].shape}.",
)
seq_entropies = utils.get_sequence_entropies(
token_entropies=token_entropies,
Expand Down
6 changes: 6 additions & 0 deletions compose_rl/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
batch_process_fine_granularities,
clear_mb_load_balancing_loss,
compute_advantages,
concat_resolved_outputs,
dist_compute_masked_mean_and_var,
extract_packed_chosen_rejected,
filter_resolved_outputs,
flatten,
flip_pad_token_usage_for_generate,
flip_pad_token_usage_in_ffn,
Expand All @@ -42,6 +44,7 @@
masked_sum,
masked_var,
masked_whiten,
partition_global_batch,
process_fine_granularities,
remove_left_padding,
rescale,
Expand Down Expand Up @@ -91,6 +94,9 @@
'make_action_mask',
'flatten',
'sample_wise_masked_mean',
'filter_resolved_outputs',
'concat_resolved_outputs',
'partition_global_batch',
'extract_gsm8k_answer',
'extract_math_answer',
'is_equiv',
Expand Down
Loading
Loading