diff --git a/compose_rl/algorithms/offline/model_methods.py b/compose_rl/algorithms/offline/model_methods.py index 4da46019..6f6ace20 100644 --- a/compose_rl/algorithms/offline/model_methods.py +++ b/compose_rl/algorithms/offline/model_methods.py @@ -67,6 +67,12 @@ def pairwise_offline_forward( if pad_token_id is None: raise ValueError('Tokenizer must have a PAD token.') + is_multimodal = 'pixel_values' in batch.keys() + if is_multimodal and use_attention_sequence_id: + raise NotImplementedError( + 'Using Sequence ID is not implemented for VLMs', + ) + # If we can use attention sequence ID, we use this logic branch. # This is determined by a value set in `train_dpo.py` if use_attention_sequence_id: @@ -104,18 +110,42 @@ def pairwise_offline_forward( pad_token_id=0, ) - batch_cat_inputs = torch.cat([chosen_inputs, rejected_inputs], dim=0) - batch_attn_mask = torch.cat( - [ - chosen_attention_mask, - rejected_attention_mask, - ], - dim=0, - ) + inputs = { + 'input_ids': + torch.cat([chosen_inputs, rejected_inputs], dim=0), + 'attention_mask': + torch.cat( + [ + chosen_attention_mask, + rejected_attention_mask, + ], + dim=0, + ), + } + + if is_multimodal: + chosen_token_type_ids, rejected_token_type_ids = extract_packed_chosen_rejected( + batch['token_type_ids'], + batch['chosen_len'], + batch['rejected_len'], + concat_seq_len, + pad_token_id=0, + ) + + # TODO: Ask if assuming same pixel inputs is ok? + multimodal_inputs = { + 'token_type_ids': + torch.cat([chosen_token_type_ids, rejected_token_type_ids], + dim=0), + 'pixel_values': + torch.cat([batch['pixel_values'], batch['pixel_values']], + dim=0), + } + + inputs.update(multimodal_inputs) output_logits = model( - batch_cat_inputs, - attention_mask=batch_attn_mask, + **inputs, ).logits # Extract out the chosen and rejected logits along the batch dimension diff --git a/compose_rl/algorithms/online/model_methods.py b/compose_rl/algorithms/online/model_methods.py index acd00f4e..e86463b8 100644 --- a/compose_rl/algorithms/online/model_methods.py +++ b/compose_rl/algorithms/online/model_methods.py @@ -124,6 +124,10 @@ def composer_online_rl_forward( model_forward_kwargs['action_mask'] = batch['action_mask'] model_forward_kwargs['max_gen_len'] = batch['max_gen_len'] + if 'pixel_values' in batch.keys(): + model_forward_kwargs['token_type_ids'] = batch['token_type_ids'] + model_forward_kwargs['pixel_values'] = batch['pixel_values'] + actor_output = model(batch['obs'], **model_forward_kwargs) logits = actor_output.logits diff --git a/compose_rl/data/preference_data.py b/compose_rl/data/preference_data.py index c08d2757..71fc2bd9 100644 --- a/compose_rl/data/preference_data.py +++ b/compose_rl/data/preference_data.py @@ -8,7 +8,9 @@ import numpy as np import torch +from PIL import Image from streaming import StreamingDataset +from torchvision import transforms from transformers import DataCollatorForLanguageModeling, PreTrainedTokenizer log = logging.getLogger(__name__) @@ -56,6 +58,10 @@ def pairwise_preference_dataset_collate_fn( chosen_rewards = [] rejected_rewards = [] + # For VLMs + token_type_ids = [] + pixel_values = [] + for sample in data: chosen = sample['chosen'] rejected = sample['rejected'] @@ -63,6 +69,17 @@ def pairwise_preference_dataset_collate_fn( chosen_len = sample['chosen_len'] rejected_len = sample['rejected_len'] + is_multimodal = 'pixel_values' in sample.keys() + if is_multimodal: + pixel_vals = sample['pixel_values'] + chosen_token_type_ids = sample['chosen_token_type_ids'] + rejected_token_type_ids = sample['rejected_token_type_ids'] + else: + pixel_vals = None + chosen_token_type_ids = None + rejected_token_type_ids = None + cat_token_type_ids = None + # Note: if we do any truncation, we force the last token to be EOS # https://github.com/mosaicml/RLHF/issues/101 @@ -75,6 +92,13 @@ def pairwise_preference_dataset_collate_fn( pad_len = max_seq_len - chosen_len - rejected_len cat_batch = torch.cat([chosen, rejected], dim=-1) + if is_multimodal: + cat_token_type_ids = torch.cat([ + chosen_token_type_ids, + rejected_token_type_ids, + ], + dim=-1) + if pad_len < 0: # We should truncate chosen and rejected by the same amount truncate_len = abs(pad_len // 2) + 1 @@ -92,6 +116,22 @@ def pairwise_preference_dataset_collate_fn( rejected = rejected[:-truncate_len] rejected[-1] = tokenizer.eos_token_id # type: ignore + if is_multimodal: + chosen_token_type_ids = chosen_token_type_ids[: + -truncate_len # type: ignore + ] + rejected_token_type_ids = rejected_token_type_ids[: # type: ignore + -truncate_len] + + # NOTE: GEMMA specific: 0 == text token + chosen_token_type_ids[-1] = 0 + rejected_token_type_ids[-1] = 0 + cat_token_type_ids = torch.cat([ + chosen_token_type_ids, + rejected_token_type_ids, + ], + dim=-1) + cat_batch = torch.cat([chosen, rejected], dim=-1) chosen_len = torch.tensor([len(chosen)]) @@ -108,6 +148,17 @@ def pairwise_preference_dataset_collate_fn( ], dim=-1, # type: ignore ) + if is_multimodal: + cat_token_type_ids = torch.cat( + [ + cat_token_type_ids, # type: ignore + torch.zeros( + int(pad_len.item()), + dtype=cat_token_type_ids.dtype, # type: ignore + ), + ], + dim=-1, + ) attention_mask = torch.logical_not( torch.eq(cat_batch, tokenizer.pad_token_id), # type: ignore @@ -127,6 +178,10 @@ def pairwise_preference_dataset_collate_fn( chosen_rewards.append(sample['chosen_reward']) rejected_rewards.append(sample['rejected_reward']) + if is_multimodal: + token_type_ids.append(cat_token_type_ids) # type: ignore + pixel_values.append(pixel_vals) + input_ids = ref_collate_fn(input_ids)['input_ids'] attention_masks = torch.stack(attention_masks) sequence_id = torch.stack(sequence_id) @@ -147,6 +202,13 @@ def pairwise_preference_dataset_collate_fn( rejected_rewards = torch.stack(rejected_rewards) return_dict['chosen_reward'] = chosen_rewards return_dict['rejected_reward'] = rejected_rewards + + if is_multimodal: # type: ignore + token_type_ids = torch.stack(token_type_ids) + pixel_values = torch.stack(pixel_values) + return_dict['token_type_ids'] = token_type_ids + return_dict['pixel_values'] = pixel_values + return return_dict @@ -234,21 +296,50 @@ def __getitem__(self, idx: int) -> dict[str, Any]: idx (int): the index where we fetch the data in the StreamingDataset. """ sample = super().__getitem__(idx) + # Handle prompt if available - if 'prompt' in sample: + if isinstance(sample['chosen'], bytes): # Prepend the prompt to the chosen and rejected responses - sample['chosen'] = sample['prompt'] + sample['chosen'] - sample['rejected'] = sample['prompt'] + sample['rejected'] - chosen = self._read_binary_tokenized_sample(sample, 'chosen') - rejected = self._read_binary_tokenized_sample(sample, 'rejected') - - if 'prompt' in sample: - prompt = self._read_binary_tokenized_sample(sample, 'prompt') - prompt_len = len(prompt) + if 'prompt' in sample: + sample['chosen'] = sample['prompt'] + sample['chosen'] + sample['rejected'] = sample['prompt'] + sample['rejected'] + chosen = self._read_binary_tokenized_sample(sample, 'chosen') + rejected = self._read_binary_tokenized_sample(sample, 'rejected') + + if 'prompt' in sample: + prompt = self._read_binary_tokenized_sample(sample, 'prompt') + prompt_len = len(prompt) + else: + # Only use prefix matching version of prompt_len when + # 'prompt' is not directly given in the sample + prompt_len = self.find_prompt_length(chosen, rejected) + + elif isinstance(sample['chosen'], np.ndarray): + if 'prompt' in sample: + sample['chosen'] = np.concatenate([ + sample['prompt'], + sample['chosen'], + ]) + sample['rejected'] = np.concatenate([ + sample['prompt'], + sample['rejected'], + ]) + + chosen = torch.from_numpy(sample['chosen'][:self.max_seq_len]) + rejected = torch.from_numpy(sample['rejected'][:self.max_seq_len]) + + if 'prompt' in sample: + prompt_len = len(sample['prompt']) + else: + # Only use prefix matching version of prompt_len when + # 'prompt' is not directly given in the sample + prompt_len = self.find_prompt_length(chosen, rejected) else: - # Only use prefix matching version of prompt_len when - # 'prompt' is not directly given in the sample - prompt_len = self.find_prompt_length(chosen, rejected) + token_type = type(sample['chosen']) + raise ValueError( + f'Expect prompt and response to be bytes or numpy.ndarray type, but got {token_type}', + ) + chosen_len, rejected_len = len(chosen), len(rejected) return_dict = { 'chosen': chosen, @@ -263,6 +354,45 @@ def __getitem__(self, idx: int) -> dict[str, Any]: rejected_reward = torch.Tensor([sample['rejected_reward']]) return_dict['chosen_reward'] = chosen_reward return_dict['rejected_reward'] = rejected_reward + + if 'pixel_values' in sample: + if isinstance(sample['pixel_values'], np.ndarray): + pixel_values = torch.from_numpy(sample['pixel_values']) + elif isinstance(sample['pixel_values'], Image.Image): + pil_to_tensor_transform = transforms.PILToTensor() + pixel_values = pil_to_tensor_transform(sample['pixel_values']) + else: + pixel_values_type = type(sample['pixel_values']) + raise ValueError( + f'Expect pixel values to be numpy.ndarray or PIL.Image type, but got {pixel_values_type}', + ) + + if isinstance(sample['chosen_token_type_ids'], bytes): + chosen_token_type_ids = self._read_binary_tokenized_sample( + sample, + 'chosen_token_type_ids', + ) + rejected_token_type_ids = self._read_binary_tokenized_sample( + sample, + 'rejected_token_type_ids', + ) + elif isinstance(sample['chosen_token_type_ids'], np.ndarray): + chosen_token_type_ids = torch.from_numpy( + sample['chosen_token_type_ids'][:self.max_seq_len], + ) + rejected_token_type_ids = torch.from_numpy( + sample['rejected_token_type_ids'][:self.max_seq_len], + ) + else: + token_type = type(sample['chosen_token_type_ids']) + raise ValueError( + f'Expect token_type_ids to be numpy.ndarray or bytes, but got {token_type}', + ) + + return_dict['pixel_values'] = pixel_values + return_dict['chosen_token_type_ids'] = chosen_token_type_ids + return_dict['rejected_token_type_ids'] = rejected_token_type_ids + return return_dict def find_prompt_length(self, seq_1: torch.Tensor, seq_2: torch.Tensor):