Skip to content

Conversation

xiaopoc
Copy link

@xiaopoc xiaopoc commented Aug 7, 2025

Description of the change:
Currently, each image in obs.images in pi0.py is processed sequentially by the vision transformer in embed_prefix(), which leads to redundant kernel launches and increased runtime.

Motivation:
Batching all images together and encoding them in one forward pass reduces kernel launch overhead and enables better fusion. This can lead to lower latency during inference (~5ms speed up on RTX 4090).

Proposed Change:

  • Stack images across cameras along a new axis, e.g., stacked_images = jnp.stack(list(obs.images.values()), axis=1)
  • Flatten and encode all images in one forward pass via self.PaliGemma.img(...)
  • Update the token aggregation and attention masks accordingly

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants