Skip to content

Commit 599bc61

Browse files
committed
[FLAVA]Change ordering on contrastive loss initialization
ghstack-source-id: 8614e1e Pull Request resolved: #105
1 parent 298fe21 commit 599bc61

File tree

3 files changed

+30
-28
lines changed

3 files changed

+30
-28
lines changed

test/models/flava/test_flava.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,27 +27,30 @@ def setUp(self):
2727

2828
@torch.no_grad()
2929
def test_forward_classification(self):
30-
flava = flava_model_for_classification(NUM_CLASSES, pretrained_model_key=None)
3130
text = torch.randint(0, 30500, (2, 77), dtype=torch.long)
3231
image = torch.rand((2, 3, 224, 224))
3332

3433
labels = torch.randint(0, 2, (2,), dtype=torch.long)
3534

35+
flava = flava_model_for_classification(NUM_CLASSES, pretrained_model_key=None)
36+
flava.eval()
37+
3638
# Test multimodal scenario
39+
3740
output = flava(image, text, "mm", labels)
38-
self.assertAlmostEqual(output.loss.item(), 0.9724, places=4)
41+
self.assertAlmostEqual(output.loss.item(), 0.7180, places=4)
3942

4043
# Test unimodal image scenario
4144
output = flava(image, text, "image", labels)
42-
self.assertAlmostEqual(output.loss.item(), 0.5453, places=4)
45+
self.assertAlmostEqual(output.loss.item(), 0.7020, places=4)
4346

4447
# Test unimodal text scenario
4548
output = flava(image, text, "text", labels)
46-
self.assertAlmostEqual(output.loss.item(), 0.7074, places=4)
49+
self.assertAlmostEqual(output.loss.item(), 0.6663, places=4)
4750

4851
@torch.no_grad()
4952
def test_forward_pretraining(self):
50-
flava = flava_model_for_pretraining()
53+
5154
text = torch.randint(0, 30500, (2, 77), dtype=torch.long)
5255
image = torch.rand((2, 3, 224, 224))
5356
image_for_codebook = torch.rand(2, 3, 112, 112)
@@ -58,7 +61,8 @@ def test_forward_pretraining(self):
5861
mlm_labels[:, :] = -1
5962
mlm_labels[:, 1:3] = text[:, 1:3]
6063
itm_labels = torch.tensor((0, 1), dtype=torch.long)
61-
64+
flava = flava_model_for_pretraining()
65+
flava.eval()
6266
output = flava(
6367
image=image,
6468
text=text,
@@ -79,7 +83,7 @@ def test_forward_pretraining(self):
7983
sum(
8084
value if value is not None else 0 for value in output.losses.values()
8185
).item(),
82-
20.4199,
86+
21.4791,
8387
places=4,
8488
)
8589

@@ -103,7 +107,7 @@ def test_forward_pretraining(self):
103107
sum(
104108
value if value is not None else 0 for value in output.losses.values()
105109
).item(),
106-
9.3403,
110+
8.9674,
107111
places=4,
108112
)
109113

@@ -128,7 +132,7 @@ def test_forward_pretraining(self):
128132
sum(
129133
value if value is not None else 0 for value in output.losses.values()
130134
).item(),
131-
10.8777,
135+
10.0305,
132136
places=4,
133137
)
134138

torchmultimodal/models/flava/flava_model.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,8 @@ def flava_model_for_pretraining(
185185
# TODO: Add parameters for loss here
186186
):
187187
model = flava_model(**flava_model_kwargs)
188-
189-
codebook = DalleVAEEncoder(image_size=codebook_image_size)
190188
losses = FLAVAPretrainingLoss()
189+
codebook = DalleVAEEncoder(image_size=codebook_image_size)
191190

192191
flava = FLAVAForPreTraining(
193192
model=model,
@@ -212,7 +211,6 @@ def flava_model_for_classification(
212211
pretrained_model_key: Optional[str] = "flava_full",
213212
**flava_model_kwargs: Any,
214213
):
215-
model = flava_model(**flava_model_kwargs)
216214
classifier = MLP(
217215
in_dim=classifier_in_dim,
218216
out_dim=num_classes,
@@ -221,7 +219,7 @@ def flava_model_for_classification(
221219
activation=classifier_activation,
222220
normalization=classifier_normalization,
223221
)
224-
222+
model = flava_model(**flava_model_kwargs)
225223
if loss_fn is None:
226224
loss_fn = nn.CrossEntropyLoss()
227225

torchmultimodal/modules/losses/flava.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,21 @@ def forward(
380380
outputs = FLAVAPretrainingLossOutput()
381381
pos_mask = None
382382

383+
if (
384+
image_sequence is not None
385+
and text_sequence is not None
386+
and self.contrastive_loss_weight > 0
387+
):
388+
outputs.global_contrastive_output = self.contrastive_loss(
389+
image_sequence,
390+
text_sequence,
391+
pos_mask,
392+
)
393+
outputs.global_contrastive_output.loss *= self.contrastive_loss_weight
394+
outputs.losses.global_contrastive_loss = (
395+
outputs.global_contrastive_output.loss
396+
)
397+
383398
# Check multimodal_masked_sequence to make sure this is unimodal case
384399
# This specific case can though be backpropagated directly as MIM is independent of
385400
# text, but that is a research question :)
@@ -461,19 +476,4 @@ def forward(
461476
outputs.mmm_image_output.loss *= self.mmm_image_loss_weight
462477
outputs.losses.mmm_image_loss = outputs.mmm_image_output.loss
463478

464-
if (
465-
image_sequence is not None
466-
and text_sequence is not None
467-
and self.contrastive_loss_weight > 0
468-
):
469-
outputs.global_contrastive_output = self.contrastive_loss(
470-
image_sequence,
471-
text_sequence,
472-
pos_mask,
473-
)
474-
outputs.global_contrastive_output.loss *= self.contrastive_loss_weight
475-
outputs.losses.global_contrastive_loss = (
476-
outputs.global_contrastive_output.loss
477-
)
478-
479479
return outputs

0 commit comments

Comments
 (0)