3737
3838FLAVAOutput = namedtuple (
3939 "FLAVAOutput" ,
40- ["image" , "image_masked" , "text" , "text_masked" , "multimodal" , "multimodal_masked" ],
41- defaults = (None , None , None , None , None , None ),
40+ [
41+ "image" ,
42+ "image_masked" ,
43+ "text" ,
44+ "text_masked" ,
45+ "multimodal" ,
46+ "multimodal_masked" ,
47+ "projected_image_embeddings" ,
48+ "projected_text_embeddings" ,
49+ ],
50+ defaults = (None , None , None , None , None , None , None , None ),
4251)
4352FLAVAOutput .__annotations__ = {
4453 "image" : FLAVATransformerOutput ,
5160
5261
5362FLAVA_FOR_PRETRAINED_MAPPING = {
54- "flava_full" : "https://huggingface.co/aps/flava_full_pretrained_encoders_torchmm/resolve/main/pytorch_model.bin" ,
63+ "flava_full" : "https://download.pytorch.org/models/multimodal/flava/flava_for_pretraining.pt"
5564}
5665
5766
@@ -124,6 +133,8 @@ def flava_model(
124133 multimodal_intermediate_activation : Callable [..., Tensor ] = nn .functional .gelu ,
125134 multimodal_attention_probs_dropout_prob : float = 0.0 ,
126135 multimodal_layer_norm_eps : float = 1e-12 ,
136+ # projection
137+ text_and_image_proj_size : int = 768 ,
127138 ** kwargs : Any ,
128139):
129140 image_encoder = flava_image_encoder (
@@ -169,12 +180,17 @@ def flava_model(
169180 image_to_mm_projection = nn .Linear (image_hidden_size , multimodal_hidden_size )
170181 text_to_mm_projection = nn .Linear (text_hidden_size , multimodal_hidden_size )
171182
183+ image_projection = nn .Linear (image_hidden_size , text_and_image_proj_size )
184+ text_projection = nn .Linear (text_hidden_size , text_and_image_proj_size )
185+
172186 return FLAVAModel (
173187 image_encoder = image_encoder ,
174188 text_encoder = text_encoder ,
175189 mm_encoder = mm_encoder ,
176190 image_to_mm_projection = image_to_mm_projection ,
177191 text_to_mm_projection = text_to_mm_projection ,
192+ text_projection = text_projection ,
193+ image_projection = image_projection ,
178194 )
179195
180196
@@ -253,6 +269,8 @@ def __init__(
253269 mm_encoder : nn .Module ,
254270 image_to_mm_projection : nn .Module ,
255271 text_to_mm_projection : nn .Module ,
272+ text_projection : nn .Module ,
273+ image_projection : nn .Module ,
256274 ** kwargs : Any ,
257275 ):
258276 super ().__init__ ()
@@ -261,6 +279,8 @@ def __init__(
261279 self .mm_encoder = mm_encoder
262280 self .image_to_mm_projection = image_to_mm_projection
263281 self .text_to_mm_projection = text_to_mm_projection
282+ self .text_projection = text_projection
283+ self .image_projection = image_projection
264284
265285 def forward (
266286 self ,
@@ -279,18 +299,30 @@ def forward(
279299 else :
280300 required_embedding = "text"
281301
282- image_outputs = self ._encode_data_to_embeddings (
302+ image_encoding_out = self ._encode_data_to_embeddings (
283303 image ,
284304 required_embedding ,
285305 ["image" , "mm" ],
286- self .encode_image ,
306+ partial ( self .encode_image , projection = True ) ,
287307 )
288- text_outputs = self ._encode_data_to_embeddings (
308+ if len (image_encoding_out ) == 2 :
309+ image_outputs , projected_image_embeddings = image_encoding_out
310+ else :
311+ image_outputs = image_encoding_out
312+ projected_image_embeddings = None
313+
314+ text_encoding_out = self ._encode_data_to_embeddings (
289315 text ,
290316 required_embedding ,
291317 ["text" , "mm" ],
292- self .encode_text ,
318+ partial ( self .encode_text , projection = True ) ,
293319 )
320+ if len (text_encoding_out ) == 2 :
321+ text_outputs , projected_text_embeddings = text_encoding_out
322+ else :
323+ text_outputs = text_encoding_out
324+ projected_text_embeddings = None
325+
294326 image_masked_outputs = self ._encode_data_to_embeddings (
295327 image ,
296328 required_embedding ,
@@ -336,26 +368,41 @@ def forward(
336368 text_masked = text_masked_outputs ,
337369 multimodal = multimodal_outputs ,
338370 multimodal_masked = multimodal_masked_outputs ,
371+ projected_image_embeddings = projected_image_embeddings ,
372+ projected_text_embeddings = projected_text_embeddings ,
339373 )
340374
341375 def encode_image (
342- self , image : Tensor , image_patches_mask : Optional [Tensor ] = None
376+ self ,
377+ image : Tensor ,
378+ image_patches_mask : Optional [Tensor ] = None ,
379+ projection : bool = False ,
343380 ) -> Optional [FLAVATransformerOutput ]:
344381 if image_patches_mask is not None :
345- return self .image_encoder (image , image_patches_mask )
382+ encoded_image = self .image_encoder (image , image_patches_mask )
346383 else :
347- return self .image_encoder (image )
384+ encoded_image = self .image_encoder (image )
385+ if projection :
386+ projected_embeddings = self .image_projection (
387+ encoded_image .last_hidden_state [:, 0 , :]
388+ )
389+ return encoded_image , projected_embeddings
390+ return encoded_image
348391
349392 def encode_text (
350- self ,
351- text : Tensor ,
352- text_mask : Optional [Tensor ] = None ,
393+ self , text : Tensor , text_mask : Optional [Tensor ] = None , projection : bool = False
353394 ) -> Optional [FLAVATransformerOutput ]:
354395 # TODO(asg): Give proper parameter names when implementing text encoder
355- return self .text_encoder (
396+ encoded_text = self .text_encoder (
356397 input_ids = text ,
357398 attention_mask = text_mask ,
358399 )
400+ if projection :
401+ projected_embeddings = self .text_projection (
402+ encoded_text .last_hidden_state [:, 0 , :]
403+ )
404+ return encoded_text , projected_embeddings
405+ return encoded_text
359406
360407 def _encode_data_to_embeddings (
361408 self ,
@@ -368,7 +415,6 @@ def _encode_data_to_embeddings(
368415
369416 if data is not None and selected_head_encoder in encoder_options :
370417 output = encode_callable (data )
371-
372418 return output
373419
374420 def encode_mm (
@@ -403,19 +449,17 @@ def encode_image(
403449 image : Tensor ,
404450 cls_index : int = 0 ,
405451 ):
406- transformer_output = self .model .encode_image (image )
407- embeddings = transformer_output .last_hidden_state
408- return self .loss .contrastive_loss .image_projection (embeddings [:, cls_index , :])
452+ _ , encoded_image = self .model .encode_image (image , projection = True )
453+ return encoded_image
409454
410455 def encode_text (
411456 self ,
412457 text : Tensor ,
413458 text_mask : Optional [Tensor ] = None ,
414459 cls_index : int = 0 ,
415460 ):
416- transformer_output = self .model .encode_text (text , text_mask )
417- embeddings = transformer_output .last_hidden_state
418- return self .loss .contrastive_loss .text_projection (embeddings [:, cls_index , :])
461+ _ , encoded_text = self .model .encode_text (text , text_mask , projection = True )
462+ return encoded_text
419463
420464 # TODO: Add options to enable losses selectively
421465 def forward (
@@ -457,6 +501,8 @@ def forward(
457501 itm_labels = itm_labels ,
458502 mim_labels = image_labels ,
459503 mlm_labels = mlm_labels ,
504+ projected_image_embeddings = flava_output .projected_image_embeddings ,
505+ projected_text_embeddings = flava_output .projected_text_embeddings ,
460506 )
461507
462508
0 commit comments