Skip to content

Commit 1085b9c

Browse files
authored
Add GenericProvider (#14)
* Use MetaProvider for xai and openai models. * Add GenericProvider as default provider for models. * Use generic provider as default only if the model_id is not custom endpoint. * Remove default_provider property. * Update the logic for checking model_id in get_provider().
1 parent 6c1015a commit 1085b9c

File tree

2 files changed

+26
-10
lines changed

2 files changed

+26
-10
lines changed

libs/oci/langchain_oci/chat_models/oci_generative_ai.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -545,8 +545,8 @@ def process_stream_tool_calls(
545545
return tool_call_chunks
546546

547547

548-
class MetaProvider(Provider):
549-
"""Provider implementation for Meta."""
548+
class GenericProvider(Provider):
549+
"""Provider for models using generic API spec."""
550550

551551
stop_sequence_key: str = "stop"
552552

@@ -934,6 +934,11 @@ def process_stream_tool_calls(
934934
return tool_call_chunks
935935

936936

937+
class MetaProvider(GenericProvider):
938+
"""Provider for Meta models. This provider is for backward compatibility."""
939+
pass
940+
941+
937942
class ChatOCIGenAI(BaseChatModel, OCIGenAIBase):
938943
"""ChatOCIGenAI chat model integration.
939944
@@ -1018,6 +1023,7 @@ def _provider_map(self) -> Mapping[str, Provider]:
10181023
return {
10191024
"cohere": CohereProvider(),
10201025
"meta": MetaProvider(),
1026+
"generic": GenericProvider(),
10211027
}
10221028

10231029
@property

libs/oci/langchain_oci/llms/oci_generative_ai.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ def completion_response_to_text(self, response: Any) -> str:
4040
return response.data.inference_response.generated_texts[0].text
4141

4242

43-
class MetaProvider(Provider):
43+
class GenericProvider(Provider):
44+
"""Provider for models using generic API spec."""
4445
stop_sequence_key: str = "stop"
4546

4647
def __init__(self) -> None:
@@ -50,6 +51,11 @@ def __init__(self) -> None:
5051

5152
def completion_response_to_text(self, response: Any) -> str:
5253
return response.data.inference_response.choices[0].text
54+
55+
56+
class MetaProvider(GenericProvider):
57+
"""Provider for Meta models. This provider is for backward compatibility."""
58+
pass
5359

5460

5561
class OCIAuthType(Enum):
@@ -202,14 +208,17 @@ def _identifying_params(self) -> Mapping[str, Any]:
202208
def _get_provider(self, provider_map: Mapping[str, Any]) -> Any:
203209
if self.provider is not None:
204210
provider = self.provider
211+
elif self.model_id is None:
212+
raise ValueError(
213+
"model_id is required to derive the provider, "
214+
"please provide the provider explicitly or specify "
215+
"the model_id to derive the provider."
216+
)
217+
elif self.model_id.startswith(CUSTOM_ENDPOINT_PREFIX):
218+
raise ValueError("provider is required for custom endpoints.")
205219
else:
206-
if self.model_id is None:
207-
raise ValueError(
208-
"model_id is required to derive the provider, "
209-
"please provide the provider explicitly or specify "
210-
"the model_id to derive the provider."
211-
)
212-
provider = self.model_id.split(".")[0].lower()
220+
221+
provider = provider_map.get(self.model_id.split(".")[0].lower(), "generic")
213222

214223
if provider not in provider_map:
215224
raise ValueError(
@@ -269,6 +278,7 @@ def _provider_map(self) -> Mapping[str, Any]:
269278
return {
270279
"cohere": CohereProvider(),
271280
"meta": MetaProvider(),
281+
"generic": GenericProvider(),
272282
}
273283

274284
@property

0 commit comments

Comments
 (0)