Skip to content

Commit c801910

Browse files
committed
[FLAVA] Move projections from contrastive loss to model
ghstack-source-id: f8b9173 Pull Request resolved: #106
1 parent 5ec627a commit c801910

File tree

3 files changed

+79
-34
lines changed

3 files changed

+79
-34
lines changed

test/models/flava/test_flava.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,8 @@ def setUp(self):
167167
mm_encoder=mm_encoder,
168168
image_to_mm_projection=image_to_mm_projection,
169169
text_to_mm_projection=text_to_mm_projection,
170+
text_projection=nn.Identity(),
171+
image_projection=nn.Identity(),
170172
)
171173

172174
def _assert_empty(self, field):

torchmultimodal/models/flava/flava_model.py

Lines changed: 67 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,17 @@
3737

3838
FLAVAOutput = namedtuple(
3939
"FLAVAOutput",
40-
["image", "image_masked", "text", "text_masked", "multimodal", "multimodal_masked"],
41-
defaults=(None, None, None, None, None, None),
40+
[
41+
"image",
42+
"image_masked",
43+
"text",
44+
"text_masked",
45+
"multimodal",
46+
"multimodal_masked",
47+
"projected_image_embeddings",
48+
"projected_text_embeddings",
49+
],
50+
defaults=(None, None, None, None, None, None, None, None),
4251
)
4352
FLAVAOutput.__annotations__ = {
4453
"image": FLAVATransformerOutput,
@@ -51,7 +60,7 @@
5160

5261

5362
FLAVA_FOR_PRETRAINED_MAPPING = {
54-
"flava_full": "https://huggingface.co/aps/flava_full_pretrained_encoders_torchmm/resolve/main/pytorch_model.bin",
63+
"flava_full": "https://download.pytorch.org/models/multimodal/flava/flava_for_pretraining.pt"
5564
}
5665

5766

@@ -124,6 +133,8 @@ def flava_model(
124133
multimodal_intermediate_activation: Callable[..., Tensor] = nn.functional.gelu,
125134
multimodal_attention_probs_dropout_prob: float = 0.0,
126135
multimodal_layer_norm_eps: float = 1e-12,
136+
# projection
137+
text_and_image_proj_size: int = 768,
127138
**kwargs: Any,
128139
):
129140
image_encoder = flava_image_encoder(
@@ -169,12 +180,17 @@ def flava_model(
169180
image_to_mm_projection = nn.Linear(image_hidden_size, multimodal_hidden_size)
170181
text_to_mm_projection = nn.Linear(text_hidden_size, multimodal_hidden_size)
171182

183+
image_projection = nn.Linear(image_hidden_size, text_and_image_proj_size)
184+
text_projection = nn.Linear(text_hidden_size, text_and_image_proj_size)
185+
172186
return FLAVAModel(
173187
image_encoder=image_encoder,
174188
text_encoder=text_encoder,
175189
mm_encoder=mm_encoder,
176190
image_to_mm_projection=image_to_mm_projection,
177191
text_to_mm_projection=text_to_mm_projection,
192+
text_projection=text_projection,
193+
image_projection=image_projection,
178194
)
179195

180196

@@ -254,6 +270,8 @@ def __init__(
254270
mm_encoder: nn.Module,
255271
image_to_mm_projection: nn.Module,
256272
text_to_mm_projection: nn.Module,
273+
text_projection: nn.Module,
274+
image_projection: nn.Module,
257275
**kwargs: Any,
258276
):
259277
super().__init__()
@@ -262,6 +280,8 @@ def __init__(
262280
self.mm_encoder = mm_encoder
263281
self.image_to_mm_projection = image_to_mm_projection
264282
self.text_to_mm_projection = text_to_mm_projection
283+
self.text_projection = text_projection
284+
self.image_projection = image_projection
265285

266286
def forward(
267287
self,
@@ -280,18 +300,30 @@ def forward(
280300
else:
281301
required_embedding = "text"
282302

283-
image_outputs = self._encode_data_to_embeddings(
303+
image_encoding_out = self._encode_data_to_embeddings(
284304
image,
285305
required_embedding,
286306
["image", "mm"],
287-
self.encode_image,
307+
partial(self.encode_image, projection=True),
288308
)
289-
text_outputs = self._encode_data_to_embeddings(
309+
if len(image_encoding_out) == 2:
310+
image_outputs, projected_image_embeddings = image_encoding_out
311+
else:
312+
image_outputs = image_encoding_out
313+
projected_image_embeddings = None
314+
315+
text_encoding_out = self._encode_data_to_embeddings(
290316
text,
291317
required_embedding,
292318
["text", "mm"],
293-
self.encode_text,
319+
partial(self.encode_text, projection=True),
294320
)
321+
if len(text_encoding_out) == 2:
322+
text_outputs, projected_text_embeddings = text_encoding_out
323+
else:
324+
text_outputs = text_encoding_out
325+
projected_text_embeddings = None
326+
295327
image_masked_outputs = self._encode_data_to_embeddings(
296328
image,
297329
required_embedding,
@@ -337,26 +369,41 @@ def forward(
337369
text_masked=text_masked_outputs,
338370
multimodal=multimodal_outputs,
339371
multimodal_masked=multimodal_masked_outputs,
372+
projected_image_embeddings=projected_image_embeddings,
373+
projected_text_embeddings=projected_text_embeddings,
340374
)
341375

342376
def encode_image(
343-
self, image: Tensor, image_patches_mask: Optional[Tensor] = None
377+
self,
378+
image: Tensor,
379+
image_patches_mask: Optional[Tensor] = None,
380+
projection: bool = False,
344381
) -> Optional[FLAVATransformerOutput]:
345382
if image_patches_mask is not None:
346-
return self.image_encoder(image, image_patches_mask)
383+
encoded_image = self.image_encoder(image, image_patches_mask)
347384
else:
348-
return self.image_encoder(image)
385+
encoded_image = self.image_encoder(image)
386+
if projection:
387+
projected_embeddings = self.image_projection(
388+
encoded_image.last_hidden_state[:, 0, :]
389+
)
390+
return encoded_image, projected_embeddings
391+
return encoded_image
349392

350393
def encode_text(
351-
self,
352-
text: Tensor,
353-
text_mask: Optional[Tensor] = None,
394+
self, text: Tensor, text_mask: Optional[Tensor] = None, projection: bool = False
354395
) -> Optional[FLAVATransformerOutput]:
355396
# TODO(asg): Give proper parameter names when implementing text encoder
356-
return self.text_encoder(
397+
encoded_text = self.text_encoder(
357398
input_ids=text,
358399
attention_mask=text_mask,
359400
)
401+
if projection:
402+
projected_embeddings = self.text_projection(
403+
encoded_text.last_hidden_state[:, 0, :]
404+
)
405+
return encoded_text, projected_embeddings
406+
return encoded_text
360407

361408
def _encode_data_to_embeddings(
362409
self,
@@ -369,7 +416,6 @@ def _encode_data_to_embeddings(
369416

370417
if data is not None and selected_head_encoder in encoder_options:
371418
output = encode_callable(data)
372-
373419
return output
374420

375421
def encode_mm(
@@ -404,19 +450,17 @@ def encode_image(
404450
image: Tensor,
405451
cls_index: int = 0,
406452
):
407-
transformer_output = self.model.encode_image(image)
408-
embeddings = transformer_output.last_hidden_state
409-
return self.loss.contrastive_loss.image_projection(embeddings[:, cls_index, :])
453+
_, encoded_image = self.model.encode_image(image, projection=True)
454+
return encoded_image
410455

411456
def encode_text(
412457
self,
413458
text: Tensor,
414459
text_mask: Optional[Tensor] = None,
415460
cls_index: int = 0,
416461
):
417-
transformer_output = self.model.encode_text(text, text_mask)
418-
embeddings = transformer_output.last_hidden_state
419-
return self.loss.contrastive_loss.text_projection(embeddings[:, cls_index, :])
462+
_, encoded_text = self.model.encode_text(text, text_mask, projection=True)
463+
return encoded_text
420464

421465
# TODO: Add options to enable losses selectively
422466
def forward(
@@ -458,6 +502,8 @@ def forward(
458502
itm_labels=itm_labels,
459503
mim_labels=image_labels,
460504
mlm_labels=mlm_labels,
505+
projected_image_embeddings=flava_output.projected_image_embeddings,
506+
projected_text_embeddings=flava_output.projected_text_embeddings,
461507
)
462508

463509

torchmultimodal/modules/losses/flava.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -249,22 +249,16 @@ def __init__(
249249
else:
250250
self.logit_scale = nn.Parameter(logit_scale * torch.ones([]))
251251

252-
self.image_projection = nn.Linear(image_embedding_size, projection_size)
253-
self.text_projection = nn.Linear(text_embedding_size, projection_size)
254-
self.image_embedding_index = image_embedding_index
255-
self.text_embedding_index = text_embedding_index
256-
257252
def forward(
258253
self,
259254
image_sequence: Tensor,
260255
text_sequence: Tensor,
261256
mask: Tensor,
262257
):
263-
text_embedding = nn.functional.normalize(
264-
self.text_projection(text_sequence[:, self.text_embedding_index, :]), dim=-1
265-
)
258+
259+
text_embedding = nn.functional.normalize(text_sequence, dim=-1)
266260
image_embedding = nn.functional.normalize(
267-
self.image_projection(image_sequence[:, self.image_embedding_index, :]),
261+
image_sequence,
268262
dim=-1,
269263
)
270264

@@ -376,18 +370,20 @@ def forward(
376370
itm_labels: Optional[Tensor] = None,
377371
mim_labels: Optional[Tensor] = None,
378372
mlm_labels: Optional[Tensor] = None,
373+
projected_image_embeddings: Optional[Tensor] = None,
374+
projected_text_embeddings: Optional[Tensor] = None,
379375
) -> FLAVAPretrainingLossOutput:
380376
outputs = FLAVAPretrainingLossOutput()
381377
pos_mask = None
382378

383379
if (
384-
image_sequence is not None
385-
and text_sequence is not None
380+
projected_image_embeddings is not None
381+
and projected_text_embeddings is not None
386382
and self.contrastive_loss_weight > 0
387383
):
388384
outputs.global_contrastive_output = self.contrastive_loss(
389-
image_sequence,
390-
text_sequence,
385+
projected_image_embeddings,
386+
projected_text_embeddings,
391387
pos_mask,
392388
)
393389
outputs.global_contrastive_output.loss *= self.contrastive_loss_weight
@@ -398,6 +394,7 @@ def forward(
398394
# Check multimodal_masked_sequence to make sure this is unimodal case
399395
# This specific case can though be backpropagated directly as MIM is independent of
400396
# text, but that is a research question :)
397+
401398
if (
402399
image_masked_sequence is not None
403400
and self.mim_weight > 0

0 commit comments

Comments
 (0)