Skip to content

Commit 679f359

Browse files
ankitadefacebook-github-bot
authored andcommitted
Make projections part of the core model (#106)
Summary: Pull Request resolved: #106 Move projections from the contrastive loss to the core model This will allow users to use the model (instead of the pretraining model) for doing zero shot Also moved to using the translated the checkpoint. Test plan 1. pytest 2. python -m flava.train config=flava/configs/pretraining/debug.yaml 3. python -m flava.finetune config=flava/configs/finetuning/qnli.yaml Test Plan: Imported from OSS Reviewed By: ebsmothers Differential Revision: D37481127 Pulled By: ankitade fbshipit-source-id: 71a639b867db61aef8ae317cdfb8b87e82e075c5
1 parent 5e3121d commit 679f359

File tree

3 files changed

+101
-39
lines changed

3 files changed

+101
-39
lines changed

test/models/flava/test_flava.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,8 @@ def setUp(self):
167167
mm_encoder=mm_encoder,
168168
image_to_mm_projection=image_to_mm_projection,
169169
text_to_mm_projection=text_to_mm_projection,
170+
text_projection=nn.Identity(),
171+
image_projection=nn.Identity(),
170172
)
171173

172174
def _assert_empty(self, field):

torchmultimodal/models/flava/flava_model.py

Lines changed: 89 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,17 @@
3737

