3737
3838FLAVAOutput = namedtuple (
3939 "FLAVAOutput" ,
40- ["image" , "image_masked" , "text" , "text_masked" , "multimodal" , "multimodal_masked" ],
41- defaults = (None , None , None , None , None , None ),
40+ [
41+ "image" ,
42+ "image_masked" ,
43+ "text" ,
44+ "text_masked" ,
45+ "multimodal" ,
46+ "multimodal_masked" ,
47+ "projected_image_embeddings" ,
48+ "projected_text_embeddings" ,
49+ ],
50+ defaults = (None , None , None , None , None , None , None , None ),
4251)
4352FLAVAOutput .__annotations__ = {
4453 "image" : FLAVATransformerOutput ,
5362FLAVA_FOR_PRETRAINED_MAPPING = {
5463 # This will no longer load with the updated model, but keeping here just in case
5564 # "flava_full": "https://huggingface.co/aps/flava_full_pretrained_encoders_torchmm/resolve/main/pytorch_model.bin",
56- "flava_full" : "https://download.pytorch.org/models/multimodal/flava/flava_self_attn.bin " ,
65+ "flava_full" : "https://download.pytorch.org/models/multimodal/flava/flava_for_pretraining_projection.pt " ,
5766}
5867
5968
@@ -101,6 +110,8 @@ def __init__(
101110 mm_encoder : nn .Module ,
102111 image_to_mm_projection : nn .Module ,
103112 text_to_mm_projection : nn .Module ,
113+ text_projection : nn .Module ,
114+ image_projection : nn .Module ,
104115 ** kwargs : Any ,
105116 ) -> None :
106117 super ().__init__ ()
@@ -109,6 +120,8 @@ def __init__(
109120 self .mm_encoder = mm_encoder
110121 self .image_to_mm_projection = image_to_mm_projection
111122 self .text_to_mm_projection = text_to_mm_projection
123+ self .text_projection = text_projection
124+ self .image_projection = image_projection
112125
113126 def forward (
114127 self ,
@@ -127,30 +140,50 @@ def forward(
127140 else :
128141 required_embedding = "text"
129142
130- image_outputs = self ._encode_data_to_embeddings (
143+ image_encoding_out = self ._encode_data_to_embeddings (
131144 image ,
132145 required_embedding ,
133146 ["image" , "mm" ],
134- self .encode_image ,
147+ partial ( self .encode_image , projection = True ) ,
135148 )
136- text_outputs = self ._encode_data_to_embeddings (
149+ if len (image_encoding_out ) == 2 :
150+ image_outputs , projected_image_embeddings = (
151+ image_encoding_out [0 ],
152+ image_encoding_out [1 ],
153+ )
154+ else :
155+ image_outputs = image_encoding_out
156+ projected_image_embeddings = None
157+
158+ text_encoding_out = self ._encode_data_to_embeddings (
137159 text ,
138160 required_embedding ,
139161 ["text" , "mm" ],
140- self .encode_text ,
162+ partial ( self .encode_text , projection = True ) ,
141163 )
164+ if len (text_encoding_out ) == 2 :
165+ text_outputs , projected_text_embeddings = (
166+ text_encoding_out [0 ],
167+ text_encoding_out [1 ],
168+ )
169+ else :
170+ text_outputs = text_encoding_out
171+ projected_text_embeddings = None
172+
142173 image_masked_outputs = self ._encode_data_to_embeddings (
143174 image ,
144175 required_embedding ,
145176 ["image" , "mm" ],
146177 partial (self .encode_image , image_patches_mask = image_patches_mask ),
147178 )
179+ assert type (image_masked_outputs ) == FLAVATransformerOutput
148180 text_masked_outputs = self ._encode_data_to_embeddings (
149181 text_masked ,
150182 required_embedding ,
151183 ["text" , "mm" ],
152184 self .encode_text ,
153185 )
186+ assert type (text_masked_outputs ) == FLAVATransformerOutput
154187
155188 multimodal_outputs = FLAVATransformerOutput ()
156189 multimodal_masked_outputs = FLAVATransformerOutput ()
@@ -184,39 +217,60 @@ def forward(
184217 text_masked = text_masked_outputs ,
185218 multimodal = multimodal_outputs ,
186219 multimodal_masked = multimodal_masked_outputs ,
220+ projected_image_embeddings = projected_image_embeddings ,
221+ projected_text_embeddings = projected_text_embeddings ,
187222 )
188223
189224 def encode_image (
190- self , image : Tensor , image_patches_mask : Optional [Tensor ] = None
191- ) -> Optional [FLAVATransformerOutput ]:
225+ self ,
226+ image : Tensor ,
227+ image_patches_mask : Optional [Tensor ] = None ,
228+ projection : bool = False ,
229+ ) -> Union [Tuple [FLAVATransformerOutput , Tensor ], Optional [FLAVATransformerOutput ]]:
192230 if image_patches_mask is not None :
193- return self .image_encoder (image , image_patches_mask )
231+ encoded_image = self .image_encoder (image , image_patches_mask )
194232 else :
195- return self .image_encoder (image )
233+ encoded_image = self .image_encoder (image )
234+ if projection :
235+ projected_embeddings = self .image_projection (
236+ encoded_image .last_hidden_state [:, 0 , :]
237+ )
238+ return encoded_image , projected_embeddings
239+ return encoded_image
196240
197241 def encode_text (
198- self ,
199- text : Tensor ,
200- text_mask : Optional [Tensor ] = None ,
201- ) -> Optional [FLAVATransformerOutput ]:
242+ self , text : Tensor , text_mask : Optional [Tensor ] = None , projection : bool = False
243+ ) -> Union [Tuple [FLAVATransformerOutput , Tensor ], Optional [FLAVATransformerOutput ]]:
202244 # TODO(asg): Give proper parameter names when implementing text encoder
203- return self .text_encoder (
245+ encoded_text = self .text_encoder (
204246 input_ids = text ,
205247 attention_mask = text_mask ,
206248 )
249+ if projection :
250+ projected_embeddings = self .text_projection (
251+ encoded_text .last_hidden_state [:, 0 , :]
252+ )
253+ return encoded_text , projected_embeddings
254+ return encoded_text
207255
208256 def _encode_data_to_embeddings (
209257 self ,
210258 data : Optional [Tensor ],
211259 selected_head_encoder : EMBEDDING_OPTIONS ,
212260 encoder_options : List [EMBEDDING_OPTIONS ],
213- encode_callable : Callable [..., FLAVATransformerOutput ],
214- ) -> Optional [FLAVATransformerOutput ]:
215- output = FLAVATransformerOutput ()
261+ encode_callable : Callable [
262+ ...,
263+ Union [
264+ Tuple [FLAVATransformerOutput , Tensor ], Optional [FLAVATransformerOutput ]
265+ ],
266+ ],
267+ ) -> Union [Tuple [FLAVATransformerOutput , Tensor ], Optional [FLAVATransformerOutput ]]:
268+ output : Union [
269+ Tuple [FLAVATransformerOutput , Tensor ], FLAVATransformerOutput
270+ ] = FLAVATransformerOutput ()
216271
217272 if data is not None and selected_head_encoder in encoder_options :
218273 output = encode_callable (data )
219-
220274 return output
221275
222276 def encode_mm (
@@ -253,19 +307,19 @@ def encode_image(
253307 image : Tensor ,
254308 cls_index : int = 0 ,
255309 ) -> Tensor :
256- transformer_output = self .model .encode_image (image )
257- embeddings = transformer_output . last_hidden_state
258- return self . loss . contrastive_loss . image_projection ( embeddings [:, cls_index , :])
310+ encoded_result = self .model .encode_image (image , projection = True )
311+ encoded_image = encoded_result [ 1 ]
312+ return encoded_image
259313
260314 def encode_text (
261315 self ,
262316 text : Tensor ,
263317 text_mask : Optional [Tensor ] = None ,
264318 cls_index : int = 0 ,
265319 ) -> Tensor :
266- transformer_output = self .model .encode_text (text , text_mask )
267- embeddings = transformer_output . last_hidden_state
268- return self . loss . contrastive_loss . text_projection ( embeddings [:, cls_index , :])
320+ encoded_result = self .model .encode_text (text , text_mask , projection = True )
321+ encoded_text = encoded_result [ 1 ]
322+ return encoded_text
269323
270324 # TODO: Add options to enable losses selectively
271325 def forward (
@@ -307,6 +361,8 @@ def forward(
307361 itm_labels = itm_labels ,
308362 mim_labels = image_labels ,
309363 mlm_labels = mlm_labels ,
364+ projected_image_embeddings = flava_output .projected_image_embeddings ,
365+ projected_text_embeddings = flava_output .projected_text_embeddings ,
310366 )
311367
312368
@@ -394,6 +450,8 @@ def flava_model(
394450 multimodal_intermediate_activation : Callable [..., Tensor ] = nn .functional .gelu ,
395451 multimodal_attention_probs_dropout_prob : float = 0.0 ,
396452 multimodal_layer_norm_eps : float = 1e-12 ,
453+ # projection
454+ text_and_image_proj_size : int = 768 ,
397455 ** kwargs : Any ,
398456) -> FLAVAModel :
399457 image_encoder = flava_image_encoder (
@@ -439,12 +497,17 @@ def flava_model(
439497 image_to_mm_projection = nn .Linear (image_hidden_size , multimodal_hidden_size )
440498 text_to_mm_projection = nn .Linear (text_hidden_size , multimodal_hidden_size )
441499
500+ image_projection = nn .Linear (image_hidden_size , text_and_image_proj_size )
501+ text_projection = nn .Linear (text_hidden_size , text_and_image_proj_size )
502+
442503 return FLAVAModel (
443504 image_encoder = image_encoder ,
444505 text_encoder = text_encoder ,
445506 mm_encoder = mm_encoder ,
446507 image_to_mm_projection = image_to_mm_projection ,
447508 text_to_mm_projection = text_to_mm_projection ,
509+ text_projection = text_projection ,
510+ image_projection = image_projection ,
448511 )
449512
450513
0 commit comments