Skip to content

Commit 6fac11d

Browse files
committed
[FLAVA] Move projections from contrastive loss to model
ghstack-source-id: e6b230c Pull Request resolved: #106
1 parent 599bc61 commit 6fac11d

File tree

3 files changed

+79
-34
lines changed

3 files changed

+79
-34
lines changed

test/models/flava/test_flava.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,8 @@ def setUp(self):
167167
mm_encoder=mm_encoder,
168168
image_to_mm_projection=image_to_mm_projection,
169169
text_to_mm_projection=text_to_mm_projection,
170+
text_projection=nn.Identity(),
171+
image_projection=nn.Identity(),
170172
)
171173

172174
def _assert_empty(self, field):

torchmultimodal/models/flava/flava_model.py

Lines changed: 67 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,17 @@
3737

3838
FLAVAOutput = namedtuple(
3939
"FLAVAOutput",
40-
["image", "image_masked", "text", "text_masked", "multimodal", "multimodal_masked"],
41-
defaults=(None, None, None, None, None, None),
40+
[
41+
"image",
42+
"image_masked",
43+
"text",
44+
"text_masked",
45+
"multimodal",
46+
"multimodal_masked",
47+
"projected_image_embeddings",
48+
"projected_text_embeddings",
49+
],
50+
defaults=(None, None, None, None, None, None, None, None),
4251
)
4352
FLAVAOutput.__annotations__ = {
4453
"image": FLAVATransformerOutput,
@@ -51,7 +60,7 @@
5160

5261

5362
FLAVA_FOR_PRETRAINED_MAPPING = {
54-
"flava_full": "https://huggingface.co/aps/flava_full_pretrained_encoders_torchmm/resolve/main/pytorch_model.bin",
63+
"flava_full": "https://download.pytorch.org/models/multimodal/flava/flava_for_pretraining.pt"
5564
}
5665

5766

@@ -124,6 +133,8 @@ def flava_model(
124133
multimodal_intermediate_activation: Callable[..., Tensor] = nn.functional.gelu,
125134
multimodal_attention_probs_dropout_prob: float = 0.0,
126135
multimodal_layer_norm_eps: float = 1e-12,
136+
# projection
137+
text_and_image_proj_size: int = 768,
127138
**kwargs: Any,
128139
):
129140
image_encoder = flava_image_encoder(
@@ -169,12 +180,17 @@ def flava_model(
169180
image_to_mm_projection = nn.Linear(image_hidden_size, multimodal_hidden_size)
170181
text_to_mm_projection = nn.Linear(text_hidden_size, multimodal_hidden_size)
171182

183+
image_projection = nn.Linear(image_hidden_size, text_and_image_proj_size)
184+
text_projection = nn.Linear(text_hidden_size, text_and_image_proj_size)
185+
172186
return FLAVAModel(
173187
image_encoder=image_encoder,
174188
text_encoder=text_encoder,
175189
mm_encoder=mm_encoder,
176190
image_to_mm_projection=image_to_mm_projection,
177191
text_to_mm_projection=text_to_mm_projection,
192+
text_projection=text_projection,
193+
image_projection=image_projection,
178194
)
179195

180196

@@ -253,6 +269,8 @@ def __init__(
253269
mm_encoder: nn.Module,
254270
image_to_mm_projection: nn.Module,
255271
text_to_mm_projection: nn.Module,
272+
text_projection: nn.Module,
273+
image_projection: nn.Module,
256274
**kwargs: Any,
257275
):
258276
super().__init__()
@@ -261,6 +279,8 @@ def __init__(
261279
self.mm_encoder = mm_encoder
262280
self.image_to_mm_projection = image_to_mm_projection
263281
self.text_to_mm_projection = text_to_mm_projection
282+
self.text_projection = text_projection
283+
self.image_projection = image_projection
264284

265285
def forward(
266286
self,
@@ -279,18 +299,30 @@ def forward(
279299
else:
280300
required_embedding = "text"
281301

282-
image_outputs = self._encode_data_to_embeddings(
302+
image_encoding_out = self._encode_data_to_embeddings(
283303
image,
284304
required_embedding,
285305
["image", "mm"],
286-
self.encode_image,
306+
partial(self.encode_image, projection=True),
287307
)
288-
text_outputs = self._encode_data_to_embeddings(
308+
if len(image_encoding_out) == 2:
309+
image_outputs, projected_image_embeddings = image_encoding_out
310+
else:
311+
image_outputs = image_encoding_out
312+
projected_image_embeddings = None
313+
314+
text_encoding_out = self._encode_data_to_embeddings(
289315
text,
290316
required_embedding,
291317
["text", "mm"],
292-
self.encode_text,
318+
partial(self.encode_text, projection=True),
293319
)
320+
if len(text_encoding_out) == 2:
321+
text_outputs, projected_text_embeddings = text_encoding_out
322+
else:
323+
text_outputs = text_encoding_out
324+
projected_text_embeddings = None
325+
294326
image_masked_outputs = self._encode_data_to_embeddings(
295327
image,
296328
required_embedding,
@@ -336,26 +368,41 @@ def forward(
336368
text_masked=text_masked_outputs,
337369
multimodal=multimodal_outputs,
338370
multimodal_masked=multimodal_masked_outputs,
371+
projected_image_embeddings=projected_image_embeddings,
372+
projected_text_embeddings=projected_text_embeddings,
339373
)
340374

341375
def encode_image(
342-
self, image: Tensor, image_patches_mask: Optional[Tensor] = None
376+
self,
377+
image: Tensor,
378+
image_patches_mask: Optional[Tensor] = None,
379+
projection: bool = False,
343380
) -> Optional[FLAVATransformerOutput]:
344381
if image_patches_mask is not None:
345-
return self.image_encoder(image, image_patches_mask)
382+
encoded_image = self.image_encoder(image, image_patches_mask)
346383
else:
347-
return self.image_encoder(image)
384+
encoded_image = self.image_encoder(image)
385+
if projection:
386+
projected_embeddings = self.image_projection(
387+
encoded_image.last_hidden_state[:, 0, :]
388+
)
389+
return encoded_image, projected_embeddings
390+
return encoded_image
348391

349392
def encode_text(
350-
self,
351-
text: Tensor,
352-
text_mask: Optional[Tensor] = None,
393+
self, text: Tensor, text_mask: Optional[Tensor] = None, projection: bool = False
353394
) -> Optional[FLAVATransformerOutput]:
354395
# TODO(asg): Give proper parameter names when implementing text encoder
355-
return self.text_encoder(
396+
encoded_text = self.text_encoder(
356397
input_ids=text,
357398
attention_mask=text_mask,
358399
)
400+
if projection:
401+
projected_embeddings = self.text_projection(
402+
encoded_text.last_hidden_state[:, 0, :]
403+
)
404+
return encoded_text, projected_embeddings
405+
return encoded_text
359406

360407
def _encode_data_to_embeddings(
361408
self,
@@ -368,7 +415,6 @@ def _encode_data_to_embeddings(
368415

369416
if data is not None and selected_head_encoder in encoder_options:
370417
output = encode_callable(data)
371-
372418
return output
373419

374420
def encode_mm(
@@ -403,19 +449,17 @@ def encode_image(
403449
image: Tensor,
404450
cls_index: int = 0,
405451
):
406-
transformer_output = self.model.encode_image(image)
407-
embeddings = transformer_output.last_hidden_state
408-
return self.loss.contrastive_loss.image_projection(embeddings[:, cls_index, :])
452+
_, encoded_image = self.model.encode_image(image, projection=True)
453+
return encoded_image
409454

410455
def encode_text(
411456
self,
412457
text: Tensor,
413458
text_mask: Optional[Tensor] = None,
414459
cls_index: int = 0,
415460
):
416-
transformer_output = self.model.encode_text(text, text_mask)
417-
embeddings = transformer_output.last_hidden_state
418-
return self.loss.contrastive_loss.text_projection(embeddings[:, cls_index, :])
461+
_, encoded_text = self.model.encode_text(text, text_mask, projection=True)
462+
return encoded_text
419463

420464
# TODO: Add options to enable losses selectively
421465
def forward(
@@ -457,6 +501,8 @@ def forward(
457501
itm_labels=itm_labels,
458502
mim_labels=image_labels,
459503
mlm_labels=mlm_labels,
504+
projected_image_embeddings=flava_output.projected_image_embeddings,
505+
projected_text_embeddings=flava_output.projected_text_embeddings,
460506
)
461507

462508

torchmultimodal/modules/losses/flava.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -249,22 +249,16 @@ def __init__(
249249
else:
250250
self.logit_scale = nn.Parameter(logit_scale * torch.ones([]))
251251

252-
self.image_projection = nn.Linear(image_embedding_size, projection_size)
253-
self.text_projection = nn.Linear(text_embedding_size, projection_size)
254-
self.image_embedding_index = image_embedding_index
255-
self.text_embedding_index = text_embedding_index
256-
257252
def forward(
258253
self,
259254
image_sequence: Tensor,
260255
text_sequence: Tensor,
261256
mask: Tensor,
262257
):
263-
text_embedding = nn.functional.normalize(
264-
self.text_projection(text_sequence[:, self.text_embedding_index, :]), dim=-1
265-
)
258+
259+
text_embedding = nn.functional.normalize(text_sequence, dim=-1)
266260
image_embedding = nn.functional.normalize(
267-
self.image_projection(image_sequence[:, self.image_embedding_index, :]),
261+
image_sequence,
268262
dim=-1,
269263
)
270264

@@ -376,18 +370,20 @@ def forward(
376370
itm_labels: Optional[Tensor] = None,
377371
mim_labels: Optional[Tensor] = None,
378372
mlm_labels: Optional[Tensor] = None,
373+
projected_image_embeddings: Optional[Tensor] = None,
374+
projected_text_embeddings: Optional[Tensor] = None,
379375
) -> FLAVAPretrainingLossOutput:
380376
outputs = FLAVAPretrainingLossOutput()
381377
pos_mask = None
382378

383379
if (
384-
image_sequence is not None
385-
and text_sequence is not None
380+
projected_image_embeddings is not None
381+
and projected_text_embeddings is not None
386382
and self.contrastive_loss_weight > 0
387383
):
388384
outputs.global_contrastive_output = self.contrastive_loss(
389-
image_sequence,
390-
text_sequence,
385+
projected_image_embeddings,
386+
projected_text_embeddings,
391387
pos_mask,
392388
)
393389
outputs.global_contrastive_output.loss *= self.contrastive_loss_weight
@@ -398,6 +394,7 @@ def forward(
398394
# Check multimodal_masked_sequence to make sure this is unimodal case
399395
# This specific case can though be backpropagated directly as MIM is independent of
400396
# text, but that is a research question :)
397+
401398
if (
402399
image_masked_sequence is not None
403400
and self.mim_weight > 0

0 commit comments

Comments
 (0)