Skip to content

Commit 182c75e

Browse files
committed
Temp CL
ghstack-source-id: 1b7477c Pull Request resolved: #106
1 parent b7b0166 commit 182c75e

File tree

2 files changed

+96
-24
lines changed

2 files changed

+96
-24
lines changed

torchmultimodal/models/flava/flava_model.py

Lines changed: 62 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,17 @@
3737

3838
FLAVAOutput = namedtuple(
3939
"FLAVAOutput",
40-
["image", "image_masked", "text", "text_masked", "multimodal", "multimodal_masked"],
41-
defaults=(None, None, None, None, None, None),
40+
[
41+
"image",
42+
"image_masked",
43+
"text",
44+
"text_masked",
45+
"multimodal",
46+
"multimodal_masked",
47+
"projected_image_embeddings",
48+
"projected_text_embeddings",
49+
],
50+
defaults=(None, None, None, None, None, None, None, None),
4251
)
4352
FLAVAOutput.__annotations__ = {
4453
"image": FLAVATransformerOutput,
@@ -124,6 +133,8 @@ def flava_model(
124133
multimodal_intermediate_activation: Callable[..., Tensor] = nn.functional.gelu,
125134
multimodal_attention_probs_dropout_prob: float = 0.0,
126135
multimodal_layer_norm_eps: float = 1e-12,
136+
# projection
137+
text_and_image_proj_size: int = 768,
127138
**kwargs: Any,
128139
):
129140
image_encoder = flava_image_encoder(
@@ -169,12 +180,17 @@ def flava_model(
169180
image_to_mm_projection = nn.Linear(image_hidden_size, multimodal_hidden_size)
170181
text_to_mm_projection = nn.Linear(text_hidden_size, multimodal_hidden_size)
171182

183+
image_projection = nn.Linear(image_hidden_size, text_and_image_proj_size)
184+
text_projection = nn.Linear(text_hidden_size, text_and_image_proj_size)
185+
172186
return FLAVAModel(
173187
image_encoder=image_encoder,
174188
text_encoder=text_encoder,
175189
mm_encoder=mm_encoder,
176190
image_to_mm_projection=image_to_mm_projection,
177191
text_to_mm_projection=text_to_mm_projection,
192+
text_projection=text_projection,
193+
image_projection=image_projection,
178194
)
179195

180196

@@ -246,6 +262,8 @@ def __init__(
246262
mm_encoder: nn.Module,
247263
image_to_mm_projection: nn.Module,
248264
text_to_mm_projection: nn.Module,
265+
text_projection: nn.Module,
266+
image_projection: nn.Module,
249267
**kwargs: Any,
250268
):
251269
super().__init__()
@@ -254,6 +272,8 @@ def __init__(
254272
self.mm_encoder = mm_encoder
255273
self.image_to_mm_projection = image_to_mm_projection
256274
self.text_to_mm_projection = text_to_mm_projection
275+
self.text_projection = text_projection
276+
self.image_projection = image_projection
257277

258278
def forward(
259279
self,
@@ -272,18 +292,30 @@ def forward(
272292
else:
273293
required_embedding = "text"
274294

275-
image_outputs = self._encode_data_to_embeddings(
295+
image_encoding_out = self._encode_data_to_embeddings(
276296
image,
277297
required_embedding,
278298
["image", "mm"],
279-
self.encode_image,
299+
partial(self.encode_image, projection=True),
280300
)
281-
text_outputs = self._encode_data_to_embeddings(
301+
if len(image_encoding_out) == 2:
302+
image_outputs, projected_image_embeddings = image_encoding_out
303+
else:
304+
image_outputs = image_encoding_out
305+
projected_image_embeddings = None
306+
307+
text_encoding_out = self._encode_data_to_embeddings(
282308
text,
283309
required_embedding,
284310
["text", "mm"],
285-
self.encode_text,
311+
partial(self.encode_text, projection=True),
286312
)
313+
if len(text_encoding_out) == 2:
314+
text_outputs, projected_text_embeddings = text_encoding_out
315+
else:
316+
text_outputs = text_encoding_out
317+
projected_text_embeddings = None
318+
287319
image_masked_outputs = self._encode_data_to_embeddings(
288320
image,
289321
required_embedding,
@@ -329,26 +361,41 @@ def forward(
329361
text_masked=text_masked_outputs,
330362
multimodal=multimodal_outputs,
331363
multimodal_masked=multimodal_masked_outputs,
364+
projected_image_embeddings=projected_image_embeddings,
365+
projected_text_embeddings=projected_text_embeddings,
332366
)
333367

334368
def encode_image(
335-
self, image: Tensor, image_patches_mask: Optional[Tensor] = None
369+
self,
370+
image: Tensor,
371+
image_patches_mask: Optional[Tensor] = None,
372+
projection: bool = False,
336373
) -> Optional[FLAVATransformerOutput]:
337374
if image_patches_mask is not None:
338-
return self.image_encoder(image, image_patches_mask)
375+
encoded_image = self.image_encoder(image, image_patches_mask)
339376
else:
340-
return self.image_encoder(image)
377+
encoded_image = self.image_encoder(image)
378+
if projection:
379+
projected_embeddings = self.image_projection(
380+
encoded_image.last_hidden_state[:, 0, :]
381+
)
382+
return encoded_image, projected_embeddings
383+
return encoded_image
341384

342385
def encode_text(
343-
self,
344-
text: Tensor,
345-
text_mask: Optional[Tensor] = None,
386+
self, text: Tensor, text_mask: Optional[Tensor] = None, projection: bool = False
346387
) -> Optional[FLAVATransformerOutput]:
347388
# TODO(asg): Give proper parameter names when implementing text encoder
348-
return self.text_encoder(
389+
encoded_text = self.text_encoder(
349390
input_ids=text,
350391
attention_mask=text_mask,
351392
)
393+
if projection:
394+
projected_embeddings = self.text_projection(
395+
encoded_text.last_hidden_state[:, 0, :]
396+
)
397+
return encoded_text, projected_embeddings
398+
return encoded_text
352399

353400
def _encode_data_to_embeddings(
354401
self,
@@ -361,7 +408,6 @@ def _encode_data_to_embeddings(
361408

362409
if data is not None and selected_head_encoder in encoder_options:
363410
output = encode_callable(data)
364-
365411
return output
366412

367413
def encode_mm(
@@ -450,6 +496,8 @@ def forward(
450496
itm_labels=itm_labels,
451497
mim_labels=image_labels,
452498
mlm_labels=mlm_labels,
499+
projected_image_embeddings=flava_output.projected_image_embeddings,
500+
projected_text_embeddings=flava_output.projected_text_embeddings,
453501
)
454502

455503

torchmultimodal/modules/losses/flava.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -249,22 +249,28 @@ def __init__(
249249
else:
250250
self.logit_scale = nn.Parameter(logit_scale * torch.ones([]))
251251

252-
self.image_projection = nn.Linear(image_embedding_size, projection_size)
253-
self.text_projection = nn.Linear(text_embedding_size, projection_size)
254-
self.image_embedding_index = image_embedding_index
255-
self.text_embedding_index = text_embedding_index
252+
# self.image_projection = nn.Linear(image_embedding_size, projection_size)
253+
# self.text_projection = nn.Linear(text_embedding_size, projection_size)
254+
# self.image_embedding_index = image_embedding_index
255+
# self.text_embedding_index = text_embedding_index
256256

257257
def forward(
258258
self,
259259
image_sequence: Tensor,
260260
text_sequence: Tensor,
261261
mask: Tensor,
262262
):
263-
text_embedding = nn.functional.normalize(
264-
self.text_projection(text_sequence[:, self.text_embedding_index, :]), dim=-1
265-
)
263+
# text_embedding = nn.functional.normalize(
264+
# self.text_projection(text_sequence[:, self.text_embedding_index, :]), dim=-1
265+
# )
266+
# image_embedding = nn.functional.normalize(
267+
# self.image_projection(image_sequence[:, self.image_embedding_index, :]),
268+
# dim=-1,
269+
# )
270+
271+
text_embedding = nn.functional.normalize(text_sequence, dim=-1)
266272
image_embedding = nn.functional.normalize(
267-
self.image_projection(image_sequence[:, self.image_embedding_index, :]),
273+
image_sequence,
268274
dim=-1,
269275
)
270276

@@ -278,6 +284,7 @@ def forward(
278284
# Always true for FLAVA global contrastive loss
279285
backprop_in_gather=True,
280286
)
287+
print(output.loss)
281288

282289
return FLAVAGlobalContrastiveLossOutput(
283290
loss=output.loss,
@@ -376,6 +383,8 @@ def forward(
376383
itm_labels: Optional[Tensor] = None,
377384
mim_labels: Optional[Tensor] = None,
378385
mlm_labels: Optional[Tensor] = None,
386+
projected_image_embeddings=None,
387+
projected_text_embeddings=None,
379388
) -> FLAVAPretrainingLossOutput:
380389
outputs = FLAVAPretrainingLossOutput()
381390
pos_mask = None
@@ -386,8 +395,8 @@ def forward(
386395
and self.contrastive_loss_weight > 0
387396
):
388397
outputs.global_contrastive_output = self.contrastive_loss(
389-
image_sequence,
390-
text_sequence,
398+
projected_image_embeddings,
399+
projected_text_embeddings,
391400
pos_mask,
392401
)
393402
outputs.global_contrastive_output.loss *= self.contrastive_loss_weight
@@ -398,6 +407,21 @@ def forward(
398407
# Check multimodal_masked_sequence to make sure this is unimodal case
399408
# This specific case can though be backpropagated directly as MIM is independent of
400409
# text, but that is a research question :)
410+
if (
411+
image_sequence is not None
412+
and text_sequence is not None
413+
and self.contrastive_loss_weight > 0
414+
):
415+
outputs.global_contrastive_output = self.contrastive_loss(
416+
projected_image_embeddings,
417+
projected_text_embeddings,
418+
pos_mask,
419+
)
420+
outputs.global_contrastive_output.loss *= self.contrastive_loss_weight
421+
outputs.losses.global_contrastive_loss = (
422+
outputs.global_contrastive_output.loss
423+
)
424+
401425
if (
402426
image_masked_sequence is not None
403427
and self.mim_weight > 0

0 commit comments

Comments
 (0)