Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions packages/openpi-client/src/openpi_client/action_chunk_broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ class ActionChunkBroker(_base_policy.BasePolicy):
"""

def __init__(self, policy: _base_policy.BasePolicy, action_horizon: int):
"""Initialize the ActionChunkBroker with a policy and action horizon.

Args:
policy: The underlying policy to wrap for chunked action delivery.
action_horizon: The number of action steps in each chunk from the policy.
"""
self._policy = policy
self._action_horizon = action_horizon
self._cur_step: int = 0
Expand All @@ -25,11 +31,20 @@ def __init__(self, policy: _base_policy.BasePolicy, action_horizon: int):

@override
def infer(self, obs: Dict) -> Dict: # noqa: UP006
"""Return the next action from the current chunk or fetch a new chunk if needed.

Args:
obs: Observation dictionary to pass to the underlying policy when fetching new chunks.

Returns:
Dictionary containing the action for the current step, extracted from the chunk.
"""
if self._last_results is None:
self._last_results = self._policy.infer(obs)
self._cur_step = 0

def slicer(x):
"""Extract the current step from array data or return non-array data unchanged."""
if isinstance(x, np.ndarray):
return x[self._cur_step, ...]
else:
Expand All @@ -45,6 +60,7 @@ def slicer(x):

@override
def reset(self) -> None:
"""Reset the broker state and the underlying policy."""
self._policy.reset()
self._last_results = None
self._cur_step = 0
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,25 @@ class PolicyAgent(_agent.Agent):
"""An agent that uses a policy to determine actions."""

def __init__(self, policy: _base_policy.BasePolicy) -> None:
"""Initialize the policy agent with a given policy.

Args:
policy: The policy instance used to infer actions from observations.
"""
self._policy = policy

@override
def get_action(self, observation: dict) -> dict:
"""Get an action by inferring from the observation using the policy.

Args:
observation: The current observation state as a dictionary.

Returns:
The action determined by the policy as a dictionary.
"""
return self._policy.infer(observation)

def reset(self) -> None:
"""Reset the policy to its initial state."""
self._policy.reset()
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ class WebsocketClientPolicy(_base_policy.BasePolicy):
"""

