import json
import re
import uuid
from abc import ABC, abstractmethod
from operator import itemgetter
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Literal,
Mapping,
Optional,
Sequence,
Type,
Union,
)
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import (
BaseChatModel,
generate_from_stream,
)
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
ToolCall,
ToolMessage,
)
from langchain_core.messages.tool import ToolCallChunk
from langchain_core.output_parsers import (
JsonOutputParser,
PydanticOutputParser,
)
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
PydanticToolsParser,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_function
from pydantic import BaseModel, ConfigDict
from langchain_community.llms.oci_generative_ai import OCIGenAIBase
from langchain_community.llms.utils import enforce_stop_tokens
CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint"
JSON_TO_PYTHON_TYPES = {
"string": "str",
"number": "float",
"boolean": "bool",
"integer": "int",
"array": "List",
"object": "Dict",
"any": "any",
}
def _is_pydantic_class(obj: Any) -> bool:
return isinstance(obj, type) and issubclass(obj, BaseModel)
def _remove_signature_from_tool_description(name: str, description: str) -> str:
"""
Removes the `{name}{signature} - ` prefix and Args: section from tool description.
The signature is usually present for tools created with the @tool decorator,
whereas the Args: section may be present in function doc blocks.
"""
description = re.sub(rf"^{name}\(.*?\) -(?:> \w+? -)? ", "", description)
description = re.sub(r"(?s)(?:\n?\n\s*?)?Args:.*$", "", description)
return description
def _format_oci_tool_calls(
tool_calls: Optional[List[Any]] = None,
) -> List[Dict]:
"""
Formats a OCI GenAI API response into the tool call format used in Langchain.
"""
if not tool_calls:
return []
formatted_tool_calls = []
for tool_call in tool_calls:
formatted_tool_calls.append(
{
"id": uuid.uuid4().hex[:],
"function": {
"name": tool_call.name,
"arguments": json.dumps(tool_call.parameters),
},
"type": "function",
}
)
return formatted_tool_calls
def _convert_oci_tool_call_to_langchain(tool_call: Any) -> ToolCall:
"""Convert a OCI GenAI tool call into langchain_core.messages.ToolCall"""
_id = uuid.uuid4().hex[:]
return ToolCall(name=tool_call.name, args=tool_call.parameters, id=_id)
[docs]
class Provider(ABC):
@property
@abstractmethod
def stop_sequence_key(self) -> str: ...
[docs]
@abstractmethod
def chat_response_to_text(self, response: Any) -> str: ...
[docs]
@abstractmethod
def chat_stream_to_text(self, event_data: Dict) -> str: ...
[docs]
@abstractmethod
def is_chat_stream_end(self, event_data: Dict) -> bool: ...
[docs]
@abstractmethod
def chat_generation_info(self, response: Any) -> Dict[str, Any]: ...
[docs]
@abstractmethod
def chat_stream_generation_info(self, event_data: Dict) -> Dict[str, Any]: ...
[docs]
@abstractmethod
def get_role(self, message: BaseMessage) -> str: ...
[docs]
@abstractmethod
def messages_to_oci_params(
self, messages: Any, **kwargs: Any
) -> Dict[str, Any]: ...
self,
tool: Union[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
) -> Dict[str, Any]: ...
[docs]
class CohereProvider(Provider):
stop_sequence_key: str = "stop_sequences"
[docs]
def __init__(self) -> None:
from oci.generative_ai_inference import models
self.oci_chat_request = models.CohereChatRequest
self.oci_tool = models.CohereTool
self.oci_tool_param = models.CohereParameterDefinition
self.oci_tool_result = models.CohereToolResult
self.oci_tool_call = models.CohereToolCall
self.oci_chat_message = {
"USER": models.CohereUserMessage,
"CHATBOT": models.CohereChatBotMessage,
"SYSTEM": models.CohereSystemMessage,
"TOOL": models.CohereToolMessage,
}
self.chat_api_format = models.BaseChatRequest.API_FORMAT_COHERE
[docs]
def chat_response_to_text(self, response: Any) -> str:
return response.data.chat_response.text
[docs]
def chat_stream_to_text(self, event_data: Dict) -> str:
if "text" in event_data:
if "finishReason" in event_data or "toolCalls" in event_data:
return ""
else:
return event_data["text"]
else:
return ""
[docs]
def is_chat_stream_end(self, event_data: Dict) -> bool:
return "finishReason" in event_data
[docs]
def chat_generation_info(self, response: Any) -> Dict[str, Any]:
generation_info: Dict[str, Any] = {
"documents": response.data.chat_response.documents,
"citations": response.data.chat_response.citations,
"search_queries": response.data.chat_response.search_queries,
"is_search_required": response.data.chat_response.is_search_required,
"finish_reason": response.data.chat_response.finish_reason,
}
if response.data.chat_response.tool_calls:
# Only populate tool_calls when 1) present on the response and
# 2) has one or more calls.
generation_info["tool_calls"] = _format_oci_tool_calls(
response.data.chat_response.tool_calls
)
return generation_info
[docs]
def chat_stream_generation_info(self, event_data: Dict) -> Dict[str, Any]:
generation_info: Dict[str, Any] = {
"documents": event_data.get("documents"),
"citations": event_data.get("citations"),
"finish_reason": event_data.get("finishReason"),
}
if "toolCalls" in event_data:
generation_info["tool_calls"] = []
for tool_call in event_data["toolCalls"]:
generation_info["tool_calls"].append(
{
"id": uuid.uuid4().hex[:],
"function": {
"name": tool_call["name"],
"arguments": json.dumps(tool_call["parameters"]),
},
"type": "function",
}
)
generation_info = {k: v for k, v in generation_info.items() if v is not None}
return generation_info
[docs]
def get_role(self, message: BaseMessage) -> str:
if isinstance(message, HumanMessage):
return "USER"
elif isinstance(message, AIMessage):
return "CHATBOT"
elif isinstance(message, SystemMessage):
return "SYSTEM"
elif isinstance(message, ToolMessage):
return "TOOL"
else:
raise ValueError(f"Got unknown type {message}")
[docs]
def messages_to_oci_params(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> Dict[str, Any]:
is_force_single_step = kwargs.get("is_force_single_step") or False
oci_chat_history = []
for msg in messages[:-1]:
if self.get_role(msg) == "USER" or self.get_role(msg) == "SYSTEM":
oci_chat_history.append(
self.oci_chat_message[self.get_role(msg)](message=msg.content)
)
elif isinstance(msg, AIMessage):
if msg.tool_calls and is_force_single_step:
continue
tool_calls = (
[
self.oci_tool_call(name=tc["name"], parameters=tc["args"])
for tc in msg.tool_calls
]
if msg.tool_calls
else None
)
msg_content = msg.content if msg.content else " "
oci_chat_history.append(
self.oci_chat_message[self.get_role(msg)](
message=msg_content, tool_calls=tool_calls
)
)
# Get the messages for the current chat turn
current_chat_turn_messages = []
for message in messages[::-1]:
current_chat_turn_messages.append(message)
if isinstance(message, HumanMessage):
break
current_chat_turn_messages = current_chat_turn_messages[::-1]
oci_tool_results: Union[List[Any], None] = []
for message in current_chat_turn_messages:
if isinstance(message, ToolMessage):
tool_message = message
previous_ai_msgs = [
message
for message in current_chat_turn_messages
if isinstance(message, AIMessage) and message.tool_calls
]
if previous_ai_msgs:
previous_ai_msg = previous_ai_msgs[-1]
for lc_tool_call in previous_ai_msg.tool_calls:
if lc_tool_call["id"] == tool_message.tool_call_id:
tool_result = self.oci_tool_result()
tool_result.call = self.oci_tool_call(
name=lc_tool_call["name"],
parameters=lc_tool_call["args"],
)
tool_result.outputs = [{"output": tool_message.content}]
oci_tool_results.append(tool_result)
if not oci_tool_results:
oci_tool_results = None
message_str = "" if oci_tool_results else messages[-1].content
oci_params = {
"message": message_str,
"chat_history": oci_chat_history,
"tool_results": oci_tool_results,
"api_format": self.chat_api_format,
}
return {k: v for k, v in oci_params.items() if v is not None}
[docs]
class ChatOCIGenAI(BaseChatModel, OCIGenAIBase):
"""ChatOCIGenAI chat model integration.
Setup:
Install ``langchain-community`` and the ``oci`` sdk.
.. code-block:: bash
pip install -U langchain-community oci
Key init args β completion params:
model_id: str
Id of the OCIGenAI chat model to use, e.g., cohere.command-r-16k.
is_stream: bool
Whether to stream back partial progress
model_kwargs: Optional[Dict]
Keyword arguments to pass to the specific model used, e.g., temperature, max_tokens.
Key init args β client params:
service_endpoint: str
The endpoint URL for the OCIGenAI service, e.g., https://inference.generativeai.us-chicago-1.oci.oraclecloud.com.
compartment_id: str
The compartment OCID.
auth_type: str
The authentication type to use, e.g., API_KEY (default), SECURITY_TOKEN, INSTANCE_PRINCIPAL, RESOURCE_PRINCIPAL.
auth_profile: Optional[str]
The name of the profile in ~/.oci/config, if not specified , DEFAULT will be used.
provider: str
Provider name of the model. Default to None, will try to be derived from the model_id otherwise, requires user input.
See full list of supported init args and their descriptions in the params section.
Instantiate:
.. code-block:: python
from langchain_community.chat_models import ChatOCIGenAI
chat = ChatOCIGenAI(
model_id="cohere.command-r-16k",
service_endpoint="https://inference.generativeai.us-chicago-1.oci.oraclecloud.com",
compartment_id="MY_OCID",
model_kwargs={"temperature": 0.7, "max_tokens": 500},
)
Invoke:
.. code-block:: python
messages = [
SystemMessage(content="your are an AI assistant."),
AIMessage(content="Hi there human!"),
HumanMessage(content="tell me a joke."),
]
response = chat.invoke(messages)
Stream:
.. code-block:: python
for r in chat.stream(messages):
print(r.content, end="", flush=True)
Response metadata
.. code-block:: python
response = chat.invoke(messages)
print(response.response_metadata)
""" # noqa: E501
model_config = ConfigDict(
extra="forbid",
arbitrary_types_allowed=True,
)
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "oci_generative_ai_chat"
@property
def _provider_map(self) -> Mapping[str, Any]:
"""Get the provider map"""
return {
"cohere": CohereProvider(),
"meta": MetaProvider(),
}
@property
def _provider(self) -> Any:
"""Get the internal provider object"""
return self._get_provider(provider_map=self._provider_map)
def _prepare_request(
self,
messages: List[BaseMessage],
stop: Optional[List[str]],
stream: bool,
**kwargs: Any,
) -> Dict[str, Any]:
try:
from oci.generative_ai_inference import models
except ImportError as ex:
raise ModuleNotFoundError(
"Could not import oci python package. "
"Please make sure you have the oci package installed."
) from ex
oci_params = self._provider.messages_to_oci_params(messages, **kwargs)
oci_params["is_stream"] = stream
_model_kwargs = self.model_kwargs or {}
if stop is not None:
_model_kwargs[self._provider.stop_sequence_key] = stop
chat_params = {**_model_kwargs, **kwargs, **oci_params}
if not self.model_id:
raise ValueError("Model ID is required to chat")
if self.model_id.startswith(CUSTOM_ENDPOINT_PREFIX):
serving_mode = models.DedicatedServingMode(endpoint_id=self.model_id)
else:
serving_mode = models.OnDemandServingMode(model_id=self.model_id)
request = models.ChatDetails(
compartment_id=self.compartment_id,
serving_mode=serving_mode,
chat_request=self._provider.oci_chat_request(**chat_params),
)
return request
[docs]
def with_structured_output(
self,
schema: Optional[Union[Dict, Type[BaseModel]]] = None,
*,
method: Literal["function_calling", "json_mode"] = "function_calling",
include_raw: bool = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
"""Model wrapper that returns outputs formatted to match the given schema.
Args:
schema: The output schema as a dict or a Pydantic class. If a Pydantic class
then the model output will be an object of that class. If a dict then
the model output will be a dict. With a Pydantic class the returned
attributes will be validated, whereas with a dict they will not be. If
`method` is "function_calling" and `schema` is a dict, then the dict
must match the OCI Generative AI function-calling spec.
method:
The method for steering model generation, either "function_calling"
or "json_mode". If "function_calling" then the schema will be converted
to an OCI function and the returned model will make use of the
function-calling API. If "json_mode" then Cohere's JSON mode will be
used. Note that if using "json_mode" then you must include instructions
for formatting the output into the desired schema into the model call.
include_raw:
If False then only the parsed structured output is returned. If
an error occurs during model output parsing it will be raised. If True
then both the raw model response (a BaseMessage) and the parsed model
response will be returned. If an error occurs during output parsing it
will be caught and returned as well. The final output is always a dict
with keys "raw", "parsed", and "parsing_error".
Returns:
A Runnable that takes any ChatModel input and returns as output:
If include_raw is True then a dict with keys:
raw: BaseMessage
parsed: Optional[_DictOrPydantic]
parsing_error: Optional[BaseException]
If include_raw is False then just _DictOrPydantic is returned,
where _DictOrPydantic depends on the schema:
If schema is a Pydantic class then _DictOrPydantic is the Pydantic
class.
If schema is a dict then _DictOrPydantic is a dict.
""" # noqa: E501
if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}")
is_pydantic_schema = _is_pydantic_class(schema)
if method == "function_calling":
if schema is None:
raise ValueError(
"schema must be specified when method is 'function_calling'. "
"Received None."
)
llm = self.bind_tools([schema], **kwargs)
tool_name = getattr(self._provider.convert_to_oci_tool(schema), "name")
if is_pydantic_schema:
output_parser: OutputParserLike = PydanticToolsParser(
tools=[schema], # type: ignore[list-item]
first_tool_only=True, # type: ignore[list-item]
)
else:
output_parser = JsonOutputKeyToolsParser(
key_name=tool_name, first_tool_only=True
)
elif method == "json_mode":
llm = self.bind(response_format={"type": "json_object"})
output_parser = (
PydanticOutputParser(pydantic_object=schema) # type: ignore[type-var, arg-type]
if is_pydantic_schema
else JsonOutputParser()
)
else:
raise ValueError(
f"Unrecognized method argument. "
f"Expected `function_calling` or `json_mode`."
f"Received: `{method}`."
)
if include_raw:
parser_assign = RunnablePassthrough.assign(
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
)
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
parser_with_fallback = parser_assign.with_fallbacks(
[parser_none], exception_key="parsing_error"
)
return RunnableMap(raw=llm) | parser_with_fallback
else:
return llm | output_parser
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Call out to a OCIGenAI chat model.
Args:
messages: list of LangChain messages
stop: Optional list of stop words to use.
Returns:
LangChain ChatResult
Example:
.. code-block:: python
messages = [
HumanMessage(content="hello!"),
AIMessage(content="Hi there human!"),
HumanMessage(content="Meow!")
]
response = llm.invoke(messages)
"""
if self.is_stream:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
request = self._prepare_request(messages, stop=stop, stream=False, **kwargs)
response = self.client.chat(request)
content = self._provider.chat_response_to_text(response)
if stop is not None:
content = enforce_stop_tokens(content, stop)
generation_info = self._provider.chat_generation_info(response)
llm_output = {
"model_id": response.data.model_id,
"model_version": response.data.model_version,
"request_id": response.request_id,
"content-length": response.headers["content-length"],
}
if "tool_calls" in generation_info:
tool_calls = [
_convert_oci_tool_call_to_langchain(tool_call)
for tool_call in response.data.chat_response.tool_calls
]
else:
tool_calls = []
message = AIMessage(
content=content,
additional_kwargs=generation_info,
tool_calls=tool_calls,
)
return ChatResult(
generations=[
ChatGeneration(message=message, generation_info=generation_info)
],
llm_output=llm_output,
)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
request = self._prepare_request(messages, stop=stop, stream=True, **kwargs)
response = self.client.chat(request)
for event in response.data.events():
event_data = json.loads(event.data)
if not self._provider.is_chat_stream_end(event_data): # still streaming
delta = self._provider.chat_stream_to_text(event_data)
chunk = ChatGenerationChunk(message=AIMessageChunk(content=delta))
if run_manager:
run_manager.on_llm_new_token(delta, chunk=chunk)
yield chunk
else: # stream end
generation_info = self._provider.chat_stream_generation_info(event_data)
tool_call_chunks = []
if tool_calls := generation_info.get("tool_calls"):
content = self._provider.chat_stream_to_text(event_data)
try:
tool_call_chunks = [
ToolCallChunk(
name=tool_call["function"].get("name"),
args=tool_call["function"].get("arguments"),
id=tool_call.get("id"),
index=tool_call.get("index"),
)
for tool_call in tool_calls
]
except KeyError:
pass
else:
content = ""
message = AIMessageChunk(
content=content,
additional_kwargs=generation_info,
tool_call_chunks=tool_call_chunks,
)
yield ChatGenerationChunk(
message=message,
generation_info=generation_info,
)