Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
37cbad2
[+] async client version with provider interface extension. Several p…
kapulkin Jan 16, 2025
d00c078
[+] tests for async client with providers, that have async implementa…
kapulkin Jan 16, 2025
91bcbd3
Merge branch 'andrewyng:main' into async_version
kapulkin Jan 23, 2025
ce8d513
Merge commit 'bd6b23fd72b3a391d96659b7faf5cdaed6a415dc' into async_ve…
kapulkin Feb 4, 2025
949bb36
[*] outdated comment is removed
kapulkin Feb 4, 2025
0e242a0
[*] openai code style polishing
kapulkin Feb 4, 2025
de4d97c
[f] import fix, depdendency for async pytest tests
kapulkin Feb 4, 2025
e244751
[*] formatting fixes
kapulkin Feb 4, 2025
f4a3f35
[*] make chat_completions_create_async() throw NotImplementedError in…
kapulkin Feb 5, 2025
b2520e4
[*] additional args support in default chat_completions_create_async(…
kapulkin Feb 5, 2025
c329720
Merge branch 'andrewyng:main' into kapulkin/async_version
kapulkin Feb 7, 2025
73d63f1
Merge commit '9dc9ae9a45470f6632e85d2f53087b84510c25f3' into kapulkin…
kapulkin Feb 19, 2025
dba87be
[+] AsyncProvider to perform async llm call
kapulkin Feb 19, 2025
4d97afd
Merge branch 'andrewyng:main' into kapulkin/async_version
kapulkin Mar 23, 2025
6ec230f
tools and thinking support for async client
kapulkin Mar 23, 2025
0a8c558
Merge branch 'andrewyng:main' into kapulkin/async_version
kapulkin Mar 26, 2025
f4cb3f1
[+] automatic_tool_calling to control tool call workflow
kapulkin Mar 28, 2025
10940f9
Merge commit 'cf9df9a107c3fbbc86ddfcfbddb2607622485db6' into kapulkin…
kapulkin Apr 4, 2025
83d99d2
[f] fixed support for provider_configs in AsyncClient constructor
kapulkin Apr 4, 2025
58db39a
Merge commit '83d99d2e2f8878b43b3c47aac991979a2adf0d2f' into kapulkin…
kapulkin Apr 4, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions aisuite/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .client import Client
from .async_client import AsyncClient
from .framework.message import Message
from .utils.tools import Tools
81 changes: 81 additions & 0 deletions aisuite/async_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from .client import Client, Chat, Completions
from .base_client import BaseClient
from .provider import ProviderFactory
from .tool_runner import ToolRunner


class AsyncClient(BaseClient):
def __init__(self, provider_configs: dict = {}):
super().__init__(provider_configs, is_async=True)

def configure(self, provider_configs: dict = None):
super().configure(provider_configs, True)

@property
def chat(self):
"""Return the async chat API interface."""
if not self._chat:
self._chat = AsyncChat(self)
return self._chat


class AsyncChat(Chat):
def __init__(self, client: "AsyncClient"):
self.client = client
self._completions = AsyncCompletions(self.client)


class AsyncCompletions(Completions):
async def create(self, model: str, messages: list, **kwargs):
"""
Create async chat completion based on the model, messages, and any extra arguments.
Supports automatic tool execution when max_turns is specified.
"""
# Check that correct format is used
if ":" not in model:
raise ValueError(
f"Invalid model format. Expected 'provider:model', got '{model}'"
)

# Extract the provider key from the model identifier, e.g., "google:gemini-xx"
provider_key, model_name = model.split(":", 1)

# Validate if the provider is supported
supported_providers = ProviderFactory.get_supported_providers()
if provider_key not in supported_providers:
raise ValueError(
f"Invalid provider key '{provider_key}'. Supported providers: {supported_providers}. "
"Make sure the model string is formatted correctly as 'provider:model'."
)

# Initialize provider if not already initialized
if provider_key not in self.client.providers:
config = self.client.provider_configs.get(provider_key, {})
self.client.providers[provider_key] = ProviderFactory.create_provider(
provider_key, config, is_async=True
)

provider = self.client.providers.get(provider_key)
if not provider:
raise ValueError(f"Could not load provider for '{provider_key}'.")

# Extract tool-related parameters
max_turns = kwargs.pop("max_turns", None)
tools = kwargs.get("tools", None)
automatic_tool_calling = kwargs.get("automatic_tool_calling", False)

# Check environment variable before allowing multi-turn tool execution
if max_turns is not None and tools is not None:
tool_runner = ToolRunner(provider, model_name, messages.copy(), tools, max_turns, automatic_tool_calling)
return await tool_runner.run_async(
provider,
model_name,
messages.copy(),
tools,
max_turns,
)

# Default behavior without tool execution
# Delegate the chat completion to the correct provider's async implementation
response = await provider.chat_completions_create_async(model_name, messages, **kwargs)
return self._extract_thinking_content(response)
67 changes: 67 additions & 0 deletions aisuite/base_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from .provider import ProviderFactory
from abc import ABC, abstractproperty

class BaseClient(ABC):
def __init__(self, provider_configs: dict = {}, is_async: bool = False):
"""
Initialize the client with provider configurations.
Use the ProviderFactory to create provider instances.

Args:
provider_configs (dict): A dictionary containing provider configurations.
Each key should be a provider string (e.g., "google" or "aws-bedrock"),
and the value should be a dictionary of configuration options for that provider.
For example:
{
"openai": {"api_key": "your_openai_api_key"},
"aws-bedrock": {
"aws_access_key": "your_aws_access_key",
"aws_secret_key": "your_aws_secret_key",
"aws_region": "us-west-2"
}
}
"""
self.providers = {}
self.provider_configs = provider_configs
self._chat = None
self._initialize_providers(is_async)


def _initialize_providers(self, is_async):
"""Helper method to initialize or update providers."""
for provider_key, config in self.provider_configs.items():
provider_key = self._validate_provider_key(provider_key)
self.providers[provider_key] = ProviderFactory.create_provider(
provider_key, config, is_async
)

def _validate_provider_key(self, provider_key):
"""
Validate if the provider key corresponds to a supported provider.
"""
supported_providers = ProviderFactory.get_supported_providers()

if provider_key not in supported_providers:
raise ValueError(
f"Invalid provider key '{provider_key}'. Supported providers: {supported_providers}. "
"Make sure the model string is formatted correctly as 'provider:model'."
)

return provider_key

def configure(self, provider_configs: dict = None, is_async: bool = False):
"""
Configure the client with provider configurations.
"""
if provider_configs is None:
return

self.provider_configs.update(provider_configs)
self._initialize_providers(is_async) # NOTE: This will override existing provider instances.


@property
@abstractproperty
def chat(self):
"""Return the chat API interface."""
raise NotImplementedError("Chat is not implemented for this client.")
126 changes: 8 additions & 118 deletions aisuite/client.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from .provider import ProviderFactory
from .base_client import BaseClient
import os
from .utils.tools import Tools
from .tool_runner import ToolRunner


class Client:
class Client(BaseClient):
def __init__(self, provider_configs: dict = {}):
"""
Initialize the client with provider configurations.
Expand All @@ -23,42 +25,10 @@ def __init__(self, provider_configs: dict = {}):
}
}
"""
self.providers = {}
self.provider_configs = provider_configs
self._chat = None
self._initialize_providers()

def _initialize_providers(self):
"""Helper method to initialize or update providers."""
for provider_key, config in self.provider_configs.items():
provider_key = self._validate_provider_key(provider_key)
self.providers[provider_key] = ProviderFactory.create_provider(
provider_key, config
)

def _validate_provider_key(self, provider_key):
"""
Validate if the provider key corresponds to a supported provider.
"""
supported_providers = ProviderFactory.get_supported_providers()

if provider_key not in supported_providers:
raise ValueError(
f"Invalid provider key '{provider_key}'. Supported providers: {supported_providers}. "
"Make sure the model string is formatted correctly as 'provider:model'."
)

return provider_key
super().__init__(provider_configs, is_async=False)

def configure(self, provider_configs: dict = None):
"""
Configure the client with provider configurations.
"""
if provider_configs is None:
return

self.provider_configs.update(provider_configs)
self._initialize_providers() # NOTE: This will override existing provider instances.
super().configure(provider_configs, False)

@property
def chat(self):
Expand Down Expand Up @@ -111,88 +81,6 @@ def _extract_thinking_content(self, response):

return response

def _tool_runner(
self,
provider,
model_name: str,
messages: list,
tools: any,
max_turns: int,
**kwargs,
):
"""
Handle tool execution loop for max_turns iterations.

Args:
provider: The provider instance to use for completions
model_name: Name of the model to use
messages: List of conversation messages
tools: Tools instance or list of callable tools
max_turns: Maximum number of tool execution turns
**kwargs: Additional arguments to pass to the provider

Returns:
The final response from the model with intermediate responses and messages
"""
# Handle tools validation and conversion
if isinstance(tools, Tools):
tools_instance = tools
kwargs["tools"] = tools_instance.tools()
else:
# Check if passed tools are callable
if not all(callable(tool) for tool in tools):
raise ValueError("One or more tools is not callable")
tools_instance = Tools(tools)
kwargs["tools"] = tools_instance.tools()

turns = 0
intermediate_responses = [] # Store intermediate responses
intermediate_messages = [] # Store all messages including tool interactions

while turns < max_turns:
# Make the API call
response = provider.chat_completions_create(model_name, messages, **kwargs)
response = self._extract_thinking_content(response)

# Store intermediate response
intermediate_responses.append(response)

# Check if there are tool calls in the response
tool_calls = (
getattr(response.choices[0].message, "tool_calls", None)
if hasattr(response, "choices")
else None
)

# Store the model's message
intermediate_messages.append(response.choices[0].message)

if not tool_calls:
# Set the intermediate data in the final response
response.intermediate_responses = intermediate_responses[
:-1
] # Exclude final response
response.choices[0].intermediate_messages = intermediate_messages
return response

# Execute tools and get results
results, tool_messages = tools_instance.execute_tool(tool_calls)

# Add tool messages to intermediate messages
intermediate_messages.extend(tool_messages)

# Add the assistant's response and tool results to messages
messages.extend([response.choices[0].message, *tool_messages])

turns += 1

# Set the intermediate data in the final response
response.intermediate_responses = intermediate_responses[
:-1
] # Exclude final response
response.choices[0].intermediate_messages = intermediate_messages
return response

def create(self, model: str, messages: list, **kwargs):
"""
Create chat completion based on the model, messages, and any extra arguments.
Expand Down Expand Up @@ -229,10 +117,12 @@ def create(self, model: str, messages: list, **kwargs):
# Extract tool-related parameters
max_turns = kwargs.pop("max_turns", None)
tools = kwargs.get("tools", None)
automatic_tool_calling = kwargs.get("automatic_tool_calling", False)

# Check environment variable before allowing multi-turn tool execution
if max_turns is not None and tools is not None:
return self._tool_runner(
tool_runner = ToolRunner(provider, model_name, messages.copy(), tools, max_turns, automatic_tool_calling)
return tool_runner.run(
provider,
model_name,
messages.copy(),
Expand Down
11 changes: 9 additions & 2 deletions aisuite/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,24 @@ def chat_completions_create(self, model, messages):
"""Abstract method for chat completion calls, to be implemented by each provider."""
pass

class AsyncProvider(ABC):
@abstractmethod
async def chat_completions_create_async(self, model, messages, **kwargs):
"""Method for async chat completion calls, to be implemented by each provider."""
raise NotImplementedError("Async chat completion calls are not implemented for this provider.")


class ProviderFactory:
"""Factory to dynamically load provider instances based on naming conventions."""

PROVIDERS_DIR = Path(__file__).parent / "providers"

@classmethod
def create_provider(cls, provider_key, config):
def create_provider(cls, provider_key, config, is_async=False):
"""Dynamically load and create an instance of a provider based on the naming convention."""
# Convert provider_key to the expected module and class names
provider_class_name = f"{provider_key.capitalize()}Provider"
async_suffix = "Async" if is_async else ""
provider_class_name = f"{provider_key.capitalize()}{async_suffix}Provider"
provider_module_name = f"{provider_key}_provider"

module_path = f"aisuite.providers.{provider_module_name}"
Expand Down
28 changes: 27 additions & 1 deletion aisuite/providers/anthropic_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import anthropic
import json
from aisuite.provider import Provider
from aisuite.provider import Provider, AsyncProvider
from aisuite.framework import ChatCompletionResponse
from aisuite.framework.message import Message, ChatCompletionMessageToolCall, Function

Expand Down Expand Up @@ -222,3 +222,29 @@ def _prepare_kwargs(self, kwargs):
kwargs["tools"] = self.converter.convert_tool_spec(kwargs["tools"])

return kwargs

class AnthropicAsyncProvider(AsyncProvider):
def __init__(self, **config):
"""Initialize the Anthropic provider with the given configuration."""
self.async_client = anthropic.AsyncAnthropic(**config)
self.converter = AnthropicMessageConverter()

async def chat_completions_create_async(self, model, messages, **kwargs):
"""Create a chat completion using the async Anthropic API."""
kwargs = self._prepare_kwargs(kwargs)
system_message, converted_messages = self.converter.convert_request(messages)

response = await self.async_client.messages.create(
model=model, system=system_message, messages=converted_messages, **kwargs
)
return self.converter.convert_response(response)

def _prepare_kwargs(self, kwargs):
"""Prepare kwargs for the API call."""
kwargs = kwargs.copy()
kwargs.setdefault("max_tokens", DEFAULT_MAX_TOKENS)

if "tools" in kwargs:
kwargs["tools"] = self.converter.convert_tool_spec(kwargs["tools"])

return kwargs
Loading