|  | 
|  | 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. | 
|  | 2 | +# All rights reserved. | 
|  | 3 | +# | 
|  | 4 | +# This source code is licensed under the BSD-style license found in the | 
|  | 5 | +# LICENSE file in the root directory of this source tree. | 
|  | 6 | + | 
|  | 7 | + | 
|  | 8 | +from typing import NamedTuple, Optional | 
|  | 9 | + | 
|  | 10 | +import torch | 
|  | 11 | + | 
|  | 12 | +from torch import nn, Tensor | 
|  | 13 | +from torch.nn import functional as F | 
|  | 14 | +from torchmultimodal.modules.layers.transformer import TransformerOutput | 
|  | 15 | + | 
|  | 16 | + | 
|  | 17 | +class Blip2Output(NamedTuple): | 
|  | 18 | +    """ | 
|  | 19 | +    BLIP2 model output for loss computation. | 
|  | 20 | +
 | 
|  | 21 | +    image_embeddings(Tensor): normalized image embeddings returned by the visual encoder | 
|  | 22 | +        with shape [bsz x seq_len x embed_dim]. | 
|  | 23 | +    image_features(Tensor): Image features after qformer and projection (for stage 1 training) | 
|  | 24 | +        with shape [bsz, num_query_tokens, embed_dim] | 
|  | 25 | +    image_qformer_output(Tensor) : last hidden state for qformer output by given image input | 
|  | 26 | +    text_features(Optional[Tensor]): Text features after qformer and projection if text input is provided | 
|  | 27 | +        with shape [bsz, embed_dim] | 
|  | 28 | +    prediction_scores (Optional[Tensor]): computed for next word prediction | 
|  | 29 | +        with shape of [bsz, seq_len, vocab_size] | 
|  | 30 | +    """ | 
|  | 31 | + | 
|  | 32 | +    image_embeddings: Tensor | 
|  | 33 | +    image_features: Tensor | 
|  | 34 | +    image_qformer_output: Tensor | 
|  | 35 | +    text_features: Optional[Tensor] = None | 
|  | 36 | +    prediction_scores: Optional[Tensor] = None | 
|  | 37 | + | 
|  | 38 | + | 
|  | 39 | +class BLIP2(nn.Module): | 
|  | 40 | +    """ | 
|  | 41 | +    BLIP2(https://arxiv.org/pdf/2301.12597.pdf) provides a pre-training strategy to bootstrap vision-language | 
|  | 42 | +    pre-training from frozen image encoders and frozen large language models(LLM). BLIP-2 bridges the modality gap | 
|  | 43 | +    and facilitates cross-modal alignment via Querying Transformer (Q-former). Q-former is a lightweight transformer | 
|  | 44 | +    which has a set of learnable query vectors to extract visual features from the frozen image encoder. | 
|  | 45 | +
 | 
|  | 46 | +    Args: | 
|  | 47 | +        qformer(nn.Module): Querying Transformer (Q-former) | 
|  | 48 | +        visual_encoder(nn.Module): Frozen image encoder | 
|  | 49 | +        dim_q(int) : Dimension of query tensor, this value should be the same as dim_q in qformer. | 
|  | 50 | +        image_encoder_embedding_dim(int): Embedding dimension for image encoder, | 
|  | 51 | +            this value should be the same as dim_kv in qformer. | 
|  | 52 | +        freeze_visual_encoder(bool): Whether to freeze the visual encoder, default to True | 
|  | 53 | +        cross_attention_freq(int): Frequency of adding cross-attention block in Qformer, default to 2 | 
|  | 54 | +        embedding_dim(int): Embedding dimension | 
|  | 55 | +        num_query_token(int): Number of query tokens in Qformer, default to 32 | 
|  | 56 | +        init_query_tokens(bool): whether init query token params, default to True | 
|  | 57 | +        decoder_bos_token_id(Optional[int]): bos_token_id used in decoder, default to None | 
|  | 58 | +    """ | 
|  | 59 | + | 
|  | 60 | +    def __init__( | 
|  | 61 | +        self, | 
|  | 62 | +        qformer: nn.Module, | 
|  | 63 | +        vision_encoder: nn.Module, | 
|  | 64 | +        dim_q: int, | 
|  | 65 | +        image_encoder_embedding_dim: int, | 
|  | 66 | +        freeze_vision_encoder: bool = True, | 
|  | 67 | +        cross_attention_freq: int = 2, | 
|  | 68 | +        embedding_dim: int = 256, | 
|  | 69 | +        num_query_token: int = 32, | 
|  | 70 | +        init_query_tokens: bool = True, | 
|  | 71 | +        decoder_bos_token_id: Optional[int] = None, | 
|  | 72 | +    ): | 
|  | 73 | +        super().__init__() | 
|  | 74 | +        self.vision_encoder = vision_encoder | 
|  | 75 | +        if freeze_vision_encoder: | 
|  | 76 | +            for param in self.vision_encoder.parameters(): | 
|  | 77 | +                param.requires_grad = False | 
|  | 78 | +            self.vision_encoder = self.vision_encoder.eval() | 
|  | 79 | + | 
|  | 80 | +        self.qformer = qformer | 
|  | 81 | +        self.decoder_bos_token_id = decoder_bos_token_id | 
|  | 82 | +        self.dim_q = dim_q | 
|  | 83 | +        self.query_tokens = nn.Parameter(torch.zeros(1, num_query_token, self.dim_q)) | 
|  | 84 | +        if init_query_tokens: | 
|  | 85 | +            self.query_tokens.data.normal_(mean=0.0, std=0.02) | 
|  | 86 | + | 
|  | 87 | +        self.vision_proj = nn.Linear(self.dim_q, embedding_dim) | 
|  | 88 | +        self.text_proj = nn.Linear(self.dim_q, embedding_dim) | 
|  | 89 | +        self.ln_vision = nn.LayerNorm(image_encoder_embedding_dim) | 
|  | 90 | + | 
|  | 91 | +    def forward( | 
|  | 92 | +        self, | 
|  | 93 | +        image: Tensor, | 
|  | 94 | +        input_ids: Optional[Tensor] = None, | 
|  | 95 | +        attention_mask: Optional[Tensor] = None, | 
|  | 96 | +    ) -> Blip2Output: | 
|  | 97 | +        """ | 
|  | 98 | +        Args: | 
|  | 99 | +            image(Tensor): Image input tensor with shape [B, C, H, W] | 
|  | 100 | +            input_ids(Optional[Tensor]): Text input tensor with shape [bsz, seq_len] | 
|  | 101 | +            attention_mask(Optional[Tensor]): Attention mask tensor with shape [bsz, seq_len] | 
|  | 102 | +
 | 
|  | 103 | +        Returns: | 
|  | 104 | +            return BLIP2 model output(Blip2Output). | 
|  | 105 | +        """ | 
|  | 106 | +        vision_encoder_output = self.vision_encoder(image) | 
|  | 107 | +        if isinstance(vision_encoder_output, TransformerOutput): | 
|  | 108 | +            vision_encoder_output = vision_encoder_output.last_hidden_state | 
|  | 109 | +        assert vision_encoder_output is not None | 
|  | 110 | +        image_embeds = self.ln_vision(vision_encoder_output) | 
|  | 111 | +        # query tokens: [batch_size, num_query_token, encoder_hidden_size] | 
|  | 112 | +        query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) | 
|  | 113 | +        query_output = self.qformer.model( | 
|  | 114 | +            query_embeds=query_tokens, | 
|  | 115 | +            encoder_hidden_states=image_embeds, | 
|  | 116 | +            use_cache=True, | 
|  | 117 | +        ) | 
|  | 118 | + | 
|  | 119 | +        # image_feats: [batch_size, num_query_token, embedding_dim] | 
|  | 120 | +        image_feats = F.normalize(self.vision_proj(query_output[0]), dim=-1) | 
|  | 121 | + | 
|  | 122 | +        text_feats: Optional[Tensor] = None | 
|  | 123 | +        prediction_scores: Optional[Tensor] = None | 
|  | 124 | +        if input_ids is not None: | 
|  | 125 | +            text_output = self.qformer.model( | 
|  | 126 | +                input_ids, | 
|  | 127 | +                attention_mask=attention_mask, | 
|  | 128 | +                use_cache=False, | 
|  | 129 | +            ) | 
|  | 130 | +            text_feats = F.normalize(self.text_proj(text_output[0][:, 0, :]), dim=-1) | 
|  | 131 | + | 
|  | 132 | +            decoder_input_ids = input_ids.clone() | 
|  | 133 | +            if self.decoder_bos_token_id is not None: | 
|  | 134 | +                # pyre-ignore | 
|  | 135 | +                decoder_input_ids[:, 0] = self.decoder_bos_token_id | 
|  | 136 | + | 
|  | 137 | +            query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to( | 
|  | 138 | +                input_ids.device | 
|  | 139 | +            ) | 
|  | 140 | +            if attention_mask is not None: | 
|  | 141 | +                attention_mask = torch.cat([query_atts, attention_mask], dim=1) | 
|  | 142 | + | 
|  | 143 | +            # set use_cache = False since past_key_values should be cached in previous steps. | 
|  | 144 | +            prediction_scores = self.qformer( | 
|  | 145 | +                input_ids=decoder_input_ids, | 
|  | 146 | +                attention_mask=attention_mask, | 
|  | 147 | +                past_key_values=query_output[1], | 
|  | 148 | +                use_cache=False, | 
|  | 149 | +            ) | 
|  | 150 | + | 
|  | 151 | +        return Blip2Output( | 
|  | 152 | +            image_embeddings=image_embeds, | 
|  | 153 | +            image_features=image_feats, | 
|  | 154 | +            image_qformer_output=query_output[0], | 
|  | 155 | +            text_features=text_feats, | 
|  | 156 | +            prediction_scores=prediction_scores, | 
|  | 157 | +        ) | 
0 commit comments