Skip to content

Commit a583cb5

Browse files
committed
last tweak to vat
1 parent 2587101 commit a583cb5

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "vit-pytorch"
7-
version = "1.14.2"
7+
version = "1.14.4"
88
description = "Vision Transformer (ViT) - Pytorch"
99
readme = { file = "README.md", content-type = "text/markdown" }
1010
license = { file = "LICENSE" }

vit_pytorch/vat.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,7 @@ def forward(
329329
extra = None, # (b d) - batch, dim extra
330330
tasks = None, # (b)
331331
actions = None, # (b k d) - batch, action chunk length, action dimension
332+
return_hiddens = False
332333
):
333334
batch = video_or_image.shape[0]
334335
return_loss = exists(actions)
@@ -408,6 +409,8 @@ def forward(
408409

409410
# cross attention
410411

412+
hiddens = [action_tokens]
413+
411414
for (maybe_film, maybe_self_attn, cross_attn, ff), layer_context in zip(self.layers, context):
412415

413416
if exists(tasks):
@@ -420,6 +423,8 @@ def forward(
420423

421424
action_tokens = ff(action_tokens) + action_tokens
422425

426+
hiddens.append(action_tokens)
427+
423428
# maybe unpack extra
424429

425430
if has_extra:
@@ -432,7 +437,10 @@ def forward(
432437
pred_action = self.to_pred_action(action_tokens)
433438

434439
if not return_loss:
435-
return pred_action
440+
if not return_hiddens:
441+
return pred_action
442+
443+
return pred_action, stack(hiddens)
436444

437445
assert pred_action.shape[1] == actions.shape[1]
438446

@@ -484,6 +492,6 @@ def forward(
484492

485493
# after much training
486494

487-
pred_actions = vat(images, tasks = tasks, extra = extra)
495+
pred_actions, hiddens = vat(images, tasks = tasks, extra = extra, return_hiddens = True)
488496

489497
assert pred_actions.shape == (2, 7, 20)

0 commit comments

Comments
 (0)