Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 40 additions & 10 deletions compose_rl/algorithms/offline/model_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ def pairwise_offline_forward(
if pad_token_id is None:
raise ValueError('Tokenizer must have a PAD token.')

is_multimodal = 'pixel_values' in batch.keys()
if is_multimodal and use_attention_sequence_id:
raise NotImplementedError(
'Using Sequence ID is not implemented for VLMs',
)

# If we can use attention sequence ID, we use this logic branch.
# This is determined by a value set in `train_dpo.py`
if use_attention_sequence_id:
Expand Down Expand Up @@ -104,18 +110,42 @@ def pairwise_offline_forward(
pad_token_id=0,
)

batch_cat_inputs = torch.cat([chosen_inputs, rejected_inputs], dim=0)
batch_attn_mask = torch.cat(
[
chosen_attention_mask,
rejected_attention_mask,
],
dim=0,
)
inputs = {
'input_ids':
torch.cat([chosen_inputs, rejected_inputs], dim=0),
'attention_mask':
torch.cat(
[
chosen_attention_mask,
rejected_attention_mask,
],
dim=0,
),
}

if is_multimodal:
chosen_token_type_ids, rejected_token_type_ids = extract_packed_chosen_rejected(
batch['token_type_ids'],
batch['chosen_len'],
batch['rejected_len'],
concat_seq_len,
pad_token_id=0,
)

# TODO: Ask if assuming same pixel inputs is ok?
multimodal_inputs = {
'token_type_ids':
torch.cat([chosen_token_type_ids, rejected_token_type_ids],
dim=0),
'pixel_values':
torch.cat([batch['pixel_values'], batch['pixel_values']],
dim=0),
}

inputs.update(multimodal_inputs)

output_logits = model(
batch_cat_inputs,
attention_mask=batch_attn_mask,
**inputs,
).logits

# Extract out the chosen and rejected logits along the batch dimension
Expand Down
4 changes: 4 additions & 0 deletions compose_rl/algorithms/online/model_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ def composer_online_rl_forward(
model_forward_kwargs['action_mask'] = batch['action_mask']
model_forward_kwargs['max_gen_len'] = batch['max_gen_len']

if 'pixel_values' in batch.keys():
model_forward_kwargs['token_type_ids'] = batch['token_type_ids']
model_forward_kwargs['pixel_values'] = batch['pixel_values']

actor_output = model(batch['obs'], **model_forward_kwargs)

logits = actor_output.logits
Expand Down
154 changes: 142 additions & 12 deletions compose_rl/data/preference_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

import numpy as np
import torch
from PIL import Image
from streaming import StreamingDataset
from torchvision import transforms
from transformers import DataCollatorForLanguageModeling, PreTrainedTokenizer

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -56,13 +58,28 @@ def pairwise_preference_dataset_collate_fn(
chosen_rewards = []
rejected_rewards = []

# For VLMs
token_type_ids = []
pixel_values = []

for sample in data:
chosen = sample['chosen']
rejected = sample['rejected']
prompt_len = sample['prompt_len']
chosen_len = sample['chosen_len']
rejected_len = sample['rejected_len']

is_multimodal = 'pixel_values' in sample.keys()
if is_multimodal:
pixel_vals = sample['pixel_values']
chosen_token_type_ids = sample['chosen_token_type_ids']
rejected_token_type_ids = sample['rejected_token_type_ids']
else:
pixel_vals = None
chosen_token_type_ids = None
rejected_token_type_ids = None
cat_token_type_ids = None

# Note: if we do any truncation, we force the last token to be EOS
# https://github.com/mosaicml/RLHF/issues/101

Expand All @@ -75,6 +92,13 @@ def pairwise_preference_dataset_collate_fn(
pad_len = max_seq_len - chosen_len - rejected_len
cat_batch = torch.cat([chosen, rejected], dim=-1)

if is_multimodal:
cat_token_type_ids = torch.cat([
chosen_token_type_ids,
rejected_token_type_ids,
],
dim=-1)

if pad_len < 0:
# We should truncate chosen and rejected by the same amount
truncate_len = abs(pad_len // 2) + 1
Expand All @@ -92,6 +116,22 @@ def pairwise_preference_dataset_collate_fn(
rejected = rejected[:-truncate_len]
rejected[-1] = tokenizer.eos_token_id # type: ignore

if is_multimodal:
chosen_token_type_ids = chosen_token_type_ids[:
-truncate_len # type: ignore
]
rejected_token_type_ids = rejected_token_type_ids[: # type: ignore
-truncate_len]

# NOTE: GEMMA specific: 0 == text token
chosen_token_type_ids[-1] = 0
rejected_token_type_ids[-1] = 0
cat_token_type_ids = torch.cat([
chosen_token_type_ids,
rejected_token_type_ids,
],
dim=-1)

cat_batch = torch.cat([chosen, rejected], dim=-1)

chosen_len = torch.tensor([len(chosen)])
Expand All @@ -108,6 +148,17 @@ def pairwise_preference_dataset_collate_fn(
],
dim=-1, # type: ignore
)
if is_multimodal:
cat_token_type_ids = torch.cat(
[
cat_token_type_ids, # type: ignore
torch.zeros(
int(pad_len.item()),
dtype=cat_token_type_ids.dtype, # type: ignore
),
],
dim=-1,
)

attention_mask = torch.logical_not(
torch.eq(cat_batch, tokenizer.pad_token_id), # type: ignore
Expand All @@ -127,6 +178,10 @@ def pairwise_preference_dataset_collate_fn(
chosen_rewards.append(sample['chosen_reward'])
rejected_rewards.append(sample['rejected_reward'])

if is_multimodal:
token_type_ids.append(cat_token_type_ids) # type: ignore
pixel_values.append(pixel_vals)

input_ids = ref_collate_fn(input_ids)['input_ids']
attention_masks = torch.stack(attention_masks)
sequence_id = torch.stack(sequence_id)
Expand All @@ -147,6 +202,13 @@ def pairwise_preference_dataset_collate_fn(
rejected_rewards = torch.stack(rejected_rewards)
return_dict['chosen_reward'] = chosen_rewards
return_dict['rejected_reward'] = rejected_rewards

if is_multimodal: # type: ignore
token_type_ids = torch.stack(token_type_ids)
pixel_values = torch.stack(pixel_values)
return_dict['token_type_ids'] = token_type_ids
return_dict['pixel_values'] = pixel_values

return return_dict


Expand Down Expand Up @@ -234,21 +296,50 @@ def __getitem__(self, idx: int) -> dict[str, Any]:
idx (int): the index where we fetch the data in the StreamingDataset.
"""
sample = super().__getitem__(idx)

# Handle prompt if available
if 'prompt' in sample:
if isinstance(sample['chosen'], bytes):
# Prepend the prompt to the chosen and rejected responses
sample['chosen'] = sample['prompt'] + sample['chosen']
sample['rejected'] = sample['prompt'] + sample['rejected']
chosen = self._read_binary_tokenized_sample(sample, 'chosen')
rejected = self._read_binary_tokenized_sample(sample, 'rejected')

if 'prompt' in sample:
prompt = self._read_binary_tokenized_sample(sample, 'prompt')
prompt_len = len(prompt)
if 'prompt' in sample:
sample['chosen'] = sample['prompt'] + sample['chosen']
sample['rejected'] = sample['prompt'] + sample['rejected']
chosen = self._read_binary_tokenized_sample(sample, 'chosen')
rejected = self._read_binary_tokenized_sample(sample, 'rejected')

if 'prompt' in sample:
prompt = self._read_binary_tokenized_sample(sample, 'prompt')
prompt_len = len(prompt)
else:
# Only use prefix matching version of prompt_len when
# 'prompt' is not directly given in the sample
prompt_len = self.find_prompt_length(chosen, rejected)

elif isinstance(sample['chosen'], np.ndarray):
if 'prompt' in sample:
sample['chosen'] = np.concatenate([
sample['prompt'],
sample['chosen'],
])
sample['rejected'] = np.concatenate([
sample['prompt'],
sample['rejected'],
])

chosen = torch.from_numpy(sample['chosen'][:self.max_seq_len])
rejected = torch.from_numpy(sample['rejected'][:self.max_seq_len])

if 'prompt' in sample:
prompt_len = len(sample['prompt'])
else:
# Only use prefix matching version of prompt_len when
# 'prompt' is not directly given in the sample
prompt_len = self.find_prompt_length(chosen, rejected)
else:
# Only use prefix matching version of prompt_len when
# 'prompt' is not directly given in the sample
prompt_len = self.find_prompt_length(chosen, rejected)
token_type = type(sample['chosen'])
raise ValueError(
f'Expect prompt and response to be bytes or numpy.ndarray type, but got {token_type}',
)

chosen_len, rejected_len = len(chosen), len(rejected)
return_dict = {
'chosen': chosen,
Expand All @@ -263,6 +354,45 @@ def __getitem__(self, idx: int) -> dict[str, Any]:
rejected_reward = torch.Tensor([sample['rejected_reward']])
return_dict['chosen_reward'] = chosen_reward
return_dict['rejected_reward'] = rejected_reward

if 'pixel_values' in sample:
if isinstance(sample['pixel_values'], np.ndarray):
pixel_values = torch.from_numpy(sample['pixel_values'])
elif isinstance(sample['pixel_values'], Image.Image):
pil_to_tensor_transform = transforms.PILToTensor()
pixel_values = pil_to_tensor_transform(sample['pixel_values'])
else:
pixel_values_type = type(sample['pixel_values'])
raise ValueError(
f'Expect pixel values to be numpy.ndarray or PIL.Image type, but got {pixel_values_type}',
)

if isinstance(sample['chosen_token_type_ids'], bytes):
chosen_token_type_ids = self._read_binary_tokenized_sample(
sample,
'chosen_token_type_ids',
)
rejected_token_type_ids = self._read_binary_tokenized_sample(
sample,
'rejected_token_type_ids',
)
elif isinstance(sample['chosen_token_type_ids'], np.ndarray):
chosen_token_type_ids = torch.from_numpy(
sample['chosen_token_type_ids'][:self.max_seq_len],
)
rejected_token_type_ids = torch.from_numpy(
sample['rejected_token_type_ids'][:self.max_seq_len],
)
else:
token_type = type(sample['chosen_token_type_ids'])
raise ValueError(
f'Expect token_type_ids to be numpy.ndarray or bytes, but got {token_type}',
)

return_dict['pixel_values'] = pixel_values
return_dict['chosen_token_type_ids'] = chosen_token_type_ids
return_dict['rejected_token_type_ids'] = rejected_token_type_ids

return return_dict

def find_prompt_length(self, seq_1: torch.Tensor, seq_2: torch.Tensor):
Expand Down
Loading