From 099aaef9d9bb133b8513e939d6b7c7e76e570443 Mon Sep 17 00:00:00 2001 From: brandon Date: Thu, 26 Jun 2025 00:34:36 +0000 Subject: [PATCH 1/7] update --- compose_rl/algorithms/online/callback.py | 141 ++++++++++-- compose_rl/utils/__init__.py | 6 + compose_rl/utils/utils.py | 262 +++++++++++++++++++++++ 3 files changed, 386 insertions(+), 23 deletions(-) diff --git a/compose_rl/algorithms/online/callback.py b/compose_rl/algorithms/online/callback.py index eb932e94..7a5eee61 100644 --- a/compose_rl/algorithms/online/callback.py +++ b/compose_rl/algorithms/online/callback.py @@ -53,7 +53,9 @@ from compose_rl.utils import ( add_right_padding, compute_advantages, + concat_resolved_outputs, dist_compute_masked_mean_and_var, + filter_resolved_outputs, flatten, get_decoded_sequence, get_entropies, @@ -61,6 +63,7 @@ mask_eos, masked_mean, masked_sum, + partition_global_batch, switch_left_to_right_padding, ) @@ -424,14 +427,16 @@ def __init__( self.generations_per_prompt, ) + self.global_iter_batch_size = self.num_batches_per_update * self.global_train_batch_size + log.info( f'Per iteration using: {self.num_unique_prompts_per_iter} prompts.', ) - if self.num_unique_prompts_per_iter * self.generations_per_prompt != self.global_train_batch_size * self.num_batches_per_update: - raise ValueError( - f'{self.num_unique_prompts_per_iter=} * {self.generations_per_prompt=} must equal {self.global_train_batch_size=} * {self.num_batches_per_update=}', - ) + # if self.num_unique_prompts_per_iter * self.generations_per_prompt != self.global_train_batch_size * self.num_batches_per_update: + # raise ValueError( + # f'{self.num_unique_prompts_per_iter=} * {self.generations_per_prompt=} must equal {self.global_train_batch_size=} * {self.num_batches_per_update=}', + # ) self.epochs_per_iteration = ensure_time( var_config.get('epoch_per_iteration', 1), @@ -441,7 +446,9 @@ def __init__( # Programmatically setting the max buffer size instead of the yaml var_config['buffer']['max_buffer_size'] = self.num_batches_per_update + self.buffer = MinibatchRolloutBuffer(var_config['buffer']) + self.global_sample_list = [] # Build the KL controller through registries kl_ctl_name = var_config['kl_controller'].pop('kl_ctl_type') @@ -470,6 +477,11 @@ def __init__( train_config['python_log_level'].upper(), ) + self.same_reward_filter_threshold = var_config.get( + 'same_reward_filter_threshold', + None, + ) + self.vllm_engines = None self.num_vllm_engines = 0 self.vllm_tensor_parallel_size = var_config.get( @@ -507,7 +519,11 @@ def __init__( self.test_prompt = 'Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions.' else: # HF generate route extra checks - num_gen_calls = self.num_batches_per_update * self.device_train_batch_size // self.device_generate_batch_size + world_size = dist.get_world_size() + num_gen_calls = self.global_iter_batch_size // self.device_generate_batch_size // world_size + log.info( + f"Using {num_gen_calls} generation calls per iteration with HuggingFace generate.", + ) if num_gen_calls <= 0: raise ValueError(f'{num_gen_calls=} must be greater than 0') @@ -602,13 +618,69 @@ def after_load(self, state: State, logger: Logger): def iteration_start(self, state: State, logger: Logger): del logger # unused - batch = self._get_next_iter_prompts() - batch = state.device.batch_to_device(batch) - + batch = self._get_next_iter_prompts(state) if self.vllm_engines is not None: self._update_inference_model(batch) - self._interact_with_env(batch) + num_env_interactions = 0 + while len(self.buffer) < self.num_batches_per_update: + if num_env_interactions > 0: + batch = self._get_next_iter_prompts(state) + + num_env_interactions += 1 + + # TODO: the case where we are not filtering + # We do not do an all gather, so this logic is slightly wrong right now + self._interact_with_env(batch) + + cur_global_samples = concat_resolved_outputs( + self.global_sample_list, + self.pad_token_idx, # type: ignore + ) + + bs = cur_global_samples['prompt_id'].shape[0] + + log.info(f"Current global batch size is {bs}.") + log.info( + f"Current global iter batch size is {self.global_iter_batch_size}.", + ) + + if bs >= self.global_iter_batch_size: + log.info( + 'We have enough samples, adding samples to the buffer.', + ) + rank = dist.get_global_rank() + world_size = dist.get_world_size() + + local_samples = {} + for key, value in cur_global_samples.items(): + local_samples[key] = partition_global_batch( + value, + world_size=world_size, + rank=rank, + device_train_batch_size=self.device_train_batch_size, + ) + + local_bs = local_samples['prompt_id'].shape[0] + # Add the local samples to the buffer + for idx in range(local_bs // self.device_train_batch_size): + minibatch = self._extract_minibatch( + batch=local_samples, + idx=idx, + minibatch_size=self.device_train_batch_size, + ) + self.buffer.add(minibatch) + + log.info( + f"For iteration {self.iter_num}, we have {len(self.buffer)} samples in the buffer. Starting training.", + ) + log.info( + f"It took {num_env_interactions} environment interactions to fill the buffer.", + ) + + # Making sure we correctly parsed the minibatches + assert len(self.buffer) >= self.num_batches_per_update + # Reset and initialize state train dataloader log.warning( 'trainer._train_data_spec should be updated whenever the dataloader is updated', @@ -635,16 +707,25 @@ def iteration_end(self, state: State, logger: Logger): del logger # unused self._log_generations_to_logger(state) self._increment_rl_iter() + self.buffer.reset() + + # A list of all samples across ranks + # These can be filtered or unfiltered + self.global_sample_list = [] + self.buffer.set_state_dict( self.train_prompt_loader.state_dict(), # pyright: ignore 0, ) - def _get_next_iter_prompts(self): + def _get_next_iter_prompts(self, state: State): """Gets the next iteration's batch of prompts.""" - # Sample fewer batches for the Online RL interation depending on the number of generations per prompt n_unique_batches = self.num_unique_prompts_per_iter // self.global_train_batch_size + log.info( + f"Getting {n_unique_batches} unique batches of prompts for the current iteration.", + ) + batches = [ self._get_single_batch_prompts() for _ in range(n_unique_batches) ] @@ -696,7 +777,7 @@ def _get_next_iter_prompts(self): # this is an edge case that we will not hit currently, but just handling it as needed ret_batch[key] = curr_values - return ret_batch + return state.device.batch_to_device(ret_batch) def _get_single_batch_prompts(self): """Gets a single batch of prompts from the dataloader.""" @@ -770,6 +851,7 @@ def _interact_with_env(self, batch: dict[str, torch.Tensor]): ) padded_sequences.append(padded_sequence) sequences = torch.cat(padded_sequences, dim=0) + # Add the prepared sequences to the batch again batch['sequences'] = sequences @@ -795,8 +877,6 @@ def _interact_with_env(self, batch: dict[str, torch.Tensor]): f'Finished reward computation for the rollout in {total_reward_time:.4f} seconds.', ) - self.prompts_and_gens.extend(prompts_and_gens) - gen_batch_partial_outputs = (env_outputs, ref_outputs, all_rewards_dict) # For every partial output we want to resolve them together # And compute the global per iteration batch advantage's mean and variance @@ -805,17 +885,32 @@ def _interact_with_env(self, batch: dict[str, torch.Tensor]): gen_batch_partial_outputs, ) - # We need to split the resolved outputs into minibatches - for idx in range(bs // self.device_train_batch_size): - minibatch = self._extract_minibatch( - resolved_outputs, - idx, - self.device_train_batch_size, + if self.same_reward_filter_threshold is not None: + log.info( + f"in reward thresholding, trying to filter with: {self.same_reward_filter_threshold}", ) - self.buffer.add(minibatch) + start_time = time.time() + all_gathered_outputs = dist.all_gather_object(resolved_outputs) - # Making sure we correctly parsed the minibatches - assert len(self.buffer) == self.num_batches_per_update + log.info( + f"It took {time.time() - start_time} seconds to gather all resolved outputs.", + ) + + all_resolved_outputs = concat_resolved_outputs( + all_gathered_outputs, # type: ignore + self.pad_token_idx, # type: ignore + ) + + # Filter the resolved outputs based on the generation filtering values + resolved_outputs = filter_resolved_outputs( + all_resolved_outputs, + self.same_reward_filter_threshold, + ) + + self.global_sample_list.append(resolved_outputs) + + # TODO: bcui fix + self.prompts_and_gens.extend(prompts_and_gens) self.actor_critic.train() diff --git a/compose_rl/utils/__init__.py b/compose_rl/utils/__init__.py index ee28a4c2..1f3f1500 100644 --- a/compose_rl/utils/__init__.py +++ b/compose_rl/utils/__init__.py @@ -17,8 +17,10 @@ batch_process_fine_granularities, clear_mb_load_balancing_loss, compute_advantages, + concat_resolved_outputs, dist_compute_masked_mean_and_var, extract_packed_chosen_rejected, + filter_resolved_outputs, flatten, flip_pad_token_usage_for_generate, flip_pad_token_usage_in_ffn, @@ -42,6 +44,7 @@ masked_sum, masked_var, masked_whiten, + partition_global_batch, process_fine_granularities, remove_left_padding, rescale, @@ -91,6 +94,9 @@ 'make_action_mask', 'flatten', 'sample_wise_masked_mean', + 'filter_resolved_outputs', + 'concat_resolved_outputs', + 'partition_global_batch', 'extract_gsm8k_answer', 'extract_math_answer', 'is_equiv', diff --git a/compose_rl/utils/utils.py b/compose_rl/utils/utils.py index f24d972c..d01d47a3 100644 --- a/compose_rl/utils/utils.py +++ b/compose_rl/utils/utils.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import logging +import math import re import warnings from collections.abc import Generator, Iterable @@ -1222,3 +1223,264 @@ def flatten(coll: Union[Iterable[Any], str]) -> Generator[Any, None, None]: yield subc else: yield i + + +def filter_resolved_outputs( + outputs: dict[str, torch.Tensor], + filter_threshold: float, +): + """Filters the outputs based on the provided filter values. + + Filters out prompts where a high percentage of rewards are the same value. + This is done on a prompt-by-prompt basis. + + Args: + outputs (dict[str, torch.Tensor]): The outputs to filter. + filter_threshold (float): The percentage threshold for filtering. + + Returns: + dict: Filtered outputs with resolved prompts removed + """ + prompt_id = outputs['prompt_id'] + + # We are trying to resolve only the filtered + rewards = outputs['env_rewards'] + + generated_len = outputs['generated_len'] + + # Get unique prompt IDs and their indices + unique_prompt_ids, inverse_indices = torch.unique( + prompt_id, + return_inverse=True, + ) + + log.info(f"\nTotal unique prompts: {len(unique_prompt_ids)}") + log.info( + f"Threshold: Filter if > {filter_threshold:.0%} of rewards are the same", + ) + + prompts_to_filter = [] + prompt_stats = {} + + # Check each prompt individually + for i, unique_id in enumerate(unique_prompt_ids): + mask = inverse_indices == i + + cur_generated_lens = generated_len[mask] + + n_samples = cur_generated_lens.size(0) + + batch_tensor = torch.arange( + n_samples, + device=cur_generated_lens.device, + ) + + # For simplicity, we should always be getting the last generated reward + prompt_rewards = rewards[mask][batch_tensor, cur_generated_lens - 1] + + # print ("masked rewards is: ", rewards[mask], rewards[mask].shape) + # print ("prompt rewards shape is: ", prompt_rewards.shape) + # print ("rewards shape is: ", rewards.shape) + + # Find the most common reward value and its percentage + unique_values, counts = torch.unique(prompt_rewards, return_counts=True) + max_count = torch.max(counts).item() + max_percentage = max_count / n_samples + + # print ("rewards are: ", prompt_rewards) + + # print ("max count is: ", max_count, "n samples is: ", n_samples) + # print ("max percentage is: ", max_percentage) + + # print ("n samples is: ", n_samples) + # print ("max count is: ", max_count) + # print ("prompt rewards size is: ", prompt_rewards.size()) + + # Find which value is most common + most_common_idx = torch.argmax(counts) + most_common_value = unique_values[most_common_idx].item() + + # Decide whether to filter + if max_percentage > filter_threshold: + prompts_to_filter.append(i) + prompt_stats[unique_id.item()] = { + 'action': 'filtered', + 'most_common_value': most_common_value, + 'most_common_count': max_count, + 'percentage_same': max_percentage, + 'n_samples': n_samples, + 'unique_values': unique_values.tolist(), + 'counts': counts.tolist(), + } + log.info( + f" Prompt {unique_id.item()}: filtering as ({max_percentage:.0%} of the generations with the reward: {most_common_value}) as the most common value", + ) + else: + prompt_stats[unique_id.item()] = { + 'action': 'kept', + 'most_common_value': most_common_value, + 'most_common_count': max_count, + 'percentage_same': max_percentage, + 'n_samples': n_samples, + 'unique_values': unique_values.tolist(), + 'counts': counts.tolist(), + 'mean': torch.mean(prompt_rewards).item(), + 'std': torch.std(prompt_rewards).item(), + } + + # Create filter mask + keep_mask = torch.ones(len(prompt_id), dtype=torch.bool) + for prompt_idx in prompts_to_filter: + keep_mask[inverse_indices == prompt_idx] = False + + # print ("key mask shape is: ", keep_mask.shape) + + # get the integer indices of entries to keep + keep_indices = keep_mask.nonzero(as_tuple=True)[0] + + # Apply filter to all outputs + filtered_outputs = {} + for key, value in outputs.items(): + # print ("key is: ", key) + if isinstance(value, torch.Tensor): + # print('value shape is: ', value.shape) + filtered_outputs[key] = value[keep_mask] + + elif isinstance(value, list) and len(value) == keep_mask.shape[0]: + # keep only those elements whose mask is True + filtered_outputs[key] = [value[i] for i in keep_indices.tolist()] + + else: + assert False + + # Store statistics + filter_stats = { + 'total_prompts': len(unique_prompt_ids), + 'kept_prompts': len(unique_prompt_ids) - len(prompts_to_filter), + 'filtered_prompts': len(prompts_to_filter), + 'prompt_stats': prompt_stats, + 'threshold': filter_threshold, + } + + # filtered_outputs['filter_stats'] = filter_stats + + log.info(f" Kept: {filter_stats['kept_prompts']} prompts") + log.info(f" Filtered: {len(prompts_to_filter)} prompts") + log.info(f" Original samples: {len(prompt_id)}") + log.info(f" Remaining samples: {len(filtered_outputs['prompt_id'])}") + + return filtered_outputs + + +def concat_resolved_outputs( + output_list: list[dict[str, Union[torch.Tensor, list[Any]]]], + pad_token_id: int, +): + """Stacks a list of resolved outputs into a single dictionary. + + This function takes a list of dictionaries, where each dictionary contains tensors or lists + all of the keys should be the same. + + Args: + output_list (list[dict[str, Union[torch.Tensor, list]]]): A list of dictionaries containing tensors or lists. + """ + all_resolved_outputs = {} + for key in output_list[0].keys(): + # Collect all tensors under this key from each rank + raw_vals = [d[key] for d in output_list] + + if key in ['verified_answer']: + all_resolved_outputs[key] = list(flatten(raw_vals)) + continue + + tensor_list = [v for v in raw_vals if isinstance(v, torch.Tensor)] + assert len(tensor_list) == len(raw_vals), ( + f"Expected all values for key {key!r} to be torch.Tensor, " + f"but got {[type(v) for v in raw_vals]}" + ) + + padding_key = pad_token_id + + if key == 'prompt_attention_mask': + padding_key = False + + # If you're one of these keys we need to do some extra padding + # to make sure that the tensors are all the same length + if key in [ + 'prompt', + 'prompt_attention_mask', + 'sequences', + 'obs', + 'right_padded_attn_mask', + ]: + max_len = max(t.size(-1) for t in tensor_list) + padded_tensors: list[torch.Tensor] = [] + for t in tensor_list: + if t.size(-1) < max_len: + # Pad the tensor to the max length + padding = torch.full( + (t.size(0), max_len - t.size(-1)), + padding_key, # type: ignore + dtype=t.dtype, + device=t.device, + ) + padded_tensors.append( + torch.cat([t, padding], dim=-1), + ) + else: + padded_tensors.append(t) + tensor_list = padded_tensors + + all_resolved_outputs[key] = torch.cat(tensor_list, dim=0) + + return all_resolved_outputs + + +def partition_global_batch( + batch: Union[torch.Tensor, list], + world_size: int, + rank: int, + device_train_batch_size: int, +): + """Partitions a batch of data evenly across all ranks. + + Pads `batch` (Tensor or list) so that: + - total length = per_rank * world_size + - per_rank is the smallest multiple of device_train_batch_size ≥ original_length/world_size + Then returns the slice for this `rank`. + + Note: this function assumes that each rank has the exact same set of data. + + Args: + batch (Union[torch.Tensor, list]): The batch of data to partition. + world_size (int): The total number of ranks. + rank (int): The rank of the current process. + device_train_batch_size (int): The batch size per device. + """ + # 1) figure out length and per-rank size + B = batch.size(0) if isinstance(batch, torch.Tensor) else len(batch) + + per_rank = math.ceil( + B / world_size / device_train_batch_size, + ) * device_train_batch_size + total = per_rank * world_size + pad_size = total - B + + # 2) pad if needed + if pad_size > 0: + if isinstance(batch, torch.Tensor): + idx = torch.arange(pad_size, device=batch.device) % B + batch = torch.cat([batch, batch[idx]], dim=0) + else: + batch = batch + [batch[i % B] for i in range(pad_size)] + + # 3) split out rank’s chunk + if isinstance(batch, torch.Tensor): + # reshape into [world_size, per_rank, ...] then select + new_shape = [world_size, per_rank] + list(batch.shape[1:]) + return batch.view(*new_shape)[rank] + + else: + start = rank * per_rank + end = start + per_rank + return batch[start:end] From 23364b53de243494d2d5e1044bef0d3ac95fe77b Mon Sep 17 00:00:00 2001 From: brandon Date: Thu, 26 Jun 2025 00:36:19 +0000 Subject: [PATCH 2/7] update --- compose_rl/utils/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/compose_rl/utils/utils.py b/compose_rl/utils/utils.py index d01d47a3..76d2c35c 100644 --- a/compose_rl/utils/utils.py +++ b/compose_rl/utils/utils.py @@ -1383,6 +1383,7 @@ def concat_resolved_outputs( Args: output_list (list[dict[str, Union[torch.Tensor, list]]]): A list of dictionaries containing tensors or lists. + pad_token_id (int): The token ID used for padding tensors that need to be padded. """ all_resolved_outputs = {} for key in output_list[0].keys(): From 9a10b841e8331a592277f292c2046a89ae72574f Mon Sep 17 00:00:00 2001 From: brandon Date: Thu, 26 Jun 2025 00:39:23 +0000 Subject: [PATCH 3/7] update --- compose_rl/utils/utils.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/compose_rl/utils/utils.py b/compose_rl/utils/utils.py index 76d2c35c..6bb10f6e 100644 --- a/compose_rl/utils/utils.py +++ b/compose_rl/utils/utils.py @@ -1278,24 +1278,11 @@ def filter_resolved_outputs( # For simplicity, we should always be getting the last generated reward prompt_rewards = rewards[mask][batch_tensor, cur_generated_lens - 1] - # print ("masked rewards is: ", rewards[mask], rewards[mask].shape) - # print ("prompt rewards shape is: ", prompt_rewards.shape) - # print ("rewards shape is: ", rewards.shape) - # Find the most common reward value and its percentage unique_values, counts = torch.unique(prompt_rewards, return_counts=True) max_count = torch.max(counts).item() max_percentage = max_count / n_samples - # print ("rewards are: ", prompt_rewards) - - # print ("max count is: ", max_count, "n samples is: ", n_samples) - # print ("max percentage is: ", max_percentage) - - # print ("n samples is: ", n_samples) - # print ("max count is: ", max_count) - # print ("prompt rewards size is: ", prompt_rewards.size()) - # Find which value is most common most_common_idx = torch.argmax(counts) most_common_value = unique_values[most_common_idx].item() @@ -1333,17 +1320,13 @@ def filter_resolved_outputs( for prompt_idx in prompts_to_filter: keep_mask[inverse_indices == prompt_idx] = False - # print ("key mask shape is: ", keep_mask.shape) - # get the integer indices of entries to keep keep_indices = keep_mask.nonzero(as_tuple=True)[0] # Apply filter to all outputs filtered_outputs = {} for key, value in outputs.items(): - # print ("key is: ", key) if isinstance(value, torch.Tensor): - # print('value shape is: ', value.shape) filtered_outputs[key] = value[keep_mask] elif isinstance(value, list) and len(value) == keep_mask.shape[0]: From e471d513e938643807d02b96c09205b19a04dbc0 Mon Sep 17 00:00:00 2001 From: brandon Date: Thu, 26 Jun 2025 00:54:37 +0000 Subject: [PATCH 4/7] update --- compose_rl/utils/utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/compose_rl/utils/utils.py b/compose_rl/utils/utils.py index 6bb10f6e..0d20ab3c 100644 --- a/compose_rl/utils/utils.py +++ b/compose_rl/utils/utils.py @@ -1300,7 +1300,7 @@ def filter_resolved_outputs( 'counts': counts.tolist(), } log.info( - f" Prompt {unique_id.item()}: filtering as ({max_percentage:.0%} of the generations with the reward: {most_common_value}) as the most common value", + f"Prompt {unique_id.item()}: filtering as ({max_percentage:.0%} of the generations with the reward: {most_common_value}) as the most common value", ) else: prompt_stats[unique_id.item()] = { @@ -1347,10 +1347,10 @@ def filter_resolved_outputs( # filtered_outputs['filter_stats'] = filter_stats - log.info(f" Kept: {filter_stats['kept_prompts']} prompts") - log.info(f" Filtered: {len(prompts_to_filter)} prompts") - log.info(f" Original samples: {len(prompt_id)}") - log.info(f" Remaining samples: {len(filtered_outputs['prompt_id'])}") + log.info(f"Kept: {filter_stats['kept_prompts']} prompts") + log.info(f"Filtered: {len(prompts_to_filter)} prompts") + log.info(f"Original samples: {len(prompt_id)}") + log.info(f"Remaining samples: {len(filtered_outputs['prompt_id'])}") return filtered_outputs From 6fefdafb7cd976b8d5a8f4a2bc13024774d62ca2 Mon Sep 17 00:00:00 2001 From: brandon Date: Thu, 26 Jun 2025 01:04:50 +0000 Subject: [PATCH 5/7] update --- compose_rl/algorithms/online/model_methods.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index bdcec748..d68e5771 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -238,7 +238,7 @@ def policy_loss( logits=gen_logits, ) assert token_entropies.shape == batch['action_mask'].shape, ( - f'Token entropies shape {token_entropies.shape} does not match action mask shape {batch["action_mask"].shape}.', + f'Token entropies shape {token_entropies.shape} does not match action mask shape {batch['action_mask'].shape}.', ) seq_entropies = utils.get_sequence_entropies( token_entropies=token_entropies, From 938af5e170b1a41e55c0996069c810a19350ee24 Mon Sep 17 00:00:00 2001 From: brandon Date: Thu, 26 Jun 2025 01:13:18 +0000 Subject: [PATCH 6/7] update --- compose_rl/algorithms/online/model_methods.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index d68e5771..3528e515 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -238,7 +238,7 @@ def policy_loss( logits=gen_logits, ) assert token_entropies.shape == batch['action_mask'].shape, ( - f'Token entropies shape {token_entropies.shape} does not match action mask shape {batch['action_mask'].shape}.', + f"Token entropies shape {token_entropies.shape} does not match action mask shape {batch['action_mask'].shape}.", ) seq_entropies = utils.get_sequence_entropies( token_entropies=token_entropies, From 7f736863a08cc68f5f14ffa2070447644a3e265b Mon Sep 17 00:00:00 2001 From: brandon Date: Thu, 26 Jun 2025 22:07:02 +0000 Subject: [PATCH 7/7] quotes --- compose_rl/algorithms/online/callback.py | 16 ++++++++-------- compose_rl/utils/utils.py | 21 +++++++++++---------- 2 files changed, 19 insertions(+), 18 deletions(-) diff --git a/compose_rl/algorithms/online/callback.py b/compose_rl/algorithms/online/callback.py index 7a5eee61..5b19f1e5 100644 --- a/compose_rl/algorithms/online/callback.py +++ b/compose_rl/algorithms/online/callback.py @@ -522,7 +522,7 @@ def __init__( world_size = dist.get_world_size() num_gen_calls = self.global_iter_batch_size // self.device_generate_batch_size // world_size log.info( - f"Using {num_gen_calls} generation calls per iteration with HuggingFace generate.", + f'Using {num_gen_calls} generation calls per iteration with HuggingFace generate.', ) if num_gen_calls <= 0: raise ValueError(f'{num_gen_calls=} must be greater than 0') @@ -640,9 +640,9 @@ def iteration_start(self, state: State, logger: Logger): bs = cur_global_samples['prompt_id'].shape[0] - log.info(f"Current global batch size is {bs}.") + log.info(f'Current global batch size is {bs}.') log.info( - f"Current global iter batch size is {self.global_iter_batch_size}.", + f'Current global iter batch size is {self.global_iter_batch_size}.', ) if bs >= self.global_iter_batch_size: @@ -672,10 +672,10 @@ def iteration_start(self, state: State, logger: Logger): self.buffer.add(minibatch) log.info( - f"For iteration {self.iter_num}, we have {len(self.buffer)} samples in the buffer. Starting training.", + f'For iteration {self.iter_num}, we have {len(self.buffer)} samples in the buffer. Starting training.', ) log.info( - f"It took {num_env_interactions} environment interactions to fill the buffer.", + f'It took {num_env_interactions} environment interactions to fill the buffer.', ) # Making sure we correctly parsed the minibatches @@ -723,7 +723,7 @@ def _get_next_iter_prompts(self, state: State): """Gets the next iteration's batch of prompts.""" n_unique_batches = self.num_unique_prompts_per_iter // self.global_train_batch_size log.info( - f"Getting {n_unique_batches} unique batches of prompts for the current iteration.", + f'Getting {n_unique_batches} unique batches of prompts for the current iteration.', ) batches = [ @@ -887,13 +887,13 @@ def _interact_with_env(self, batch: dict[str, torch.Tensor]): if self.same_reward_filter_threshold is not None: log.info( - f"in reward thresholding, trying to filter with: {self.same_reward_filter_threshold}", + f'in reward thresholding, trying to filter with: {self.same_reward_filter_threshold}', ) start_time = time.time() all_gathered_outputs = dist.all_gather_object(resolved_outputs) log.info( - f"It took {time.time() - start_time} seconds to gather all resolved outputs.", + f'It took {time.time() - start_time} seconds to gather all resolved outputs.', ) all_resolved_outputs = concat_resolved_outputs( diff --git a/compose_rl/utils/utils.py b/compose_rl/utils/utils.py index 0d20ab3c..e72599bd 100644 --- a/compose_rl/utils/utils.py +++ b/compose_rl/utils/utils.py @@ -1254,9 +1254,9 @@ def filter_resolved_outputs( return_inverse=True, ) - log.info(f"\nTotal unique prompts: {len(unique_prompt_ids)}") + log.info(f'\nTotal unique prompts: {len(unique_prompt_ids)}') log.info( - f"Threshold: Filter if > {filter_threshold:.0%} of rewards are the same", + f'Threshold: Filter if > {filter_threshold:.0%} of rewards are the same', ) prompts_to_filter = [] @@ -1300,7 +1300,7 @@ def filter_resolved_outputs( 'counts': counts.tolist(), } log.info( - f"Prompt {unique_id.item()}: filtering as ({max_percentage:.0%} of the generations with the reward: {most_common_value}) as the most common value", + f'Prompt {unique_id.item()}: filtering as ({max_percentage:.0%} of the generations with the reward: {most_common_value}) as the most common value', ) else: prompt_stats[unique_id.item()] = { @@ -1345,12 +1345,13 @@ def filter_resolved_outputs( 'threshold': filter_threshold, } - # filtered_outputs['filter_stats'] = filter_stats + num_kept_prompts = filter_stats['kept_prompts'] + num_remaining_prompts = len(filtered_outputs['prompt_id']) - log.info(f"Kept: {filter_stats['kept_prompts']} prompts") - log.info(f"Filtered: {len(prompts_to_filter)} prompts") - log.info(f"Original samples: {len(prompt_id)}") - log.info(f"Remaining samples: {len(filtered_outputs['prompt_id'])}") + log.info(f'Kept: {num_kept_prompts} prompts') + log.info(f'Filtered: {len(prompts_to_filter)} prompts') + log.info(f'Original samples: {len(prompt_id)}') + log.info(f'Remaining samples: {num_remaining_prompts}') return filtered_outputs @@ -1379,8 +1380,8 @@ def concat_resolved_outputs( tensor_list = [v for v in raw_vals if isinstance(v, torch.Tensor)] assert len(tensor_list) == len(raw_vals), ( - f"Expected all values for key {key!r} to be torch.Tensor, " - f"but got {[type(v) for v in raw_vals]}" + f'Expected all values for key {key!r} to be torch.Tensor, ' + f'but got {[type(v) for v in raw_vals]}' ) padding_key = pad_token_id