Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 79 additions & 1 deletion openhands-agent-server/openhands/agent_server/env_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from pathlib import Path
from types import UnionType
from typing import Annotated, Literal, Union, get_args, get_origin
Expand Down Expand Up @@ -71,6 +72,79 @@ def from_env(self, key: str) -> None | MissingType:
return None


@dataclass
class EnumEnvParser(EnvParser):
"""Parser for enum types, supporting both enum.Enum subclasses and Literal types."""

enum_type: type

def from_env(self, key: str) -> str | MissingType:
if key not in os.environ:
return MISSING

value = os.environ[key]

# Handle enum.Enum subclasses
if inspect.isclass(self.enum_type) and issubclass(self.enum_type, Enum):
# Try exact match first (for string values)
try:
return self.enum_type(value).value
except ValueError:
pass

# Try converting to int and matching (for integer values)
try:
int_value = int(value)
return self.enum_type(int_value).value
except (ValueError, TypeError):
pass

# Try case-insensitive match by name
for enum_member in self.enum_type:
if enum_member.name.upper() == value.upper():
return enum_member.value

# Try case-insensitive match by value (if value is string)
for enum_member in self.enum_type:
if (
isinstance(enum_member.value, str)
and enum_member.value.upper() == value.upper()
):
return enum_member.value

# If no match found, raise an error with helpful message
valid_values = [member.value for member in self.enum_type]
valid_names = [member.name for member in self.enum_type]
raise ValueError(
f"Invalid value '{value}' for {self.enum_type.__name__}. "
f"Valid values: {valid_values}. Valid names: {valid_names}"
)

# Handle Literal types (get_origin returns Literal, get_args returns the values)
origin = get_origin(self.enum_type)
if origin is Literal:
literal_values = get_args(self.enum_type)

# Try exact match first
if value in literal_values:
return value

# Try case-insensitive match for string literals
for literal_value in literal_values:
if (
isinstance(literal_value, str)
and literal_value.upper() == value.upper()
):
return literal_value

# For Literal types, return the value as-is to allow Pydantic validation
# to handle the error. This is important for discriminated unions.
return value

# Fallback: return the string value as-is
return value


@dataclass
class ModelEnvParser(EnvParser):
parsers: dict[str, EnvParser]
Expand Down Expand Up @@ -236,7 +310,11 @@ def get_env_parser(target_type: type, parsers: dict[type, EnvParser]) -> EnvPars
assert args[1] in (str, int, float, bool)
return DictEnvParser()
if origin is Literal:
return StrEnvParser()
return EnumEnvParser(target_type)

# Check if target_type is an enum.Enum subclass
if inspect.isclass(target_type) and issubclass(target_type, Enum):
return EnumEnvParser(target_type)

if origin and issubclass(origin, BaseModel):
target_type = origin
Expand Down
Loading
Loading