Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Currently supported providers are:
- Google
- Groq
- HuggingFace Ollama
- LMStudio
- Mistral
- OpenAI
- Sambanova
Expand Down
46 changes: 46 additions & 0 deletions aisuite/providers/lmstudio_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import lmstudio

from aisuite.framework import ChatCompletionResponse
from aisuite.provider import LLMError, Provider


class LmstudioProvider(Provider):
def __init__(self, **config):
self.client = lmstudio.Client(**config)
self.model: lmstudio.LLM | None = None

def _chat(self, model: str, messages, **kwargs) -> lmstudio.PredictionResult[str]:
"""
Makes a request to the lmstudio chat completions endpoint using the official client
"""
# get an handle of the specified model (load it if necessary)
model = self.client.llm.model(model)
# send the request to the model
result = model.respond({"messages": messages}, **kwargs)

return result

def chat_completions_create(self, model, messages, **kwargs):
"""
Makes a request to the lmstudio chat completions endpoint using the official client
and convert output to be conform to openAI model response
"""
try:
# --- send request to lmstudio endpoint
result = self._chat(model=model, messages=messages, **kwargs)

# --- Return the normalized response
normalized_response = self._normalize_response(result)
return normalized_response

# Wrap all other exceptions in LLMError.
except Exception as e:
raise LLMError("An error occurred.") from e

def _normalize_response(self, response_data: lmstudio.PredictionResult):
"""
Normalize the lmstudio response to a common format (ChatCompletionResponse).
"""
normalized_response = ChatCompletionResponse()
normalized_response.choices[0].message.content = response_data.content
return normalized_response
8 changes: 6 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ authors = ["Andrew Ng, Rohit P"]
readme = "README.md"

[tool.poetry.dependencies]
python = "^3.10"
python = "^3.11"
anthropic = { version = "^0.30.1", optional = true }
boto3 = { version = "^1.34.144", optional = true }
cohere = { version = "^5.12.0", optional = true }
Expand All @@ -17,9 +17,11 @@ openai = { version = "^1.35.8", optional = true }
ibm-watsonx-ai = { version = "^1.1.16", optional = true }
docstring-parser = { version = "^0.14.0", optional = true }
cerebras_cloud_sdk = { version = "^1.19.0", optional = true }
lmstudio = { version = "^1.0.1", optional = true }

# Optional dependencies for different providers
httpx = "~0.27.0"

[tool.poetry.extras]
anthropic = ["anthropic"]
aws = ["boto3"]
Expand All @@ -30,11 +32,12 @@ deepseek = ["openai"]
google = ["vertexai"]
groq = ["groq"]
huggingface = []
lmstudio = ["lmstudio"]
mistral = ["mistralai"]
ollama = []
openai = ["openai"]
watsonx = ["ibm-watsonx-ai"]
all = ["anthropic", "aws", "cerebras_cloud_sdk", "google", "groq", "mistral", "openai", "cohere", "watsonx"] # To install all providers
all = ["anthropic", "aws", "cerebras_cloud_sdk", "google", "groq", "mistral", "openai", "cohere", "watsonx", "lmstudio"] # To install all providers

[tool.poetry.group.dev.dependencies]
pre-commit = "^3.7.1"
Expand All @@ -54,6 +57,7 @@ datasets = "^2.20.0"
vertexai = "^1.63.0"
ibm-watsonx-ai = "^1.1.16"
cerebras_cloud_sdk = "^1.19.0"
lmstudio = "^1.0.1"

[tool.poetry.group.test]
optional = true
Expand Down
38 changes: 38 additions & 0 deletions tests/providers/test_lmstudio_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from dataclasses import dataclass
from unittest.mock import patch

from aisuite.providers.lmstudio_provider import LmstudioProvider


@dataclass
class MockResponse:
content: str


def test_lmstudio_provider():
"""High-level test that the provider is initialized and chat completions are requested successfully."""

user_greeting = "Hello!"
message_history = [{"role": "user", "content": user_greeting}]
selected_model = "our-favorite-model"
config = {
"temperature": 0.6,
"maxTokens": 5000,
}
response_text_content = "mocked-text-response-from-model"

provider = LmstudioProvider()
mock_response = MockResponse(response_text_content)

with patch.object(provider, "_chat", return_value=mock_response) as mock_create:
response = provider.chat_completions_create(
messages=message_history,
model=selected_model,
config=config,
)

mock_create.assert_called_with(
model=selected_model, messages=message_history, config=config
)

assert response.choices[0].message.content == response_text_content