From 5f3742364e494365aa65d7b3838f6038c4ee5a05 Mon Sep 17 00:00:00 2001 From: JAX Toolbox Date: Thu, 7 Aug 2025 11:22:21 -0700 Subject: [PATCH] Reduce the latency by increase the batch size for vision transformer --- src/openpi/models/pi0.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/src/openpi/models/pi0.py b/src/openpi/models/pi0.py index 1531108e6..d358021a1 100644 --- a/src/openpi/models/pi0.py +++ b/src/openpi/models/pi0.py @@ -179,19 +179,25 @@ def embed_prefix( ar_mask = [] tokens = [] # embed images - for name in obs.images: - image_tokens, _ = self.PaliGemma.img(obs.images[name], train=False) + # batched: + image_names = list(obs.images.keys()) + stacked_images = jnp.stack(list(obs.images.values()), axis=1) + batch_size, num_cams = stacked_images.shape[:2] + reshaped_images = stacked_images.reshape(-1, *stacked_images.shape[2:]) + + all_image_tokens, _ = self.PaliGemma.img(reshaped_images, train=False) + all_image_tokens = all_image_tokens.reshape(batch_size, num_cams, all_image_tokens.shape[1], -1) + + for i, name in enumerate(image_names): + image_tokens = all_image_tokens[:, i] tokens.append(image_tokens) input_mask.append( - einops.repeat( - obs.image_masks[name], - "b -> b s", - s=image_tokens.shape[1], - ) + einops.repeat(obs.image_masks[name], "b -> b s", s=image_tokens.shape[1]) ) - # image tokens attend to each other - ar_mask += [False] * image_tokens.shape[1] + + # Set attention masks + ar_mask = [False] * (image_tokens.shape[1] * len(image_names)) # add language (aka tokenized inputs) if obs.tokenized_prompt is not None: @@ -323,3 +329,4 @@ def cond(carry): x_0, _ = jax.lax.while_loop(cond, step, (noise, 1.0)) return x_0 +