3737
3838FLAVAOutput = namedtuple (
3939 "FLAVAOutput" ,
40- ["image" , "image_masked" , "text" , "text_masked" , "multimodal" , "multimodal_masked" ],
41- defaults = (None , None , None , None , None , None ),
40+ [
41+ "image" ,
42+ "image_masked" ,
43+ "text" ,
44+ "text_masked" ,
45+ "multimodal" ,
46+ "multimodal_masked" ,
47+ "projected_image_embeddings" ,
48+ "projected_text_embeddings" ,
49+ ],
50+ defaults = (None , None , None , None , None , None , None , None ),
4251)
4352FLAVAOutput .__annotations__ = {
4453 "image" : FLAVATransformerOutput ,
@@ -124,6 +133,8 @@ def flava_model(
124133 multimodal_intermediate_activation : Callable [..., Tensor ] = nn .functional .gelu ,
125134 multimodal_attention_probs_dropout_prob : float = 0.0 ,
126135 multimodal_layer_norm_eps : float = 1e-12 ,
136+ # projection
137+ text_and_image_proj_size : int = 768 ,
127138 ** kwargs : Any ,
128139):
129140 image_encoder = flava_image_encoder (
@@ -169,12 +180,17 @@ def flava_model(
169180 image_to_mm_projection = nn .Linear (image_hidden_size , multimodal_hidden_size )
170181 text_to_mm_projection = nn .Linear (text_hidden_size , multimodal_hidden_size )
171182
183+ image_projection = nn .Linear (image_hidden_size , text_and_image_proj_size )
184+ text_projection = nn .Linear (text_hidden_size , text_and_image_proj_size )
185+
172186 return FLAVAModel (
173187 image_encoder = image_encoder ,
174188 text_encoder = text_encoder ,
175189 mm_encoder = mm_encoder ,
176190 image_to_mm_projection = image_to_mm_projection ,
177191 text_to_mm_projection = text_to_mm_projection ,
192+ text_projection = text_projection ,
193+ image_projection = image_projection ,
178194 )
179195
180196
@@ -246,6 +262,8 @@ def __init__(
246262 mm_encoder : nn .Module ,
247263 image_to_mm_projection : nn .Module ,
248264 text_to_mm_projection : nn .Module ,
265+ text_projection : nn .Module ,
266+ image_projection : nn .Module ,
249267 ** kwargs : Any ,
250268 ):
251269 super ().__init__ ()
@@ -254,6 +272,8 @@ def __init__(
254272 self .mm_encoder = mm_encoder
255273 self .image_to_mm_projection = image_to_mm_projection
256274 self .text_to_mm_projection = text_to_mm_projection
275+ self .text_projection = text_projection
276+ self .image_projection = image_projection
257277
258278 def forward (
259279 self ,
@@ -272,18 +292,30 @@ def forward(
272292 else :
273293 required_embedding = "text"
274294
275- image_outputs = self ._encode_data_to_embeddings (
295+ image_encoding_out = self ._encode_data_to_embeddings (
276296 image ,
277297 required_embedding ,
278298 ["image" , "mm" ],
279- self .encode_image ,
299+ partial ( self .encode_image , projection = True ) ,
280300 )
281- text_outputs = self ._encode_data_to_embeddings (
301+ if len (image_encoding_out ) == 2 :
302+ image_outputs , projected_image_embeddings = image_encoding_out
303+ else :
304+ image_outputs = image_encoding_out
305+ projected_image_embeddings = None
306+
307+ text_encoding_out = self ._encode_data_to_embeddings (
282308 text ,
283309 required_embedding ,
284310 ["text" , "mm" ],
285- self .encode_text ,
311+ partial ( self .encode_text , projection = True ) ,
286312 )
313+ if len (text_encoding_out ) == 2 :
314+ text_outputs , projected_text_embeddings = text_encoding_out
315+ else :
316+ text_outputs = text_encoding_out
317+ projected_text_embeddings = None
318+
287319 image_masked_outputs = self ._encode_data_to_embeddings (
288320 image ,
289321 required_embedding ,
@@ -329,26 +361,41 @@ def forward(
329361 text_masked = text_masked_outputs ,
330362 multimodal = multimodal_outputs ,
331363 multimodal_masked = multimodal_masked_outputs ,
364+ projected_image_embeddings = projected_image_embeddings ,
365+ projected_text_embeddings = projected_text_embeddings ,
332366 )
333367
334368 def encode_image (
335- self , image : Tensor , image_patches_mask : Optional [Tensor ] = None
369+ self ,
370+ image : Tensor ,
371+ image_patches_mask : Optional [Tensor ] = None ,
372+ projection : bool = False ,
336373 ) -> Optional [FLAVATransformerOutput ]:
337374 if image_patches_mask is not None :
338- return self .image_encoder (image , image_patches_mask )
375+ encoded_image = self .image_encoder (image , image_patches_mask )
339376 else :
340- return self .image_encoder (image )
377+ encoded_image = self .image_encoder (image )
378+ if projection :
379+ projected_embeddings = self .image_projection (
380+ encoded_image .last_hidden_state [:, 0 , :]
381+ )
382+ return encoded_image , projected_embeddings
383+ return encoded_image
341384
342385 def encode_text (
343- self ,
344- text : Tensor ,
345- text_mask : Optional [Tensor ] = None ,
386+ self , text : Tensor , text_mask : Optional [Tensor ] = None , projection : bool = False
346387 ) -> Optional [FLAVATransformerOutput ]:
347388 # TODO(asg): Give proper parameter names when implementing text encoder
348- return self .text_encoder (
389+ encoded_text = self .text_encoder (
349390 input_ids = text ,
350391 attention_mask = text_mask ,
351392 )
393+ if projection :
394+ projected_embeddings = self .text_projection (
395+ encoded_text .last_hidden_state [:, 0 , :]
396+ )
397+ return encoded_text , projected_embeddings
398+ return encoded_text
352399
353400 def _encode_data_to_embeddings (
354401 self ,
@@ -361,7 +408,6 @@ def _encode_data_to_embeddings(
361408
362409 if data is not None and selected_head_encoder in encoder_options :
363410 output = encode_callable (data )
364-
365411 return output
366412
367413 def encode_mm (
@@ -450,6 +496,8 @@ def forward(
450496 itm_labels = itm_labels ,
451497 mim_labels = image_labels ,
452498 mlm_labels = mlm_labels ,
499+ projected_image_embeddings = flava_output .projected_image_embeddings ,
500+ projected_text_embeddings = flava_output .projected_text_embeddings ,
453501 )
454502
455503
0 commit comments