Skip to content

Commit 78c58d5

Browse files
authored
fix: Adapt assistant sdk to hlp (#242)
* fix: add schemas to assistant plans * fix: support func type in code and class in schemas * fix: add schemas to api call * fix: raise an error on code parsing failed * fix: change types * fix: add tests * fix: adapt code to 3.8 * fix: adapt code to 3.8 * fix: support model schemas of python 3.11 * fix: support model schemas of python 3.11 * fix: PR comments * fix: PR comments * fix: add user defined plan example to integration tests
1 parent 770c6c5 commit 78c58d5

File tree

10 files changed

+226
-8
lines changed

10 files changed

+226
-8
lines changed

ai21/clients/common/beta/assistant/plans.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,62 @@
11
from __future__ import annotations
22

3+
import inspect
34
from abc import ABC, abstractmethod
4-
from typing import Any, Dict
5+
from typing import Any, Dict, Type, Callable, List
56

7+
from pydantic import BaseModel
8+
9+
from ai21.errors import CodeParsingError
10+
from ai21.models._pydantic_compatibility import _to_schema
611
from ai21.models.responses.plan_response import PlanResponse, ListPlanResponse
12+
from ai21.types import NOT_GIVEN, NotGiven
713
from ai21.utils.typing import remove_not_given
814

915

1016
class BasePlans(ABC):
1117
_module_name = "plans"
1218

19+
def _parse_schema(self, schema: Type[BaseModel] | Dict[str, Any]) -> Dict[str, Any]:
20+
if inspect.isclass(schema) and issubclass(schema, BaseModel):
21+
return _to_schema(schema)
22+
return schema
23+
24+
def _parse_code(self, code: str | Callable) -> str:
25+
if callable(code):
26+
try:
27+
return inspect.getsource(code).strip()
28+
except OSError as e:
29+
raise CodeParsingError(str(e))
30+
except Exception:
31+
raise CodeParsingError()
32+
return code
33+
1334
@abstractmethod
1435
def create(
1536
self,
1637
*,
1738
assistant_id: str,
18-
code: str,
39+
code: str | Callable,
40+
schemas: List[Dict[str, Any]] | List[Type[BaseModel]] | NotGiven = NOT_GIVEN,
1941
**kwargs,
2042
) -> PlanResponse:
2143
pass
2244

2345
def _create_body(
2446
self,
2547
*,
26-
code: str,
48+
code: str | Callable,
49+
schemas: List[Dict[str, Any]] | List[BaseModel] | NotGiven = NOT_GIVEN,
2750
**kwargs,
2851
) -> Dict[str, Any]:
52+
code_str = self._parse_code(code)
53+
2954
return remove_not_given(
3055
{
31-
"code": code,
56+
"code": code_str,
57+
"schemas": (
58+
[self._parse_schema(schema) for schema in schemas] if schemas is not NOT_GIVEN else NOT_GIVEN
59+
),
3260
**kwargs,
3361
}
3462
)
@@ -57,5 +85,6 @@ def modify(
5785
assistant_id: str,
5886
plan_id: str,
5987
code: str,
88+
schemas: List[Dict[str, Any]] | NotGiven = NOT_GIVEN,
6089
) -> PlanResponse:
6190
pass

ai21/clients/studio/resources/beta/assistant/assistants_plans.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
from __future__ import annotations
22

3+
from typing import List, Any, Dict, Type
4+
5+
from pydantic import BaseModel
6+
37
from ai21.clients.common.beta.assistant.plans import BasePlans
48
from ai21.clients.studio.resources.studio_resource import (
59
AsyncStudioResource,
610
StudioResource,
711
)
812
from ai21.models.responses.plan_response import PlanResponse, ListPlanResponse
13+
from ai21.types import NotGiven, NOT_GIVEN
914

1015

1116
class AssistantPlans(StudioResource, BasePlans):
@@ -14,10 +19,12 @@ def create(
1419
*,
1520
assistant_id: str,
1621
code: str,
22+
schemas: List[Dict[str, Any]] | List[Type[BaseModel]] | NotGiven = NOT_GIVEN,
1723
**kwargs,
1824
) -> PlanResponse:
1925
body = self._create_body(
2026
code=code,
27+
schemas=schemas,
2128
**kwargs,
2229
)
2330

@@ -44,8 +51,12 @@ def modify(
4451
assistant_id: str,
4552
plan_id: str,
4653
code: str,
54+
schemas: List[Dict[str, Any]] | List[Type[BaseModel]] | NotGiven = NOT_GIVEN,
4755
) -> PlanResponse:
48-
body = dict(code=code)
56+
body = self._create_body(
57+
code=code,
58+
schemas=schemas,
59+
)
4960

5061
return self._patch(
5162
path=f"/assistants/{assistant_id}/{self._module_name}/{plan_id}", body=body, response_cls=PlanResponse
@@ -58,10 +69,12 @@ async def create(
5869
*,
5970
assistant_id: str,
6071
code: str,
72+
schemas: List[Dict[str, Any]] | List[Type[BaseModel]] | NotGiven = NOT_GIVEN,
6173
**kwargs,
6274
) -> PlanResponse:
6375
body = self._create_body(
6476
code=code,
77+
schemas=schemas,
6578
**kwargs,
6679
)
6780

@@ -90,8 +103,12 @@ async def modify(
90103
assistant_id: str,
91104
plan_id: str,
92105
code: str,
106+
schemas: List[Dict[str, Any]] | List[Type[BaseModel]] | NotGiven = NOT_GIVEN,
93107
) -> PlanResponse:
94-
body = dict(code=code)
108+
body = self._create_body(
109+
code=code,
110+
schemas=schemas,
111+
)
95112

96113
return await self._patch(
97114
path=f"/assistants/{assistant_id}/{self._module_name}/{plan_id}", body=body, response_cls=PlanResponse

ai21/errors.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,10 @@ def __init__(self, chunk: str, error_message: Optional[str] = None):
111111
class InternalDependencyException(AI21APIError):
112112
def __init__(self, details: Optional[str] = None):
113113
super().__init__(530, details)
114+
115+
116+
class CodeParsingError(AI21Error):
117+
def __init__(self, details: Optional[str] = None):
118+
message = f"Code can't be parsed{'' if details is None else f': {details}'}"
119+
super().__init__(message)
120+
self.message = message

ai21/models/_pydantic_compatibility.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import Dict, Any
3+
from typing import Dict, Any, Type
44

55
from pydantic import VERSION, BaseModel
66

@@ -33,3 +33,10 @@ def _from_json(obj: "AI21BaseModel", json_str: str, **kwargs) -> BaseModel: # n
3333
return obj.model_validate_json(json_str, **kwargs)
3434

3535
return obj.parse_raw(json_str, **kwargs)
36+
37+
38+
def _to_schema(model_object: Type[BaseModel], **kwargs) -> Dict[str, Any]:
39+
if IS_PYDANTIC_V2:
40+
return model_object.model_json_schema(**kwargs)
41+
42+
return model_object.schema(**kwargs)

ai21/models/responses/plan_response.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from datetime import datetime
2-
from typing import List, Optional
2+
from typing import List, Optional, Dict, Any
33

44
from ai21.models.ai21_base_model import AI21BaseModel
55

@@ -10,6 +10,7 @@ class PlanResponse(AI21BaseModel):
1010
updated_at: datetime
1111
assistant_id: str
1212
code: str
13+
schemas: List[Dict[str, Any]]
1314

1415

1516
class ListPlanResponse(AI21BaseModel):
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from ai21 import AI21Client
2+
from pydantic import BaseModel
3+
4+
TIMEOUT = 20
5+
6+
7+
def test_func():
8+
pass
9+
10+
11+
class ExampleSchema(BaseModel):
12+
name: str
13+
id: str
14+
15+
16+
def main():
17+
ai21_client = AI21Client()
18+
19+
assistant = ai21_client.beta.assistants.create(name="My Assistant")
20+
21+
plan = ai21_client.beta.assistants.plans.create(assistant_id=assistant.id, code=test_func, schemas=[ExampleSchema])
22+
route = ai21_client.beta.assistants.routes.create(
23+
assistant_id=assistant.id, plan_id=plan.id, name="My Route", examples=["hi"], description="My Route Description"
24+
)
25+
print(f"Route: {route}")

tests/integration_tests/clients/test_studio.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
("chat/chat_function_calling.py",),
2626
("chat/chat_function_calling_multiple_tools.py",),
2727
("chat/chat_response_format.py",),
28+
("assistant/user_defined_plans.py",),
2829
],
2930
ids=[
3031
"when_tokenization__should_return_ok",
@@ -35,6 +36,7 @@
3536
"when_chat_completions_with_function_calling__should_return_ok",
3637
"when_chat_completions_with_function_calling_multiple_tools_should_return_ok",
3738
"when_chat_completions_with_response_format__should_return_ok",
39+
"when_assistant_with_user_defined_plans_should_return_ok",
3840
],
3941
)
4042
def test_studio(test_file_name: str):

tests/unittests/clients/studio/resources/assistant/__init__.py

Whitespace-only changes.

tests/unittests/clients/studio/resources/assistant/plans/__init__.py

Whitespace-only changes.
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
from typing import Callable, List, Dict, Any, Type, Union
2+
3+
from pydantic import BaseModel
4+
from ai21.clients.common.beta.assistant.plans import BasePlans
5+
from ai21.errors import CodeParsingError
6+
from ai21.models.responses.plan_response import PlanResponse, ListPlanResponse
7+
from ai21.types import NotGiven, NOT_GIVEN
8+
import pytest
9+
10+
11+
class PlanTestClass(BasePlans):
12+
def create(
13+
self,
14+
*,
15+
assistant_id: str,
16+
code: Union[str, Callable],
17+
schemas: Union[List[Dict[str, Any]], List[Type[BaseModel]], NotGiven] = NOT_GIVEN,
18+
**kwargs,
19+
) -> PlanResponse:
20+
pass
21+
22+
def list(self, *, assistant_id: str) -> ListPlanResponse:
23+
pass
24+
25+
def retrieve(self, *, assistant_id: str, plan_id: str) -> PlanResponse:
26+
pass
27+
28+
def modify(
29+
self, *, assistant_id: str, plan_id: str, code: str, schemas: Union[List[Dict[str, Any]], NotGiven] = NOT_GIVEN
30+
) -> PlanResponse:
31+
pass
32+
33+
34+
def test_create_body__when_pass_code_str__should_return_dict():
35+
# Arrange
36+
code = "code"
37+
38+
# Act
39+
result = PlanTestClass()._create_body(code=code)
40+
41+
# Assert
42+
assert result == {"code": code}
43+
44+
45+
def test_create_body__when_pass_code_callable__should_return_dict():
46+
# Arrange
47+
def code():
48+
return "code"
49+
50+
# Act
51+
result = PlanTestClass()._create_body(code=code)
52+
53+
# Assert
54+
assert result == {"code": 'def code():\n return "code"'}
55+
56+
57+
def test_create_body__when_pass_code_and_dict_schemas__should_return_dict_with_schemas():
58+
# Arrange
59+
code = "code"
60+
schemas = [{"type": "object", "properties": {"name": {"type": "string"}}}]
61+
62+
# Act
63+
result = PlanTestClass()._create_body(code=code, schemas=schemas)
64+
65+
# Assert
66+
assert result == {"code": code, "schemas": schemas}
67+
68+
69+
class TestSchema(BaseModel):
70+
name: str
71+
age: int
72+
73+
74+
def test_create_body__when_pass_code_and_pydantic_schemas__should_return_dict_with_converted_schemas():
75+
# Arrange
76+
code = "code"
77+
schemas = [TestSchema]
78+
79+
# Act
80+
result = PlanTestClass()._create_body(code=code, schemas=schemas)
81+
82+
# Assert
83+
expected_schema = {
84+
"properties": {"age": {"title": "Age", "type": "integer"}, "name": {"title": "Name", "type": "string"}},
85+
"required": ["name", "age"],
86+
"title": "TestSchema",
87+
"type": "object",
88+
}
89+
assert result == {"code": code, "schemas": [expected_schema]}
90+
91+
92+
def test_create_body__when_pass_code_and_not_given_schemas__should_return_dict_without_schemas():
93+
# Arrange
94+
code = "code"
95+
96+
# Act
97+
result = PlanTestClass()._create_body(code=code, schemas=NOT_GIVEN)
98+
99+
# Assert
100+
assert result == {"code": code}
101+
102+
103+
def test_create_body__when_pass_empty_schemas_list__should_return_dict_with_empty_schemas():
104+
# Arrange
105+
code = "code"
106+
schemas = []
107+
108+
# Act
109+
result = PlanTestClass()._create_body(code=code, schemas=schemas)
110+
111+
# Assert
112+
assert result == {"code": code, "schemas": schemas}
113+
114+
115+
def test_create_body__when_cannot_get_source_code__should_raise_code_parsing_error():
116+
# Arrange
117+
class CallableWithoutSource:
118+
def __call__(self):
119+
return "result"
120+
121+
# Override __code__ to simulate a built-in function or method
122+
@property
123+
def __code__(self):
124+
raise AttributeError("'CallableWithoutSource' object has no attribute '__code__'")
125+
126+
code = CallableWithoutSource()
127+
128+
# Act & Assert
129+
with pytest.raises(CodeParsingError):
130+
PlanTestClass()._create_body(code=code)

0 commit comments

Comments
 (0)