@@ -329,6 +329,7 @@ def forward(
329329 extra = None , # (b d) - batch, dim extra
330330 tasks = None , # (b)
331331 actions = None , # (b k d) - batch, action chunk length, action dimension
332+ return_hiddens = False
332333 ):
333334 batch = video_or_image .shape [0 ]
334335 return_loss = exists (actions )
@@ -408,6 +409,8 @@ def forward(
408409
409410 # cross attention
410411
412+ hiddens = [action_tokens ]
413+
411414 for (maybe_film , maybe_self_attn , cross_attn , ff ), layer_context in zip (self .layers , context ):
412415
413416 if exists (tasks ):
@@ -420,6 +423,8 @@ def forward(
420423
421424 action_tokens = ff (action_tokens ) + action_tokens
422425
426+ hiddens .append (action_tokens )
427+
423428 # maybe unpack extra
424429
425430 if has_extra :
@@ -432,7 +437,10 @@ def forward(
432437 pred_action = self .to_pred_action (action_tokens )
433438
434439 if not return_loss :
435- return pred_action
440+ if not return_hiddens :
441+ return pred_action
442+
443+ return pred_action , stack (hiddens )
436444
437445 assert pred_action .shape [1 ] == actions .shape [1 ]
438446
@@ -484,6 +492,6 @@ def forward(
484492
485493 # after much training
486494
487- pred_actions = vat (images , tasks = tasks , extra = extra )
495+ pred_actions , hiddens = vat (images , tasks = tasks , extra = extra , return_hiddens = True )
488496
489497 assert pred_actions .shape == (2 , 7 , 20 )
0 commit comments