3737
3838FLAVAOutput = namedtuple (
3939 "FLAVAOutput" ,
40- ["image" , "image_masked" , "text" , "text_masked" , "multimodal" , "multimodal_masked" ],
41- defaults = (None , None , None , None , None , None ),
40+ [
41+ "image" ,
42+ "image_masked" ,
43+ "text" ,
44+ "text_masked" ,
45+ "multimodal" ,
46+ "multimodal_masked" ,
47+ "projected_image_embeddings" ,
48+ "projected_text_embeddings" ,
49+ ],
50+ defaults = (None , None , None , None , None , None , None , None ),
4251)
4352FLAVAOutput .__annotations__ = {
4453 "image" : FLAVATransformerOutput ,
5160
5261
5362FLAVA_FOR_PRETRAINED_MAPPING = {
54- "flava_full" : "https://huggingface.co/aps/flava_full_pretrained_encoders_torchmm/resolve/main/pytorch_model.bin" ,
63+ "flava_full" : "https://download.pytorch.org/models/multimodal/flava/flava_for_pretraining.pt"
5564}
5665
5766
@@ -124,6 +133,8 @@ def flava_model(
124133 multimodal_intermediate_activation : Callable [..., Tensor ] = nn .functional .gelu ,
125134 multimodal_attention_probs_dropout_prob : float = 0.0 ,
126135 multimodal_layer_norm_eps : float = 1e-12 ,
136+ # projection
137+ text_and_image_proj_size : int = 768 ,
127138 ** kwargs : Any ,
128139):
129140 image_encoder = flava_image_encoder (
@@ -169,12 +180,17 @@ def flava_model(
169180 image_to_mm_projection = nn .Linear (image_hidden_size , multimodal_hidden_size )
170181 text_to_mm_projection = nn .Linear (text_hidden_size , multimodal_hidden_size )
171182
183+ image_projection = nn .Linear (image_hidden_size , text_and_image_proj_size )
184+ text_projection = nn .Linear (text_hidden_size , text_and_image_proj_size )
185+
172186 return FLAVAModel (
173187 image_encoder = image_encoder ,
174188 text_encoder = text_encoder ,
175189 mm_encoder = mm_encoder ,
176190 image_to_mm_projection = image_to_mm_projection ,
177191 text_to_mm_projection = text_to_mm_projection ,
192+ text_projection = text_projection ,
193+ image_projection = image_projection ,
178194 )
179195
180196
@@ -254,6 +270,8 @@ def __init__(
254270 mm_encoder : nn .Module ,
255271 image_to_mm_projection : nn .Module ,
256272 text_to_mm_projection : nn .Module ,
273+ text_projection : nn .Module ,
274+ image_projection : nn .Module ,
257275 ** kwargs : Any ,
258276 ):
259277 super ().__init__ ()
@@ -262,6 +280,8 @@ def __init__(
262280 self .mm_encoder = mm_encoder
263281 self .image_to_mm_projection = image_to_mm_projection
264282 self .text_to_mm_projection = text_to_mm_projection
283+ self .text_projection = text_projection
284+ self .image_projection = image_projection
265285
266286 def forward (
267287 self ,
@@ -280,18 +300,30 @@ def forward(
280300 else :
281301 required_embedding = "text"
282302
283- image_outputs = self ._encode_data_to_embeddings (
303+ image_encoding_out = self ._encode_data_to_embeddings (
284304 image ,
285305 required_embedding ,
286306 ["image" , "mm" ],
287- self .encode_image ,
307+ partial ( self .encode_image , projection = True ) ,
288308 )
289- text_outputs = self ._encode_data_to_embeddings (
309+ if len (image_encoding_out ) == 2 :
310+ image_outputs , projected_image_embeddings = image_encoding_out
311+ else :
312+ image_outputs = image_encoding_out
313+ projected_image_embeddings = None
314+
315+ text_encoding_out = self ._encode_data_to_embeddings (
290316 text ,
291317 required_embedding ,
292318 ["text" , "mm" ],
293- self .encode_text ,
319+ partial ( self .encode_text , projection = True ) ,
294320 )
321+ if len (text_encoding_out ) == 2 :
322+ text_outputs , projected_text_embeddings = text_encoding_out
323+ else :
324+ text_outputs = text_encoding_out
325+ projected_text_embeddings = None
326+
295327 image_masked_outputs = self ._encode_data_to_embeddings (
296328 image ,
297329 required_embedding ,
@@ -337,26 +369,41 @@ def forward(
337369 text_masked = text_masked_outputs ,
338370 multimodal = multimodal_outputs ,
339371 multimodal_masked = multimodal_masked_outputs ,
372+ projected_image_embeddings = projected_image_embeddings ,
373+ projected_text_embeddings = projected_text_embeddings ,
340374 )
341375
342376 def encode_image (
343- self , image : Tensor , image_patches_mask : Optional [Tensor ] = None
377+ self ,
378+ image : Tensor ,
379+ image_patches_mask : Optional [Tensor ] = None ,
380+ projection : bool = False ,
344381 ) -> Optional [FLAVATransformerOutput ]:
345382 if image_patches_mask is not None :
346- return self .image_encoder (image , image_patches_mask )
383+ encoded_image = self .image_encoder (image , image_patches_mask )
347384 else :
348- return self .image_encoder (image )
385+ encoded_image = self .image_encoder (image )
386+ if projection :
387+ projected_embeddings = self .image_projection (
388+ encoded_image .last_hidden_state [:, 0 , :]
389+ )
390+ return encoded_image , projected_embeddings
391+ return encoded_image
349392
350393 def encode_text (
351- self ,
352- text : Tensor ,
353- text_mask : Optional [Tensor ] = None ,
394+ self , text : Tensor , text_mask : Optional [Tensor ] = None , projection : bool = False
354395 ) -> Optional [FLAVATransformerOutput ]:
355396 # TODO(asg): Give proper parameter names when implementing text encoder
356- return self .text_encoder (
397+ encoded_text = self .text_encoder (
357398 input_ids = text ,
358399 attention_mask = text_mask ,
359400 )
401+ if projection :
402+ projected_embeddings = self .text_projection (
403+ encoded_text .last_hidden_state [:, 0 , :]
404+ )
405+ return encoded_text , projected_embeddings
406+ return encoded_text
360407
361408 def _encode_data_to_embeddings (
362409 self ,
@@ -369,7 +416,6 @@ def _encode_data_to_embeddings(
369416
370417 if data is not None and selected_head_encoder in encoder_options :
371418 output = encode_callable (data )
372-
373419 return output
374420
375421 def encode_mm (
@@ -404,19 +450,17 @@ def encode_image(
404450 image : Tensor ,
405451 cls_index : int = 0 ,
406452 ):
407- transformer_output = self .model .encode_image (image )
408- embeddings = transformer_output .last_hidden_state
409- return self .loss .contrastive_loss .image_projection (embeddings [:, cls_index , :])
453+ _ , encoded_image = self .model .encode_image (image , projection = True )
454+ return encoded_image
410455
411456 def encode_text (
412457 self ,
413458 text : Tensor ,
414459 text_mask : Optional [Tensor ] = None ,
415460 cls_index : int = 0 ,
416461 ):
417- transformer_output = self .model .encode_text (text , text_mask )
418- embeddings = transformer_output .last_hidden_state
419- return self .loss .contrastive_loss .text_projection (embeddings [:, cls_index , :])
462+ _ , encoded_text = self .model .encode_text (text , text_mask , projection = True )
463+ return encoded_text
420464
421465 # TODO: Add options to enable losses selectively
422466 def forward (
@@ -458,6 +502,8 @@ def forward(
458502 itm_labels = itm_labels ,
459503 mim_labels = image_labels ,
460504 mlm_labels = mlm_labels ,
505+ projected_image_embeddings = flava_output .projected_image_embeddings ,
506+ projected_text_embeddings = flava_output .projected_text_embeddings ,
461507 )
462508
463509
0 commit comments