def __init__(self, host: str = "0.0.0.0", port: Optional[int] = None, api_key: Optional[str] = None) -> None:
"""Initialize the websocket client policy.

Args:
host: The hostname or IP address of the server to connect to.
port: The port number to connect to. If None, no port is appended to the URI.
api_key: Optional API key for authentication. If provided, it will be sent in the Authorization header.
"""
self._uri = f"ws://{host}"
if port is not None:
self._uri += f":{port}"
Expand All @@ -24,9 +31,25 @@ def __init__(self, host: str = "0.0.0.0", port: Optional[int] = None, api_key: O
self._ws, self._server_metadata = self._wait_for_server()

def get_server_metadata(self) -> Dict:
"""Get metadata received from the server during connection.

Returns:
Dictionary containing server metadata that was received during the initial connection.
"""
return self._server_metadata

def _wait_for_server(self) -> Tuple[websockets.sync.client.ClientConnection, Dict]:
"""Establish connection to the server and retrieve metadata.

Continuously attempts to connect to the server until successful, with 5-second intervals between attempts.
Once connected, receives and unpacks the server metadata.

Returns:
Tuple containing the websocket connection and the server metadata dictionary.

Raises:
Any exception that occurs during metadata unpacking or connection establishment (except ConnectionRefusedError).
"""
logging.info(f"Waiting for server at {self._uri}...")
while True:
try:
Expand All @@ -42,6 +65,17 @@ def _wait_for_server(self) -> Tuple[websockets.sync.client.ClientConnection, Dic

@override
def infer(self, obs: Dict) -> Dict: # noqa: UP006
"""Send observation to server and receive inference result.

Args:
obs: Dictionary containing observation data to be sent to the server.

Returns:
Dictionary containing the inference result from the server.

Raises:
RuntimeError: If the server responds with an error message (string instead of bytes).
"""
data = self._packer.pack(obs)
self._ws.send(data)
response = self._ws.recv()
Expand All @@ -52,4 +86,8 @@ def infer(self, obs: Dict) -> Dict: # noqa: UP006

@override
def reset(self) -> None:
"""Reset the policy state.

This implementation does nothing as the websocket client maintains no local state that needs resetting.
"""
pass
87 changes: 87 additions & 0 deletions src/openpi/models_pytorch/gemma_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,30 @@


class PaliGemmaWithExpertModel(nn.Module):
"""
A PyTorch module that combines PaliGemma vision-language model with a Gemma expert model.

This model integrates a PaliGemma model for multimodal processing with an additional
Gemma expert model for specialized language processing tasks.
"""

def __init__(
self,
vlm_config,
action_expert_config,
use_adarms=None,
precision: Literal["bfloat16", "float32"] = "bfloat16",
):
"""
Initialize the PaliGemma with Expert model.

Args:
vlm_config: Configuration object for the vision-language model
action_expert_config: Configuration object for the action expert model
use_adarms: List of two booleans indicating whether to use AdaRMS for each model.
Defaults to [False, False] if None
precision: Precision type for model parameters, either "bfloat16" or "float32"
"""
if use_adarms is None:
use_adarms = [False, False]
super().__init__()
Expand Down Expand Up @@ -61,6 +78,15 @@ def __init__(
self.to_bfloat16_for_selected_params(precision)

def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"):
"""
Convert model parameters to specified precision while keeping certain parameters in float32.

Args:
precision: Target precision for most parameters, either "bfloat16" or "float32"

Raises:
ValueError: If precision is not "bfloat16" or "float32"
"""
if precision == "bfloat16":
self.to(dtype=torch.bfloat16)
elif precision == "float32":
Expand All @@ -83,9 +109,27 @@ def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float3
param.data = param.data.to(dtype=torch.float32)

def embed_image(self, image: torch.Tensor):
"""
Extract image features using the PaliGemma vision tower.

Args:
image: Input image tensor

Returns:
torch.Tensor: Image feature embeddings
"""
return self.paligemma.model.get_image_features(image)

def embed_language_tokens(self, tokens: torch.Tensor):
"""
Convert language tokens to embeddings using the PaliGemma language model.

Args:
tokens: Input token tensor

Returns:
torch.Tensor: Token embeddings
"""
return self.paligemma.language_model.embed_tokens(tokens)

def forward(
Expand All @@ -97,6 +141,26 @@ def forward(
use_cache: bool | None = None,
adarms_cond: list[torch.Tensor] | None = None,
):
"""
Forward pass through the combined PaliGemma and expert models.

Processes input embeddings through either the PaliGemma language model alone,
the Gemma expert model alone, or both models in a combined fashion with
shared attention computation.

Args:
attention_mask: Mask to avoid attention on padding tokens
position_ids: Position indices for positional embeddings
past_key_values: Cached key-value pairs from previous forward passes
inputs_embeds: List of two embedding tensors, one for each model
use_cache: Whether to return cached key-value pairs
adarms_cond: Conditioning tensors for AdaRMS normalization

Returns:
tuple: A tuple containing:
- List of output tensors [prefix_output, suffix_output]
- Past key values for caching
"""
if adarms_cond is None:
adarms_cond = [None, None]
if inputs_embeds[1] is None:
Expand Down Expand Up @@ -156,6 +220,19 @@ def forward(

# Define the complete layer computation function for gradient checkpointing
def compute_layer_complete(layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond):
"""
Compute a complete transformer layer for both models with shared attention.

Args:
layer_idx: Index of the current layer
inputs_embeds: Input embeddings for both models
attention_mask: Attention mask tensor
position_ids: Position indices
adarms_cond: AdaRMS conditioning tensors

Returns:
list: Output embeddings for both models after layer processing
"""
models = [self.paligemma.language_model, self.gemma_expert.model]

query_states = []
Expand Down Expand Up @@ -260,6 +337,16 @@ def compute_layer_complete(layer_idx, inputs_embeds, attention_mask, position_id
# final norm
# Define final norm computation function for gradient checkpointing
def compute_final_norms(inputs_embeds, adarms_cond):
"""
Apply final layer normalization to both model outputs.

Args:
inputs_embeds: Input embeddings for both models
adarms_cond: AdaRMS conditioning tensors

Returns:
list: Normalized output embeddings for both models
"""
outputs_embeds = []
for i, hidden_states in enumerate(inputs_embeds):
out_emb, _ = models[i].norm(hidden_states, cond=adarms_cond[i])
Expand Down
27 changes: 27 additions & 0 deletions src/openpi/serving/websocket_policy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,26 @@ def __init__(
port: int | None = None,
metadata: dict | None = None,
) -> None:
"""Initialize the WebSocket policy server.

Args:
policy: The policy instance to serve over WebSocket
host: The host address to bind the server to
port: The port number to bind the server to, or None for automatic assignment
metadata: Additional metadata to send to clients upon connection
"""
self._policy = policy
self._host = host
self._port = port
self._metadata = metadata or {}
logging.getLogger("websockets.server").setLevel(logging.INFO)

def serve_forever(self) -> None:
"""Start the server and run it indefinitely in a blocking manner."""
asyncio.run(self.run())

async def run(self):
"""Run the WebSocket server asynchronously with the configured handler."""
async with _server.serve(
self._handler,
self._host,
Expand All @@ -46,6 +56,14 @@ async def run(self):
await server.serve_forever()

async def _handler(self, websocket: _server.ServerConnection):
"""Handle incoming WebSocket connections and process inference requests.

Sends metadata upon connection, then continuously receives observations,
runs inference, and sends back actions with timing information.

Args:
websocket: The WebSocket connection to handle
"""
logger.info(f"Connection from {websocket.remote_address} opened")
packer = msgpack_numpy.Packer()

Expand Down Expand Up @@ -84,6 +102,15 @@ async def _handler(self, websocket: _server.ServerConnection):


def _health_check(connection: _server.ServerConnection, request: _server.Request) -> _server.Response | None:
"""Handle health check requests on the /healthz endpoint.

Args:
connection: The server connection instance
request: The incoming HTTP request

Returns:
HTTP OK response for /healthz path, None otherwise to continue normal processing
"""
if request.path == "/healthz":
return connection.respond(http.HTTPStatus.OK, "OK\n")
# Continue with the normal request handling.
Expand Down