Skip to content

Commit 3205392

Browse files
[docs] Add ViLT model code usage example
Add an example of running ViLT from code outside of MMF cli at the end of the ViLT tutorial. Example is of ViLT vqa on raw image and text. ghstack-source-id: 4f7804e Pull Request resolved: #1179
1 parent d412c25 commit 3205392

File tree

1 file changed

+100
-0
lines changed

1 file changed

+100
-0
lines changed

website/docs/projects/vilt.md

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,103 @@ To pretrain a ViLT model from scratch on the COCO dataset,
4545
```
4646
mmf_run config=projects/vilt/configs/masked_coco/pretrain.yaml run_type=train_val dataset=masked_coco model=vilt
4747
```
48+
49+
## Using the ViLT model from code
50+
Here is an example of running the ViLT model from code, to do visual question answering (vqa) on a raw image and text.
51+
The forward pass takes ~15ms which is very fast compared to UNITER's ~600ms.
52+
53+
```python
54+
from argparse import Namespace
55+
56+
import torch
57+
from mmf.common.sample import SampleList
58+
from mmf.datasets.processors.bert_processors import VILTTextTokenizer
59+
from mmf.datasets.processors.image_processors import VILTImageProcessor
60+
from mmf.utils.build import build_model
61+
from mmf.utils.configuration import Configuration, load_yaml
62+
from mmf.utils.general import get_current_device
63+
from mmf.utils.text import VocabDict
64+
from omegaconf import OmegaConf
65+
from PIL import Image
66+
```
67+
68+
A way to make model configs and instantiate the ViLT model.
69+
```python
70+
# make model config for vilt vqa2
71+
model_name = "vilt"
72+
config_args = Namespace(
73+
config_override=None,
74+
opts=["model=vilt", "dataset=vqa2", "config=configs/defaults.yaml"],
75+
)
76+
default_config = Configuration(config_args).get_config()
77+
model_vqa_config = load_yaml(
78+
"/private/home/your/path/to/mmf/projects/vilt/configs/vqa2/defaults.yaml"
79+
)
80+
config = OmegaConf.merge(default_config, model_vqa_config)
81+
OmegaConf.resolve(config)
82+
model_config = config.model_config[model_name]
83+
model_config.model = model_name
84+
vilt_model = build_model(model_config)
85+
```
86+
87+
Load model weights, `model_checkpoint_path` is the model checkpoint downloaded at model zoo path `vilt.vqa`,
88+
with current url `s3://dl.fbaipublicfiles.com/mmf/data/models/vilt/vilt.finetuned.vqa2.tar.gz`
89+
```python
90+
# build model and load weights
91+
model_checkpoint_path = './vilt_vqa2.pth'
92+
state_dict = torch.load(model_checkpoint_path)
93+
vilt_model.load_state_dict(state_dict, strict=False)
94+
vilt_model.eval()
95+
vilt_model = vilt_model.to(get_current_device())
96+
```
97+
98+
Prepare input image and text.
99+
This example is using an image of a man with a hat kissing his daughter.
100+
The text is the question posed to the ViLT model for visual question answering.
101+
```python
102+
# get image input
103+
image_processor = VILTImageProcessor({"size": [384, 384]})
104+
image_path = "./kissing_image.jpg"
105+
raw_img = Image.open(image_path).convert("RGB")
106+
image = image_processor(raw_img)
107+
108+
# get text input
109+
text_tokenizer = VILTTextTokenizer({})
110+
question = "What is on his head?"
111+
processed_text_dict = text_tokenizer({"text": question})
112+
```
113+
114+
Wrap everything up in a sample list as expected by the ViLT BaseModel.
115+
```python
116+
# make batch inputs
117+
sample_dict = {**processed_text_dict, "image": image}
118+
sample_dict = {
119+
k: v.unsqueeze(0) for k, v in sample_dict.items() if isinstance(v, torch.Tensor)
120+
}
121+
sample_dict["targets"] = torch.zeros((1, 3129))
122+
sample_dict["targets"][0,1358] = 1
123+
sample_dict["dataset_name"] = "vqa2"
124+
sample_dict["dataset_type"] = "test"
125+
sample_list = SampleList(sample_dict).to(get_current_device())
126+
```
127+
128+
Load the vqa answer -> word string map to understand what it says!
129+
Currently file url at `s3://dl.fbaipublicfiles.com/mmf/data/datasets/vqa2/defaults/extras/vocabs/answers_vqa.txt`
130+
```python
131+
# load vqa2 id -> answers
132+
vocab_file_path = "/private/home/path/to/answers_vqa.txt"
133+
answer_vocab = VocabDict(vocab_file_path)
134+
```
135+
136+
And heres the part you've been waiting for!
137+
```python
138+
# do prediction
139+
with torch.no_grad():
140+
vqa_logits = vilt_model(sample_list)["scores"]
141+
answer_id = vqa_logits.argmax().item()
142+
answer = answer_vocab.idx2word(answer_id)
143+
print(chr(27) + "[2J") # clear the terminal
144+
print(f"{question}: {answer}")
145+
```
146+
147+
Expected output `What is on his head?: hat`

0 commit comments

Comments
 (0)