@@ -45,3 +45,103 @@ To pretrain a ViLT model from scratch on the COCO dataset,
45
45
```
46
46
mmf_run config=projects/vilt/configs/masked_coco/pretrain.yaml run_type=train_val dataset=masked_coco model=vilt
47
47
```
48
+
49
+ ## Using the ViLT model from code
50
+ Here is an example of running the ViLT model from code, to do visual question answering (vqa) on a raw image and text.
51
+ The forward pass takes ~ 15ms which is very fast compared to UNITER's ~ 600ms.
52
+
53
+ ``` python
54
+ from argparse import Namespace
55
+
56
+ import torch
57
+ from mmf.common.sample import SampleList
58
+ from mmf.datasets.processors.bert_processors import VILTTextTokenizer
59
+ from mmf.datasets.processors.image_processors import VILTImageProcessor
60
+ from mmf.utils.build import build_model
61
+ from mmf.utils.configuration import Configuration, load_yaml
62
+ from mmf.utils.general import get_current_device
63
+ from mmf.utils.text import VocabDict
64
+ from omegaconf import OmegaConf
65
+ from PIL import Image
66
+ ```
67
+
68
+ A way to make model configs and instantiate the ViLT model.
69
+ ``` python
70
+ # make model config for vilt vqa2
71
+ model_name = " vilt"
72
+ config_args = Namespace(
73
+ config_override = None ,
74
+ opts = [" model=vilt" , " dataset=vqa2" , " config=configs/defaults.yaml" ],
75
+ )
76
+ default_config = Configuration(config_args).get_config()
77
+ model_vqa_config = load_yaml(
78
+ " /private/home/your/path/to/mmf/projects/vilt/configs/vqa2/defaults.yaml"
79
+ )
80
+ config = OmegaConf.merge(default_config, model_vqa_config)
81
+ OmegaConf.resolve(config)
82
+ model_config = config.model_config[model_name]
83
+ model_config.model = model_name
84
+ vilt_model = build_model(model_config)
85
+ ```
86
+
87
+ Load model weights, ` model_checkpoint_path ` is the model checkpoint downloaded at model zoo path ` vilt.vqa ` ,
88
+ with current url ` s3://dl.fbaipublicfiles.com/mmf/data/models/vilt/vilt.finetuned.vqa2.tar.gz `
89
+ ``` python
90
+ # build model and load weights
91
+ model_checkpoint_path = ' ./vilt_vqa2.pth'
92
+ state_dict = torch.load(model_checkpoint_path)
93
+ vilt_model.load_state_dict(state_dict, strict = False )
94
+ vilt_model.eval()
95
+ vilt_model = vilt_model.to(get_current_device())
96
+ ```
97
+
98
+ Prepare input image and text.
99
+ This example is using an image of a man with a hat kissing his daughter.
100
+ The text is the question posed to the ViLT model for visual question answering.
101
+ ``` python
102
+ # get image input
103
+ image_processor = VILTImageProcessor({" size" : [384 , 384 ]})
104
+ image_path = " ./kissing_image.jpg"
105
+ raw_img = Image.open(image_path).convert(" RGB" )
106
+ image = image_processor(raw_img)
107
+
108
+ # get text input
109
+ text_tokenizer = VILTTextTokenizer({})
110
+ question = " What is on his head?"
111
+ processed_text_dict = text_tokenizer({" text" : question})
112
+ ```
113
+
114
+ Wrap everything up in a sample list as expected by the ViLT BaseModel.
115
+ ``` python
116
+ # make batch inputs
117
+ sample_dict = {** processed_text_dict, " image" : image}
118
+ sample_dict = {
119
+ k: v.unsqueeze(0 ) for k, v in sample_dict.items() if isinstance (v, torch.Tensor)
120
+ }
121
+ sample_dict[" targets" ] = torch.zeros((1 , 3129 ))
122
+ sample_dict[" targets" ][0 ,1358 ] = 1
123
+ sample_dict[" dataset_name" ] = " vqa2"
124
+ sample_dict[" dataset_type" ] = " test"
125
+ sample_list = SampleList(sample_dict).to(get_current_device())
126
+ ```
127
+
128
+ Load the vqa answer -> word string map to understand what it says!
129
+ Currently file url at ` s3://dl.fbaipublicfiles.com/mmf/data/datasets/vqa2/defaults/extras/vocabs/answers_vqa.txt `
130
+ ``` python
131
+ # load vqa2 id -> answers
132
+ vocab_file_path = " /private/home/path/to/answers_vqa.txt"
133
+ answer_vocab = VocabDict(vocab_file_path)
134
+ ```
135
+
136
+ And heres the part you've been waiting for!
137
+ ``` python
138
+ # do prediction
139
+ with torch.no_grad():
140
+ vqa_logits = vilt_model(sample_list)[" scores" ]
141
+ answer_id = vqa_logits.argmax().item()
142
+ answer = answer_vocab.idx2word(answer_id)
143
+ print (chr (27 ) + " [2J" ) # clear the terminal
144
+ print (f " { question} : { answer} " )
145
+ ```
146
+
147
+ Expected output ` What is on his head?: hat `
0 commit comments