diff --git a/src/huggingface_hub/inference/_client.py b/src/huggingface_hub/inference/_client.py index 5e39dee55c..97997aeb88 100644 --- a/src/huggingface_hub/inference/_client.py +++ b/src/huggingface_hub/inference/_client.py @@ -130,7 +130,7 @@ class InferenceClient: Note: for better compatibility with OpenAI's client, `model` has been aliased as `base_url`. Those 2 arguments are mutually exclusive. If a URL is passed as `model` or `base_url` for chat completion, the `(/v1)/chat/completions` suffix path will be appended to the URL. provider (`str`, *optional*): - Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"replicate"`, "sambanova"` or `"together"`. + Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `publicai`, `"replicate"`, "sambanova"` or `"together"`. Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers. If model is a URL or `base_url` is passed, then `provider` is not used. token (`str`, *optional*): diff --git a/src/huggingface_hub/inference/_generated/_async_client.py b/src/huggingface_hub/inference/_generated/_async_client.py index 95eaf3e7e5..398dbe7fc1 100644 --- a/src/huggingface_hub/inference/_generated/_async_client.py +++ b/src/huggingface_hub/inference/_generated/_async_client.py @@ -118,7 +118,7 @@ class AsyncInferenceClient: Note: for better compatibility with OpenAI's client, `model` has been aliased as `base_url`. Those 2 arguments are mutually exclusive. If a URL is passed as `model` or `base_url` for chat completion, the `(/v1)/chat/completions` suffix path will be appended to the URL. provider (`str`, *optional*): - Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `"replicate"`, "sambanova"` or `"together"`. + Name of the provider to use for inference. Can be `"black-forest-labs"`, `"cerebras"`, `"cohere"`, `"fal-ai"`, `"featherless-ai"`, `"fireworks-ai"`, `"groq"`, `"hf-inference"`, `"hyperbolic"`, `"nebius"`, `"novita"`, `"nscale"`, `"openai"`, `publicai`, `"replicate"`, "sambanova"` or `"together"`. Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers. If model is a URL or `base_url` is passed, then `provider` is not used. token (`str`, *optional*): diff --git a/src/huggingface_hub/inference/_providers/__init__.py b/src/huggingface_hub/inference/_providers/__init__.py index ec4866c30d..ce9406c74f 100644 --- a/src/huggingface_hub/inference/_providers/__init__.py +++ b/src/huggingface_hub/inference/_providers/__init__.py @@ -36,6 +36,7 @@ from .novita import NovitaConversationalTask, NovitaTextGenerationTask, NovitaTextToVideoTask from .nscale import NscaleConversationalTask, NscaleTextToImageTask from .openai import OpenAIConversationalTask +from .publicai import PublicAIConversationalTask from .replicate import ReplicateImageToImageTask, ReplicateTask, ReplicateTextToImageTask, ReplicateTextToSpeechTask from .sambanova import SambanovaConversationalTask, SambanovaFeatureExtractionTask from .together import TogetherConversationalTask, TogetherTextGenerationTask, TogetherTextToImageTask @@ -58,6 +59,7 @@ "novita", "nscale", "openai", + "publicai", "replicate", "sambanova", "together", @@ -144,6 +146,9 @@ "openai": { "conversational": OpenAIConversationalTask(), }, + "publicai": { + "conversational": PublicAIConversationalTask(), + }, "replicate": { "image-to-image": ReplicateImageToImageTask(), "text-to-image": ReplicateTextToImageTask(), diff --git a/src/huggingface_hub/inference/_providers/_common.py b/src/huggingface_hub/inference/_providers/_common.py index 687464a934..ae2a25fbdb 100644 --- a/src/huggingface_hub/inference/_providers/_common.py +++ b/src/huggingface_hub/inference/_providers/_common.py @@ -31,6 +31,7 @@ "hyperbolic": {}, "nebius": {}, "nscale": {}, + "publicai": {}, "replicate": {}, "sambanova": {}, "together": {}, diff --git a/src/huggingface_hub/inference/_providers/publicai.py b/src/huggingface_hub/inference/_providers/publicai.py new file mode 100644 index 0000000000..4c88528e4f --- /dev/null +++ b/src/huggingface_hub/inference/_providers/publicai.py @@ -0,0 +1,6 @@ +from ._common import BaseConversationalTask + + +class PublicAIConversationalTask(BaseConversationalTask): + def __init__(self): + super().__init__(provider="publicai", base_url="https://api.publicai.co") diff --git a/tests/test_inference_client.py b/tests/test_inference_client.py index cf384db0d1..76ece5ff3f 100644 --- a/tests/test_inference_client.py +++ b/tests/test_inference_client.py @@ -118,6 +118,9 @@ "text-generation": "NousResearch/Nous-Hermes-Llama2-13b", "conversational": "meta-llama/Llama-3.1-8B-Instruct", }, + "publicai": { + "conversational": "swiss-ai/Apertus-8B-Instruct-2509", + }, "replicate": { "text-to-image": "ByteDance/SDXL-Lightning", }, diff --git a/tests/test_inference_providers.py b/tests/test_inference_providers.py index 333eb57d33..5dc7dcd145 100644 --- a/tests/test_inference_providers.py +++ b/tests/test_inference_providers.py @@ -44,6 +44,7 @@ from huggingface_hub.inference._providers.novita import NovitaConversationalTask, NovitaTextGenerationTask from huggingface_hub.inference._providers.nscale import NscaleConversationalTask, NscaleTextToImageTask from huggingface_hub.inference._providers.openai import OpenAIConversationalTask +from huggingface_hub.inference._providers.publicai import PublicAIConversationalTask from huggingface_hub.inference._providers.replicate import ( ReplicateImageToImageTask, ReplicateTask, @@ -1140,6 +1141,15 @@ def test_text_to_image_get_response(self): assert response == b"image_bytes" +class TestPublicAIProvider: + def test_prepare_url(self): + helper = PublicAIConversationalTask() + assert ( + helper._prepare_url("publicai_token", "username/repo_name") + == "https://api.publicai.co/v1/chat/completions" + ) + + class TestOpenAIProvider: def test_prepare_url(self): helper = OpenAIConversationalTask()