From ea680a575e0d6e2980dbde63a61f92403d26caf5 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Tue, 1 Jul 2025 11:47:41 -0400 Subject: [PATCH 01/20] add multimodal support for models --- .../algorithms/offline/model_methods.py | 42 ++++++++++++++----- compose_rl/algorithms/online/model_methods.py | 4 ++ 2 files changed, 36 insertions(+), 10 deletions(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 32aa3a01..d934f675 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -65,6 +65,10 @@ def pairwise_offline_forward( if pad_token_id is None: raise ValueError('Tokenizer must have a PAD token.') + is_multimodal = "pixel_values" in batch.keys() + if is_multimodal and use_attention_sequence_id: + raise NotImplementedError("Using Sequence ID is not implemented for VLMs") + # If we can use attention sequence ID, we use this logic branch. # This is determined by a value set in `train_dpo.py` if use_attention_sequence_id: @@ -102,18 +106,36 @@ def pairwise_offline_forward( pad_token_id=0, ) - batch_cat_inputs = torch.cat([chosen_inputs, rejected_inputs], dim=0) - batch_attn_mask = torch.cat( - [ - chosen_attention_mask, - rejected_attention_mask, - ], - dim=0, - ) + inputs = { + "input_ids": torch.cat([chosen_inputs, rejected_inputs], dim=0), + "attention_mask": torch.cat( + [ + chosen_attention_mask, + rejected_attention_mask, + ], + dim=0, + ), + } + + if is_multimodal: + chosen_token_type_ids, rejected_token_type_ids = extract_packed_chosen_rejected( + batch['token_type_ids'], + batch['chosen_len'], + batch['rejected_len'], + concat_seq_len, + pad_token_id=0, + ) + + # TODO: Ask if assuming same pixel inputs is ok? + multimodal_inputs = { + "token_type_ids": torch.cat([chosen_token_type_ids, rejected_token_type_ids], dim=0), + "pixel_values": torch.cat([batch['pixel_values'], batch['pixel_values']], dim=0), + } + + inputs.update(multimodal_inputs) output_logits = model( - batch_cat_inputs, - attention_mask=batch_attn_mask, + **inputs ).logits # Extract out the chosen and rejected logits along the batch dimension diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index bdcec748..76072d1d 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -118,6 +118,10 @@ def composer_online_rl_forward( model_forward_kwargs['action_mask'] = batch['action_mask'] model_forward_kwargs['max_gen_len'] = batch['max_gen_len'] + if "pixel_values" in batch.keys(): + model_forward_kwargs['token_type_ids'] = batch['token_type_ids'] + model_forward_kwargs['pixel_values'] = batch['pixel_values'] + actor_output = model(batch['obs'], **model_forward_kwargs) logits = actor_output.logits From a7e86cc89020c520ebfc981ef138655d7c56ad9c Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Tue, 1 Jul 2025 11:48:42 -0400 Subject: [PATCH 02/20] linting --- .../algorithms/offline/model_methods.py | 34 ++++++++++++------- compose_rl/algorithms/online/model_methods.py | 2 +- 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index d934f675..95bc1567 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -65,9 +65,11 @@ def pairwise_offline_forward( if pad_token_id is None: raise ValueError('Tokenizer must have a PAD token.') - is_multimodal = "pixel_values" in batch.keys() + is_multimodal = 'pixel_values' in batch.keys() if is_multimodal and use_attention_sequence_id: - raise NotImplementedError("Using Sequence ID is not implemented for VLMs") + raise NotImplementedError( + 'Using Sequence ID is not implemented for VLMs', + ) # If we can use attention sequence ID, we use this logic branch. # This is determined by a value set in `train_dpo.py` @@ -107,14 +109,16 @@ def pairwise_offline_forward( ) inputs = { - "input_ids": torch.cat([chosen_inputs, rejected_inputs], dim=0), - "attention_mask": torch.cat( - [ - chosen_attention_mask, - rejected_attention_mask, - ], - dim=0, - ), + 'input_ids': + torch.cat([chosen_inputs, rejected_inputs], dim=0), + 'attention_mask': + torch.cat( + [ + chosen_attention_mask, + rejected_attention_mask, + ], + dim=0, + ), } if is_multimodal: @@ -128,14 +132,18 @@ def pairwise_offline_forward( # TODO: Ask if assuming same pixel inputs is ok? multimodal_inputs = { - "token_type_ids": torch.cat([chosen_token_type_ids, rejected_token_type_ids], dim=0), - "pixel_values": torch.cat([batch['pixel_values'], batch['pixel_values']], dim=0), + 'token_type_ids': + torch.cat([chosen_token_type_ids, rejected_token_type_ids], + dim=0), + 'pixel_values': + torch.cat([batch['pixel_values'], batch['pixel_values']], + dim=0), } inputs.update(multimodal_inputs) output_logits = model( - **inputs + **inputs, ).logits # Extract out the chosen and rejected logits along the batch dimension diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index 76072d1d..f45abb62 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -118,7 +118,7 @@ def composer_online_rl_forward( model_forward_kwargs['action_mask'] = batch['action_mask'] model_forward_kwargs['max_gen_len'] = batch['max_gen_len'] - if "pixel_values" in batch.keys(): + if 'pixel_values' in batch.keys(): model_forward_kwargs['token_type_ids'] = batch['token_type_ids'] model_forward_kwargs['pixel_values'] = batch['pixel_values'] From 08952f7d978a4f41f9c7cb7e0f80e9d1c3f6f2f3 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Tue, 1 Jul 2025 12:15:52 -0400 Subject: [PATCH 03/20] multimodal handling for gemma3 --- compose_rl/data/preference_data.py | 51 ++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/compose_rl/data/preference_data.py b/compose_rl/data/preference_data.py index c08d2757..6a4fa50b 100644 --- a/compose_rl/data/preference_data.py +++ b/compose_rl/data/preference_data.py @@ -56,6 +56,10 @@ def pairwise_preference_dataset_collate_fn( chosen_rewards = [] rejected_rewards = [] + # For VLMs + token_type_ids = [] + pixel_values = [] + for sample in data: chosen = sample['chosen'] rejected = sample['rejected'] @@ -63,6 +67,17 @@ def pairwise_preference_dataset_collate_fn( chosen_len = sample['chosen_len'] rejected_len = sample['rejected_len'] + is_multimodal = "pixel_values" in sample.keys() + if is_multimodal: + pixel_vals = sample['pixel_values'] + chosen_token_type_ids = sample['chosen_token_type_ids'] + rejected_token_type_ids = sample['rejected_token_type_ids'] + else: + pixel_vals = None + chosen_token_type_ids = None + rejected_token_type_ids = None + cat_token_type_ids = None + # Note: if we do any truncation, we force the last token to be EOS # https://github.com/mosaicml/RLHF/issues/101 @@ -75,6 +90,9 @@ def pairwise_preference_dataset_collate_fn( pad_len = max_seq_len - chosen_len - rejected_len cat_batch = torch.cat([chosen, rejected], dim=-1) + if is_multimodal: + cat_token_type_ids = torch.cat([chosen_token_type_ids, rejected_token_type_ids], dim=-1) + if pad_len < 0: # We should truncate chosen and rejected by the same amount truncate_len = abs(pad_len // 2) + 1 @@ -92,6 +110,15 @@ def pairwise_preference_dataset_collate_fn( rejected = rejected[:-truncate_len] rejected[-1] = tokenizer.eos_token_id # type: ignore + if is_multimodal: + chosen_token_type_ids = chosen_token_type_ids[:-truncate_len] + rejected_token_type_ids = rejected_token_type_ids[:-truncate_len] + + # NOTE: GEMMA specific: 0 == text token + chosen_token_type_ids[-1] = 0 + rejected_token_type_ids[-1] = 0 + cat_token_type_ids = torch.cat([chosen_token_type_ids, rejected_token_type_ids], dim=-1) + cat_batch = torch.cat([chosen, rejected], dim=-1) chosen_len = torch.tensor([len(chosen)]) @@ -108,6 +135,11 @@ def pairwise_preference_dataset_collate_fn( ], dim=-1, # type: ignore ) + if is_multimodal: + cat_token_type_ids = torch.cat([ + cat_token_type_ids, + torch.zeros(int(pad_len.item()), dtype=cat_token_type_ids.dtype), + ], dim=-1) attention_mask = torch.logical_not( torch.eq(cat_batch, tokenizer.pad_token_id), # type: ignore @@ -127,6 +159,10 @@ def pairwise_preference_dataset_collate_fn( chosen_rewards.append(sample['chosen_reward']) rejected_rewards.append(sample['rejected_reward']) + if is_multimodal: + token_type_ids.append(cat_token_type_ids) + pixel_values.append(pixel_vals) + input_ids = ref_collate_fn(input_ids)['input_ids'] attention_masks = torch.stack(attention_masks) sequence_id = torch.stack(sequence_id) @@ -147,6 +183,11 @@ def pairwise_preference_dataset_collate_fn( rejected_rewards = torch.stack(rejected_rewards) return_dict['chosen_reward'] = chosen_rewards return_dict['rejected_reward'] = rejected_rewards + + if is_multimodal: + return_dict['token_type_ids'] = token_type_ids + return_dict['pixel_values'] = pixel_values + return return_dict @@ -263,6 +304,16 @@ def __getitem__(self, idx: int) -> dict[str, Any]: rejected_reward = torch.Tensor([sample['rejected_reward']]) return_dict['chosen_reward'] = chosen_reward return_dict['rejected_reward'] = rejected_reward + + if 'pixel_values' in sample: + pixel_values = self._read_binary_tokenized_sample(sample['pixel_values'], 'pixel_values') + chosen_token_type_ids = self._read_binary_tokenized_sample(sample['chosen_token_type_ids'], 'chosen_token_type_ids') + rejected_token_type_ids = self._read_binary_tokenized_sample((sample['rejected_token_type_ids']), 'rejected_token_type_ids') + + return_dict['pixel_values'] = pixel_values + return_dict['chosen_token_type_ids'] = chosen_token_type_ids + return_dict['rejected_token_type_ids'] = rejected_token_type_ids + return return_dict def find_prompt_length(self, seq_1: torch.Tensor, seq_2: torch.Tensor): From bcc273b625be2f0cc2c1ce2793513700f4914c89 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Tue, 1 Jul 2025 12:16:33 -0400 Subject: [PATCH 04/20] lint --- compose_rl/data/preference_data.py | 42 +++++++++++++++++++++++------- 1 file changed, 32 insertions(+), 10 deletions(-) diff --git a/compose_rl/data/preference_data.py b/compose_rl/data/preference_data.py index 6a4fa50b..990df1fb 100644 --- a/compose_rl/data/preference_data.py +++ b/compose_rl/data/preference_data.py @@ -67,7 +67,7 @@ def pairwise_preference_dataset_collate_fn( chosen_len = sample['chosen_len'] rejected_len = sample['rejected_len'] - is_multimodal = "pixel_values" in sample.keys() + is_multimodal = 'pixel_values' in sample.keys() if is_multimodal: pixel_vals = sample['pixel_values'] chosen_token_type_ids = sample['chosen_token_type_ids'] @@ -91,7 +91,11 @@ def pairwise_preference_dataset_collate_fn( cat_batch = torch.cat([chosen, rejected], dim=-1) if is_multimodal: - cat_token_type_ids = torch.cat([chosen_token_type_ids, rejected_token_type_ids], dim=-1) + cat_token_type_ids = torch.cat([ + chosen_token_type_ids, + rejected_token_type_ids, + ], + dim=-1) if pad_len < 0: # We should truncate chosen and rejected by the same amount @@ -112,12 +116,17 @@ def pairwise_preference_dataset_collate_fn( if is_multimodal: chosen_token_type_ids = chosen_token_type_ids[:-truncate_len] - rejected_token_type_ids = rejected_token_type_ids[:-truncate_len] + rejected_token_type_ids = rejected_token_type_ids[:-truncate_len + ] # NOTE: GEMMA specific: 0 == text token chosen_token_type_ids[-1] = 0 rejected_token_type_ids[-1] = 0 - cat_token_type_ids = torch.cat([chosen_token_type_ids, rejected_token_type_ids], dim=-1) + cat_token_type_ids = torch.cat([ + chosen_token_type_ids, + rejected_token_type_ids, + ], + dim=-1) cat_batch = torch.cat([chosen, rejected], dim=-1) @@ -138,8 +147,12 @@ def pairwise_preference_dataset_collate_fn( if is_multimodal: cat_token_type_ids = torch.cat([ cat_token_type_ids, - torch.zeros(int(pad_len.item()), dtype=cat_token_type_ids.dtype), - ], dim=-1) + torch.zeros( + int(pad_len.item()), + dtype=cat_token_type_ids.dtype, + ), + ], + dim=-1) attention_mask = torch.logical_not( torch.eq(cat_batch, tokenizer.pad_token_id), # type: ignore @@ -184,7 +197,7 @@ def pairwise_preference_dataset_collate_fn( return_dict['chosen_reward'] = chosen_rewards return_dict['rejected_reward'] = rejected_rewards - if is_multimodal: + if is_multimodal: return_dict['token_type_ids'] = token_type_ids return_dict['pixel_values'] = pixel_values @@ -306,9 +319,18 @@ def __getitem__(self, idx: int) -> dict[str, Any]: return_dict['rejected_reward'] = rejected_reward if 'pixel_values' in sample: - pixel_values = self._read_binary_tokenized_sample(sample['pixel_values'], 'pixel_values') - chosen_token_type_ids = self._read_binary_tokenized_sample(sample['chosen_token_type_ids'], 'chosen_token_type_ids') - rejected_token_type_ids = self._read_binary_tokenized_sample((sample['rejected_token_type_ids']), 'rejected_token_type_ids') + pixel_values = self._read_binary_tokenized_sample( + sample['pixel_values'], + 'pixel_values', + ) + chosen_token_type_ids = self._read_binary_tokenized_sample( + sample['chosen_token_type_ids'], + 'chosen_token_type_ids', + ) + rejected_token_type_ids = self._read_binary_tokenized_sample( + (sample['rejected_token_type_ids']), + 'rejected_token_type_ids', + ) return_dict['pixel_values'] = pixel_values return_dict['chosen_token_type_ids'] = chosen_token_type_ids From 634e1513036c4ee891fb42819d6f2e8b5f54e740 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Thu, 3 Jul 2025 14:57:06 -0400 Subject: [PATCH 05/20] fix multimodal preference loading --- compose_rl/data/preference_data.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/compose_rl/data/preference_data.py b/compose_rl/data/preference_data.py index 990df1fb..9abe3e75 100644 --- a/compose_rl/data/preference_data.py +++ b/compose_rl/data/preference_data.py @@ -320,15 +320,15 @@ def __getitem__(self, idx: int) -> dict[str, Any]: if 'pixel_values' in sample: pixel_values = self._read_binary_tokenized_sample( - sample['pixel_values'], + sample, 'pixel_values', ) chosen_token_type_ids = self._read_binary_tokenized_sample( - sample['chosen_token_type_ids'], + sample, 'chosen_token_type_ids', ) rejected_token_type_ids = self._read_binary_tokenized_sample( - (sample['rejected_token_type_ids']), + sample, 'rejected_token_type_ids', ) From d35ed90b19832961e5304d26eb3b52d7fffa3b92 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Thu, 3 Jul 2025 15:01:19 -0400 Subject: [PATCH 06/20] fix collator --- compose_rl/data/preference_data.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/compose_rl/data/preference_data.py b/compose_rl/data/preference_data.py index 9abe3e75..d669f346 100644 --- a/compose_rl/data/preference_data.py +++ b/compose_rl/data/preference_data.py @@ -198,6 +198,8 @@ def pairwise_preference_dataset_collate_fn( return_dict['rejected_reward'] = rejected_rewards if is_multimodal: + token_type_ids = torch.stack(token_type_ids) + pixel_values = torch.stack(pixel_values) return_dict['token_type_ids'] = token_type_ids return_dict['pixel_values'] = pixel_values From 8bea0e3c7d7d5fc8e9b66bb1b09529df7fc097e3 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Thu, 3 Jul 2025 15:15:31 -0400 Subject: [PATCH 07/20] debug --- compose_rl/algorithms/offline/model_methods.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 95bc1567..4d7c56cc 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -140,6 +140,10 @@ def pairwise_offline_forward( dim=0), } + print("MULTIMODAL INPUTS") + for k, v in multimodal_inptus.itmes(): + print(f"{k}: {v.shape}") + inputs.update(multimodal_inputs) output_logits = model( From b305fded93838af269baee93ad9cc444497b4974 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Thu, 3 Jul 2025 15:18:36 -0400 Subject: [PATCH 08/20] debug --- compose_rl/algorithms/offline/model_methods.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 4d7c56cc..00b74f73 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -141,7 +141,7 @@ def pairwise_offline_forward( } print("MULTIMODAL INPUTS") - for k, v in multimodal_inptus.itmes(): + for k, v in multimodal_inputs.items(): print(f"{k}: {v.shape}") inputs.update(multimodal_inputs) From b631cccd4ed58cc2a029a8017157fb37a7c3dc7f Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Thu, 3 Jul 2025 15:29:34 -0400 Subject: [PATCH 09/20] debug --- compose_rl/data/preference_data.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/compose_rl/data/preference_data.py b/compose_rl/data/preference_data.py index d669f346..b157e184 100644 --- a/compose_rl/data/preference_data.py +++ b/compose_rl/data/preference_data.py @@ -200,6 +200,8 @@ def pairwise_preference_dataset_collate_fn( if is_multimodal: token_type_ids = torch.stack(token_type_ids) pixel_values = torch.stack(pixel_values) + print('HIIIIIII') + print(pixel_values[0].shape) return_dict['token_type_ids'] = token_type_ids return_dict['pixel_values'] = pixel_values @@ -269,6 +271,10 @@ def __init__(self, max_seq_len: int, **kwargs: dict[str, Any]): def _read_binary_tokenized_sample(self, sample: dict[str, Any], key: str): self.num_read += 1 temp_sample = torch.from_numpy(np.frombuffer(sample[key])) + if key == 'pixel_values': + print('I AM INSIDE READ BINARY') + print(temp_sample.shape) + print(len(temp_sample)) if len(temp_sample) > self.max_seq_len: log.info(f'Truncating sample: {self.num_truncated} {self.num_read}') self.num_truncated += 1 From fabb4a82f14a6cca39ef21ae368d85329124102b Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Thu, 3 Jul 2025 16:23:05 -0400 Subject: [PATCH 10/20] change pixel values from being bytes to ndarray or pil --- compose_rl/data/preference_data.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/compose_rl/data/preference_data.py b/compose_rl/data/preference_data.py index b157e184..d08c53dc 100644 --- a/compose_rl/data/preference_data.py +++ b/compose_rl/data/preference_data.py @@ -11,6 +11,9 @@ from streaming import StreamingDataset from transformers import DataCollatorForLanguageModeling, PreTrainedTokenizer +from PIL import Image +from torchvision import transforms + log = logging.getLogger(__name__) @@ -200,8 +203,6 @@ def pairwise_preference_dataset_collate_fn( if is_multimodal: token_type_ids = torch.stack(token_type_ids) pixel_values = torch.stack(pixel_values) - print('HIIIIIII') - print(pixel_values[0].shape) return_dict['token_type_ids'] = token_type_ids return_dict['pixel_values'] = pixel_values @@ -271,10 +272,6 @@ def __init__(self, max_seq_len: int, **kwargs: dict[str, Any]): def _read_binary_tokenized_sample(self, sample: dict[str, Any], key: str): self.num_read += 1 temp_sample = torch.from_numpy(np.frombuffer(sample[key])) - if key == 'pixel_values': - print('I AM INSIDE READ BINARY') - print(temp_sample.shape) - print(len(temp_sample)) if len(temp_sample) > self.max_seq_len: log.info(f'Truncating sample: {self.num_truncated} {self.num_read}') self.num_truncated += 1 @@ -327,10 +324,17 @@ def __getitem__(self, idx: int) -> dict[str, Any]: return_dict['rejected_reward'] = rejected_reward if 'pixel_values' in sample: - pixel_values = self._read_binary_tokenized_sample( - sample, - 'pixel_values', - ) + if isinstance(sample['pixel_values'], np.ndarray): + pixel_values = torch.Tensor(sample['pixel_values']) + elif isinstance(sample['pixel_values'], Image): + pil_to_tensor_transform = transforms.PILToTensor() + pixel_values = pil_to_tensor_transform(sample['pixel_values']) + else: + pixel_values_type = type(sample['pixel_values']) + raise ValueError( + f'Expect pixel values to be numpy.ndarray or PIL.Image type, but got {pixel_values_type}', + ) + chosen_token_type_ids = self._read_binary_tokenized_sample( sample, 'chosen_token_type_ids', From e0e015ae260134f968988445ac981f2f9b9fe6bd Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Tue, 8 Jul 2025 14:13:21 -0400 Subject: [PATCH 11/20] support ndarray typing --- compose_rl/data/preference_data.py | 73 ++++++++++++++++++++++-------- 1 file changed, 53 insertions(+), 20 deletions(-) diff --git a/compose_rl/data/preference_data.py b/compose_rl/data/preference_data.py index d08c53dc..6d698ebc 100644 --- a/compose_rl/data/preference_data.py +++ b/compose_rl/data/preference_data.py @@ -293,21 +293,44 @@ def __getitem__(self, idx: int) -> dict[str, Any]: idx (int): the index where we fetch the data in the StreamingDataset. """ sample = super().__getitem__(idx) + # Handle prompt if available - if 'prompt' in sample: + if isinstance(sample['chosen'], bytes): # Prepend the prompt to the chosen and rejected responses - sample['chosen'] = sample['prompt'] + sample['chosen'] - sample['rejected'] = sample['prompt'] + sample['rejected'] - chosen = self._read_binary_tokenized_sample(sample, 'chosen') - rejected = self._read_binary_tokenized_sample(sample, 'rejected') - - if 'prompt' in sample: - prompt = self._read_binary_tokenized_sample(sample, 'prompt') - prompt_len = len(prompt) + if 'prompt' in sample: + sample['chosen'] = sample['prompt'] + sample['chosen'] + sample['rejected'] = sample['prompt'] + sample['rejected'] + chosen = self._read_binary_tokenized_sample(sample, 'chosen') + rejected = self._read_binary_tokenized_sample(sample, 'rejected') + + if 'prompt' in sample: + prompt = self._read_binary_tokenized_sample(sample, 'prompt') + prompt_len = len(prompt) + else: + # Only use prefix matching version of prompt_len when + # 'prompt' is not directly given in the sample + prompt_len = self.find_prompt_length(chosen, rejected) + + elif isinstance(sample['chosen'], np.ndarray): + if 'prompt' in sample: + sample['chosen'] = np.concatenate([sample['prompt'], sample['chosen']]) + sample['rejected'] = np.concatenate([sample['prompt'], sample['rejected']]) + + chosen = sample['chosen'][:self.max_seq_len].tolist().copy() + rejected = sample['rejected'][:self.max_seq_len].tolist().copy() + + if 'prompt' in sample: + prompt_len = len(sample['prompt']) + else: + # Only use prefix matching version of prompt_len when + # 'prompt' is not directly given in the sample + prompt_len = self.find_prompt_length(chosen, rejected) else: - # Only use prefix matching version of prompt_len when - # 'prompt' is not directly given in the sample - prompt_len = self.find_prompt_length(chosen, rejected) + token_type = type(sample['chosen']) + raise ValueError( + f'Expect prompt and response to be bytes or numpy.ndarray type, but got {token_type}', + ) + chosen_len, rejected_len = len(chosen), len(rejected) return_dict = { 'chosen': chosen, @@ -335,14 +358,24 @@ def __getitem__(self, idx: int) -> dict[str, Any]: f'Expect pixel values to be numpy.ndarray or PIL.Image type, but got {pixel_values_type}', ) - chosen_token_type_ids = self._read_binary_tokenized_sample( - sample, - 'chosen_token_type_ids', - ) - rejected_token_type_ids = self._read_binary_tokenized_sample( - sample, - 'rejected_token_type_ids', - ) + if isinstance(sample['chosen_token_type_ids'], bytes): + chosen_token_type_ids = self._read_binary_tokenized_sample( + sample, + 'chosen_token_type_ids', + ) + rejected_token_type_ids = self._read_binary_tokenized_sample( + sample, + 'rejected_token_type_ids', + ) + elif isinstance(sample['chosen_token_type_ids'], np.ndarray): + chosen_token_type_ids = sample['chosen_token_type_ids'][:self.max_seq_len].tolist().copy() + rejected_token_type_ids = sample['rejected_token_type_ids'][:self.max_seq_len].tolist().copy() + else: + token_type = type(sample['chosen_token_type_ids']) + raise ValueError( + f'Expect token_type_ids to be numpy.ndarray or bytes, but got {token_type}', + ) + return_dict['pixel_values'] = pixel_values return_dict['chosen_token_type_ids'] = chosen_token_type_ids From 2c1c0d4c421e10ae51fc11a81638449548febba1 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Tue, 8 Jul 2025 14:17:06 -0400 Subject: [PATCH 12/20] PIL image support --- compose_rl/data/preference_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/data/preference_data.py b/compose_rl/data/preference_data.py index 6d698ebc..155ddd51 100644 --- a/compose_rl/data/preference_data.py +++ b/compose_rl/data/preference_data.py @@ -349,7 +349,7 @@ def __getitem__(self, idx: int) -> dict[str, Any]: if 'pixel_values' in sample: if isinstance(sample['pixel_values'], np.ndarray): pixel_values = torch.Tensor(sample['pixel_values']) - elif isinstance(sample['pixel_values'], Image): + elif isinstance(sample['pixel_values'], Image.Image): pil_to_tensor_transform = transforms.PILToTensor() pixel_values = pil_to_tensor_transform(sample['pixel_values']) else: From 15ca2c6577ffc03d6ccc19bf19938e9f2309975c Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Tue, 8 Jul 2025 14:20:52 -0400 Subject: [PATCH 13/20] numpy support bug fix --- compose_rl/data/preference_data.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/compose_rl/data/preference_data.py b/compose_rl/data/preference_data.py index 155ddd51..90d014b9 100644 --- a/compose_rl/data/preference_data.py +++ b/compose_rl/data/preference_data.py @@ -316,8 +316,8 @@ def __getitem__(self, idx: int) -> dict[str, Any]: sample['chosen'] = np.concatenate([sample['prompt'], sample['chosen']]) sample['rejected'] = np.concatenate([sample['prompt'], sample['rejected']]) - chosen = sample['chosen'][:self.max_seq_len].tolist().copy() - rejected = sample['rejected'][:self.max_seq_len].tolist().copy() + chosen = torch.from_numpy(sample['chosen'][:self.max_seq_len]) + rejected = torch.from_numpy(sample['rejected'][:self.max_seq_len]) if 'prompt' in sample: prompt_len = len(sample['prompt']) @@ -368,8 +368,8 @@ def __getitem__(self, idx: int) -> dict[str, Any]: 'rejected_token_type_ids', ) elif isinstance(sample['chosen_token_type_ids'], np.ndarray): - chosen_token_type_ids = sample['chosen_token_type_ids'][:self.max_seq_len].tolist().copy() - rejected_token_type_ids = sample['rejected_token_type_ids'][:self.max_seq_len].tolist().copy() + chosen_token_type_ids = torch.from_numpy(sample['chosen_token_type_ids'][:self.max_seq_len]) + rejected_token_type_ids = torch.from_numpy(sample['rejected_token_type_ids'][:self.max_seq_len]) else: token_type = type(sample['chosen_token_type_ids']) raise ValueError( From 375d2576b1d65f40ab65f002a8d416f176cb52ef Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Tue, 8 Jul 2025 14:25:50 -0400 Subject: [PATCH 14/20] pixel values into lists --- compose_rl/algorithms/offline/model_methods.py | 4 +--- compose_rl/data/preference_data.py | 1 - 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 00b74f73..f927617f 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -135,9 +135,7 @@ def pairwise_offline_forward( 'token_type_ids': torch.cat([chosen_token_type_ids, rejected_token_type_ids], dim=0), - 'pixel_values': - torch.cat([batch['pixel_values'], batch['pixel_values']], - dim=0), + 'pixel_values': batch['pixel_values'] * 2, # double the list } print("MULTIMODAL INPUTS") diff --git a/compose_rl/data/preference_data.py b/compose_rl/data/preference_data.py index 90d014b9..08dca8bb 100644 --- a/compose_rl/data/preference_data.py +++ b/compose_rl/data/preference_data.py @@ -202,7 +202,6 @@ def pairwise_preference_dataset_collate_fn( if is_multimodal: token_type_ids = torch.stack(token_type_ids) - pixel_values = torch.stack(pixel_values) return_dict['token_type_ids'] = token_type_ids return_dict['pixel_values'] = pixel_values From ebf43faaa3e9a072a62c8e3a7839a04ecee18862 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Tue, 8 Jul 2025 14:29:32 -0400 Subject: [PATCH 15/20] logging fix --- compose_rl/algorithms/offline/model_methods.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index f927617f..f05b7bde 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -140,7 +140,10 @@ def pairwise_offline_forward( print("MULTIMODAL INPUTS") for k, v in multimodal_inputs.items(): - print(f"{k}: {v.shape}") + if isinstance(v, torch.tensor): + print(f"{k}: {v.shape}") + else: + print(f"{k}: {len(v)}") inputs.update(multimodal_inputs) From daae699af73e053c4944822ff8ff3362f6a80200 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Tue, 8 Jul 2025 14:32:00 -0400 Subject: [PATCH 16/20] fix --- compose_rl/algorithms/offline/model_methods.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index f05b7bde..272fba8f 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -140,7 +140,7 @@ def pairwise_offline_forward( print("MULTIMODAL INPUTS") for k, v in multimodal_inputs.items(): - if isinstance(v, torch.tensor): + if isinstance(v, torch.Tensor): print(f"{k}: {v.shape}") else: print(f"{k}: {len(v)}") From 82e2f1d0256fd279c35b9975cbe0ca70ec15216f Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Tue, 8 Jul 2025 14:40:45 -0400 Subject: [PATCH 17/20] change back to tensor --- compose_rl/algorithms/offline/model_methods.py | 9 ++++----- compose_rl/data/preference_data.py | 1 + 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 272fba8f..00b74f73 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -135,15 +135,14 @@ def pairwise_offline_forward( 'token_type_ids': torch.cat([chosen_token_type_ids, rejected_token_type_ids], dim=0), - 'pixel_values': batch['pixel_values'] * 2, # double the list + 'pixel_values': + torch.cat([batch['pixel_values'], batch['pixel_values']], + dim=0), } print("MULTIMODAL INPUTS") for k, v in multimodal_inputs.items(): - if isinstance(v, torch.Tensor): - print(f"{k}: {v.shape}") - else: - print(f"{k}: {len(v)}") + print(f"{k}: {v.shape}") inputs.update(multimodal_inputs) diff --git a/compose_rl/data/preference_data.py b/compose_rl/data/preference_data.py index 08dca8bb..90d014b9 100644 --- a/compose_rl/data/preference_data.py +++ b/compose_rl/data/preference_data.py @@ -202,6 +202,7 @@ def pairwise_preference_dataset_collate_fn( if is_multimodal: token_type_ids = torch.stack(token_type_ids) + pixel_values = torch.stack(pixel_values) return_dict['token_type_ids'] = token_type_ids return_dict['pixel_values'] = pixel_values From dbeec36aaa0f298142ed247b5a40e1587e92fc40 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Tue, 8 Jul 2025 15:26:38 -0400 Subject: [PATCH 18/20] nd array for pixel_values --- compose_rl/data/preference_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compose_rl/data/preference_data.py b/compose_rl/data/preference_data.py index 90d014b9..f1cab307 100644 --- a/compose_rl/data/preference_data.py +++ b/compose_rl/data/preference_data.py @@ -348,7 +348,7 @@ def __getitem__(self, idx: int) -> dict[str, Any]: if 'pixel_values' in sample: if isinstance(sample['pixel_values'], np.ndarray): - pixel_values = torch.Tensor(sample['pixel_values']) + pixel_values = torch.from_numpy(sample['pixel_values']) elif isinstance(sample['pixel_values'], Image.Image): pil_to_tensor_transform = transforms.PILToTensor() pixel_values = pil_to_tensor_transform(sample['pixel_values']) From e521e728a0e3a55da711257c1b679a05cbe3457e Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Tue, 8 Jul 2025 15:41:18 -0400 Subject: [PATCH 19/20] fix --- .../algorithms/offline/model_methods.py | 4 +- compose_rl/data/preference_data.py | 56 +++++++++++-------- 2 files changed, 36 insertions(+), 24 deletions(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 00b74f73..abcc7fe9 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -140,9 +140,9 @@ def pairwise_offline_forward( dim=0), } - print("MULTIMODAL INPUTS") + print('MULTIMODAL INPUTS') for k, v in multimodal_inputs.items(): - print(f"{k}: {v.shape}") + print(f'{k}: {v.shape}') inputs.update(multimodal_inputs) diff --git a/compose_rl/data/preference_data.py b/compose_rl/data/preference_data.py index f1cab307..71fc2bd9 100644 --- a/compose_rl/data/preference_data.py +++ b/compose_rl/data/preference_data.py @@ -8,11 +8,10 @@ import numpy as np import torch -from streaming import StreamingDataset -from transformers import DataCollatorForLanguageModeling, PreTrainedTokenizer - from PIL import Image +from streaming import StreamingDataset from torchvision import transforms +from transformers import DataCollatorForLanguageModeling, PreTrainedTokenizer log = logging.getLogger(__name__) @@ -118,9 +117,11 @@ def pairwise_preference_dataset_collate_fn( rejected[-1] = tokenizer.eos_token_id # type: ignore if is_multimodal: - chosen_token_type_ids = chosen_token_type_ids[:-truncate_len] - rejected_token_type_ids = rejected_token_type_ids[:-truncate_len - ] + chosen_token_type_ids = chosen_token_type_ids[: + -truncate_len # type: ignore + ] + rejected_token_type_ids = rejected_token_type_ids[: # type: ignore + -truncate_len] # NOTE: GEMMA specific: 0 == text token chosen_token_type_ids[-1] = 0 @@ -148,14 +149,16 @@ def pairwise_preference_dataset_collate_fn( dim=-1, # type: ignore ) if is_multimodal: - cat_token_type_ids = torch.cat([ - cat_token_type_ids, - torch.zeros( - int(pad_len.item()), - dtype=cat_token_type_ids.dtype, - ), - ], - dim=-1) + cat_token_type_ids = torch.cat( + [ + cat_token_type_ids, # type: ignore + torch.zeros( + int(pad_len.item()), + dtype=cat_token_type_ids.dtype, # type: ignore + ), + ], + dim=-1, + ) attention_mask = torch.logical_not( torch.eq(cat_batch, tokenizer.pad_token_id), # type: ignore @@ -176,7 +179,7 @@ def pairwise_preference_dataset_collate_fn( rejected_rewards.append(sample['rejected_reward']) if is_multimodal: - token_type_ids.append(cat_token_type_ids) + token_type_ids.append(cat_token_type_ids) # type: ignore pixel_values.append(pixel_vals) input_ids = ref_collate_fn(input_ids)['input_ids'] @@ -200,7 +203,7 @@ def pairwise_preference_dataset_collate_fn( return_dict['chosen_reward'] = chosen_rewards return_dict['rejected_reward'] = rejected_rewards - if is_multimodal: + if is_multimodal: # type: ignore token_type_ids = torch.stack(token_type_ids) pixel_values = torch.stack(pixel_values) return_dict['token_type_ids'] = token_type_ids @@ -293,7 +296,7 @@ def __getitem__(self, idx: int) -> dict[str, Any]: idx (int): the index where we fetch the data in the StreamingDataset. """ sample = super().__getitem__(idx) - + # Handle prompt if available if isinstance(sample['chosen'], bytes): # Prepend the prompt to the chosen and rejected responses @@ -313,8 +316,14 @@ def __getitem__(self, idx: int) -> dict[str, Any]: elif isinstance(sample['chosen'], np.ndarray): if 'prompt' in sample: - sample['chosen'] = np.concatenate([sample['prompt'], sample['chosen']]) - sample['rejected'] = np.concatenate([sample['prompt'], sample['rejected']]) + sample['chosen'] = np.concatenate([ + sample['prompt'], + sample['chosen'], + ]) + sample['rejected'] = np.concatenate([ + sample['prompt'], + sample['rejected'], + ]) chosen = torch.from_numpy(sample['chosen'][:self.max_seq_len]) rejected = torch.from_numpy(sample['rejected'][:self.max_seq_len]) @@ -368,15 +377,18 @@ def __getitem__(self, idx: int) -> dict[str, Any]: 'rejected_token_type_ids', ) elif isinstance(sample['chosen_token_type_ids'], np.ndarray): - chosen_token_type_ids = torch.from_numpy(sample['chosen_token_type_ids'][:self.max_seq_len]) - rejected_token_type_ids = torch.from_numpy(sample['rejected_token_type_ids'][:self.max_seq_len]) + chosen_token_type_ids = torch.from_numpy( + sample['chosen_token_type_ids'][:self.max_seq_len], + ) + rejected_token_type_ids = torch.from_numpy( + sample['rejected_token_type_ids'][:self.max_seq_len], + ) else: token_type = type(sample['chosen_token_type_ids']) raise ValueError( f'Expect token_type_ids to be numpy.ndarray or bytes, but got {token_type}', ) - return_dict['pixel_values'] = pixel_values return_dict['chosen_token_type_ids'] = chosen_token_type_ids return_dict['rejected_token_type_ids'] = rejected_token_type_ids From 83d1e3bef5b42d2bb909ffd5b452705ad0a3fefd Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Tue, 8 Jul 2025 15:52:05 -0400 Subject: [PATCH 20/20] fix --- compose_rl/algorithms/offline/model_methods.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index abcc7fe9..95bc1567 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -140,10 +140,6 @@ def pairwise_offline_forward( dim=0), } - print('MULTIMODAL INPUTS') - for k, v in multimodal_inputs.items(): - print(f'{k}: {v.shape}') - inputs.update(multimodal_inputs) output_logits = model(