Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions libs/oci/langchain_oci/chat_models/oci_generative_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,8 +545,8 @@ def process_stream_tool_calls(
return tool_call_chunks


class MetaProvider(Provider):
"""Provider implementation for Meta."""
class GenericProvider(Provider):
"""Provider for models using generic API spec."""

stop_sequence_key: str = "stop"

Expand Down Expand Up @@ -934,6 +934,11 @@ def process_stream_tool_calls(
return tool_call_chunks


class MetaProvider(GenericProvider):
"""Provider for Meta models. This provider is for backward compatibility."""
pass


class ChatOCIGenAI(BaseChatModel, OCIGenAIBase):
"""ChatOCIGenAI chat model integration.

Expand Down Expand Up @@ -1011,6 +1016,11 @@ class ChatOCIGenAI(BaseChatModel, OCIGenAIBase):
def _llm_type(self) -> str:
"""Return the type of the language model."""
return "oci_generative_ai_chat"

@property
def _default_provider(self) -> Provider:
"""Default provider for the chat model."""
return GenericProvider()

@property
def _provider_map(self) -> Mapping[str, Provider]:
Expand All @@ -1023,7 +1033,10 @@ def _provider_map(self) -> Mapping[str, Provider]:
@property
def _provider(self) -> Any:
"""Get the internal provider object"""
return self._get_provider(provider_map=self._provider_map)
return self._get_provider(
provider_map=self._provider_map,
default=self._default_provider
)

def _prepare_request(
self,
Expand Down
23 changes: 20 additions & 3 deletions libs/oci/langchain_oci/llms/oci_generative_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def completion_response_to_text(self, response: Any) -> str:
return response.data.inference_response.generated_texts[0].text


class MetaProvider(Provider):
class GenericProvider(Provider):
"""Provider for models using generic API spec."""
stop_sequence_key: str = "stop"

def __init__(self) -> None:
Expand All @@ -50,6 +51,11 @@ def __init__(self) -> None:

def completion_response_to_text(self, response: Any) -> str:
return response.data.inference_response.choices[0].text


class MetaProvider(GenericProvider):
"""Provider for Meta models. This provider is for backward compatibility."""
pass


class OCIAuthType(Enum):
Expand Down Expand Up @@ -199,7 +205,7 @@ def _identifying_params(self) -> Mapping[str, Any]:
**{"model_kwargs": _model_kwargs},
}

def _get_provider(self, provider_map: Mapping[str, Any]) -> Any:
def _get_provider(self, provider_map: Mapping[str, Any], default: Provider) -> Any:
if self.provider is not None:
provider = self.provider
else:
Expand All @@ -212,11 +218,14 @@ def _get_provider(self, provider_map: Mapping[str, Any]) -> Any:
provider = self.model_id.split(".")[0].lower()

if provider not in provider_map:
if default:
return default
raise ValueError(
f"Invalid provider derived from model_id: {self.model_id} "
"Please explicitly pass in the supported provider "
"when using custom endpoint"
)
return GenericProvider()
return provider_map[provider]


Expand Down Expand Up @@ -262,6 +271,11 @@ class OCIGenAI(LLM, OCIGenAIBase):
def _llm_type(self) -> str:
"""Return type of llm."""
return "oci_generative_ai_completion"

@property
def _default_provider(self) -> Provider:
"""Default provider for the llm model."""
return GenericProvider()

@property
def _provider_map(self) -> Mapping[str, Any]:
Expand All @@ -274,7 +288,10 @@ def _provider_map(self) -> Mapping[str, Any]:
@property
def _provider(self) -> Any:
"""Get the internal provider object"""
return self._get_provider(provider_map=self._provider_map)
return self._get_provider(
provider_map=self._provider_map,
default=self._default_provider
)

def _prepare_invocation_object(
self, prompt: str, stop: Optional[List[str]], kwargs: Dict[str, Any]
Expand Down