3838
FLAVAOutput = namedtuple(
3939
"FLAVAOutput",
40-
["image", "image_masked", "text", "text_masked", "multimodal", "multimodal_masked"],
41-
defaults=(None, None, None, None, None, None),
40+
[
41+
"image",
42+
"image_masked",
43+
"text",
44+
"text_masked",
45+
"multimodal",
46+
"multimodal_masked",
47+
"projected_image_embeddings",
48+
"projected_text_embeddings",
49+
],
50+
defaults=(None, None, None, None, None, None, None, None),
4251
)
4352
FLAVAOutput.__annotations__ = {
4453
"image": FLAVATransformerOutput,
@@ -53,7 +62,7 @@
5362
FLAVA_FOR_PRETRAINED_MAPPING = {
5463
# This will no longer load with the updated model, but keeping here just in case
5564
# "flava_full": "https://huggingface.co/aps/flava_full_pretrained_encoders_torchmm/resolve/main/pytorch_model.bin",
56-
"flava_full": "https://download.pytorch.org/models/multimodal/flava/flava_self_attn.bin",
65+
"flava_full": "https://download.pytorch.org/models/multimodal/flava/flava_for_pretraining_projection.pt",
5766
}
5867

5968

@@ -101,6 +110,8 @@ def __init__(
101110
mm_encoder: nn.Module,
102111
image_to_mm_projection: nn.Module,
103112
text_to_mm_projection: nn.Module,
113+
text_projection: nn.Module,
114+
image_projection: nn.Module,
104115
**kwargs: Any,
105116
) -> None:
106117
super().__init__()
@@ -109,6 +120,8 @@ def __init__(
109120
self.mm_encoder = mm_encoder
110121
self.image_to_mm_projection = image_to_mm_projection
111122
self.text_to_mm_projection = text_to_mm_projection
123+
self.text_projection = text_projection
124+
self.image_projection = image_projection
112125

113126
def forward(
114127
self,
@@ -127,30 +140,50 @@ def forward(
127140
else:
128141
required_embedding = "text"
129142

130-
image_outputs = self._encode_data_to_embeddings(
143+
image_encoding_out = self._encode_data_to_embeddings(
131144
image,
132145
required_embedding,
133146
["image", "mm"],
134-
self.encode_image,
147+
partial(self.encode_image, projection=True),
135148
)
136-
text_outputs = self._encode_data_to_embeddings(
149+
if len(image_encoding_out) == 2:
150+
image_outputs, projected_image_embeddings = (
151+
image_encoding_out[0],
152+
image_encoding_out[1],
153+
)
154+
else:
155+
image_outputs = image_encoding_out
156+
projected_image_embeddings = None
157+
158+
text_encoding_out = self._encode_data_to_embeddings(
137159
text,
138160
required_embedding,
139161
["text", "mm"],
140-
self.encode_text,
162+
partial(self.encode_text, projection=True),
141163
)
164+
if len(text_encoding_out) == 2:
165+
text_outputs, projected_text_embeddings = (
166+
text_encoding_out[0],
167+
text_encoding_out[1],
168+
)
169+
else:
170+
text_outputs = text_encoding_out
171+
projected_text_embeddings = None
172+
142173
image_masked_outputs = self._encode_data_to_embeddings(
143174
image,
144175
required_embedding,
145176
["image", "mm"],
146177
partial(self.encode_image, image_patches_mask=image_patches_mask),
147178
)
179+
assert type(image_masked_outputs) == FLAVATransformerOutput
148180
text_masked_outputs = self._encode_data_to_embeddings(
149181
text_masked,
150182
required_embedding,
151183
["text", "mm"],
152184
self.encode_text,
153185
)
186+
assert type(text_masked_outputs) == FLAVATransformerOutput
154187

155188
multimodal_outputs = FLAVATransformerOutput()
156189
multimodal_masked_outputs = FLAVATransformerOutput()
@@ -184,39 +217,60 @@ def forward(
184217
text_masked=text_masked_outputs,
185218
multimodal=multimodal_outputs,
186219
multimodal_masked=multimodal_masked_outputs,
220+
projected_image_embeddings=projected_image_embeddings,
221+
projected_text_embeddings=projected_text_embeddings,
187222
)
188223

189224
def encode_image(
190-
self, image: Tensor, image_patches_mask: Optional[Tensor] = None
191-
) -> Optional[FLAVATransformerOutput]:
225+
self,
226+
image: Tensor,
227+
image_patches_mask: Optional[Tensor] = None,
228+
projection: bool = False,
229+
) -> Union[Tuple[FLAVATransformerOutput, Tensor], Optional[FLAVATransformerOutput]]:
192230
if image_patches_mask is not None:
193-
return self.image_encoder(image, image_patches_mask)
231+
encoded_image = self.image_encoder(image, image_patches_mask)
194232
else:
195-
return self.image_encoder(image)
233+
encoded_image = self.image_encoder(image)
234+
if projection:
235+
projected_embeddings = self.image_projection(
236+
encoded_image.last_hidden_state[:, 0, :]
237+
)
238+
return encoded_image, projected_embeddings
239+
return encoded_image
196240

197241
def encode_text(
198-
self,
199-
text: Tensor,
200-
text_mask: Optional[Tensor] = None,
201-
) -> Optional[FLAVATransformerOutput]:
242+
self, text: Tensor, text_mask: Optional[Tensor] = None, projection: bool = False
243+
) -> Union[Tuple[FLAVATransformerOutput, Tensor], Optional[FLAVATransformerOutput]]:
202244
# TODO(asg): Give proper parameter names when implementing text encoder
203-
return self.text_encoder(
245+
encoded_text = self.text_encoder(
204246
input_ids=text,
205247
attention_mask=text_mask,
206248
)
249+
if projection:
250+
projected_embeddings = self.text_projection(
251+
encoded_text.last_hidden_state[:, 0, :]
252+
)
253+
return encoded_text, projected_embeddings
254+
return encoded_text
207255

208256
def _encode_data_to_embeddings(
209257
self,
210258
data: Optional[Tensor],
211259
selected_head_encoder: EMBEDDING_OPTIONS,
212260
encoder_options: List[EMBEDDING_OPTIONS],
213-
encode_callable: Callable[..., FLAVATransformerOutput],
214-
) -> Optional[FLAVATransformerOutput]:
215-
output = FLAVATransformerOutput()
261+
encode_callable: Callable[
262+
...,
263+
Union[
264+
Tuple[FLAVATransformerOutput, Tensor], Optional[FLAVATransformerOutput]
265+
],
266+
],
267+
) -> Union[Tuple[FLAVATransformerOutput, Tensor], Optional[FLAVATransformerOutput]]:
268+
output: Union[
269+
Tuple[FLAVATransformerOutput, Tensor], FLAVATransformerOutput
270+
] = FLAVATransformerOutput()
216271

217272
if data is not None and selected_head_encoder in encoder_options:
218273
output = encode_callable(data)
219-
220274
return output
221275

222276
def encode_mm(
@@ -253,19 +307,19 @@ def encode_image(
253307
image: Tensor,
254308
cls_index: int = 0,
255309
) -> Tensor:
256-
transformer_output = self.model.encode_image(image)
257-
embeddings = transformer_output.last_hidden_state
258-
return self.loss.contrastive_loss.image_projection(embeddings[:, cls_index, :])
310+
encoded_result = self.model.encode_image(image, projection=True)
311+
encoded_image = encoded_result[1]
312+
return encoded_image
259313

260314
def encode_text(
261315
self,
262316
text: Tensor,
263317
text_mask: Optional[Tensor] = None,
264318
cls_index: int = 0,
265319
) -> Tensor:
266-
transformer_output = self.model.encode_text(text, text_mask)
267-
embeddings = transformer_output.last_hidden_state
268-
return self.loss.contrastive_loss.text_projection(embeddings[:, cls_index, :])
320+
encoded_result = self.model.encode_text(text, text_mask, projection=True)
321+
encoded_text = encoded_result[1]
322+
return encoded_text
269323

270324
# TODO: Add options to enable losses selectively
271325
def forward(
@@ -307,6 +361,8 @@ def forward(
307361
itm_labels=itm_labels,
308362
mim_labels=image_labels,
309363
mlm_labels=mlm_labels,
364+
projected_image_embeddings=flava_output.projected_image_embeddings,
365+
projected_text_embeddings=flava_output.projected_text_embeddings,
310366
)
311367

312368

@@ -394,6 +450,8 @@ def flava_model(
394450
multimodal_intermediate_activation: Callable[..., Tensor] = nn.functional.gelu,
395451
multimodal_attention_probs_dropout_prob: float = 0.0,
396452
multimodal_layer_norm_eps: float = 1e-12,
453+
# projection
454+
text_and_image_proj_size: int = 768,
397455
**kwargs: Any,
398456
) -> FLAVAModel:
399457
image_encoder = flava_image_encoder(
@@ -439,12 +497,17 @@ def flava_model(
439497
image_to_mm_projection = nn.Linear(image_hidden_size, multimodal_hidden_size)
440498
text_to_mm_projection = nn.Linear(text_hidden_size, multimodal_hidden_size)
441499

500+
image_projection = nn.Linear(image_hidden_size, text_and_image_proj_size)
501+
text_projection = nn.Linear(text_hidden_size, text_and_image_proj_size)
502+
442503
return FLAVAModel(
443504
image_encoder=image_encoder,
444505
text_encoder=text_encoder,
445506
mm_encoder=mm_encoder,
446507
image_to_mm_projection=image_to_mm_projection,
447508
text_to_mm_projection=text_to_mm_projection,
509+
text_projection=text_projection,
510+
image_projection=image_projection,
448511
)
449512

450513

torchmultimodal/modules/losses/flava.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -253,22 +253,16 @@ def __init__(
253253
else:
254254
self.logit_scale = nn.Parameter(logit_scale * torch.ones([]))
255255

256-
self.image_projection = nn.Linear(image_embedding_size, projection_size)
257-
self.text_projection = nn.Linear(text_embedding_size, projection_size)
258-
self.image_embedding_index = image_embedding_index
259-
self.text_embedding_index = text_embedding_index
260-
261256
def forward(
262257
self,
263258
image_sequence: Tensor,
264259
text_sequence: Tensor,
265260
mask: Tensor,
266261
) -> FLAVAGlobalContrastiveLossOutput:
267-
text_embedding = nn.functional.normalize(
268-
self.text_projection(text_sequence[:, self.text_embedding_index, :]), dim=-1
269-
)
262+
263+
text_embedding = nn.functional.normalize(text_sequence, dim=-1)
270264
image_embedding = nn.functional.normalize(
271-
self.image_projection(image_sequence[:, self.image_embedding_index, :]),
265+
image_sequence,
272266
dim=-1,
273267
)
274268

@@ -380,13 +374,16 @@ def forward(
380374
itm_labels: Optional[Tensor] = None,
381375
mim_labels: Optional[Tensor] = None,
382376
mlm_labels: Optional[Tensor] = None,
377+
projected_image_embeddings: Optional[Tensor] = None,
378+
projected_text_embeddings: Optional[Tensor] = None,
383379
) -> FLAVAPretrainingLossOutput:
384380
outputs = FLAVAPretrainingLossOutput()
385381
pos_mask = None
386382

387383
# Check multimodal_masked_sequence to make sure this is unimodal case
388384
# This specific case can though be backpropagated directly as MIM is independent of
389385
# text, but that is a research question :)
386+
390387
if (
391388
image_masked_sequence is not None
392389
and self.mim_weight > 0
@@ -466,13 +463,13 @@ def forward(
466463
outputs.losses.mmm_image_loss = outputs.mmm_image_output.loss
467464

468465
if (
469-
image_sequence is not None
470-
and text_sequence is not None
466+
projected_image_embeddings is not None
467+
and projected_text_embeddings is not None
471468
and self.contrastive_loss_weight > 0
472469
):
473470
outputs.global_contrastive_output = self.contrastive_loss(
474-
image_sequence,
475-
text_sequence,
471+
projected_image_embeddings,
472+
projected_text_embeddings,
476473
pos_mask,
477474
)
478475
outputs.global_contrastive_output.loss *= self.contrastive_loss_weight

0 commit comments

Comments
 (0)