Source code for langchain_community.llms.sparkllm

from __future__ import annotations

import base64
import hashlib
import hmac
import json
import logging
import queue
import threading
from datetime import datetime
from queue import Queue
from time import mktime
from typing import Any, Dict, Generator, Iterator, List, Optional
from urllib.parse import urlencode, urlparse, urlunparse
from wsgiref.handlers import format_date_time

from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.utils import get_from_dict_or_env

logger = logging.getLogger(__name__)


[docs]class SparkLLM(LLM): """iFlyTek Spark大型语言模型。 要使用,您应该将`app_id`、`api_key`、`api_secret`作为构造函数的命名参数传递,或者设置环境变量``IFLYTEK_SPARK_APP_ID``、``IFLYTEK_SPARK_API_KEY``和``IFLYTEK_SPARK_API_SECRET`` 示例: .. code-block:: python client = SparkLLM( spark_app_id="<app_id>", spark_api_key="<api_key>", spark_api_secret="<api_secret>" ) """ client: Any = None #: :meta private: spark_app_id: Optional[str] = None spark_api_key: Optional[str] = None spark_api_secret: Optional[str] = None spark_api_url: Optional[str] = None spark_llm_domain: Optional[str] = None spark_user_id: str = "lc_user" streaming: bool = False request_timeout: int = 30 temperature: float = 0.5 top_k: int = 4 model_kwargs: Dict[str, Any] = Field(default_factory=dict) @root_validator() def validate_environment(cls, values: Dict) -> Dict: values["spark_app_id"] = get_from_dict_or_env( values, "spark_app_id", "IFLYTEK_SPARK_APP_ID", ) values["spark_api_key"] = get_from_dict_or_env( values, "spark_api_key", "IFLYTEK_SPARK_API_KEY", ) values["spark_api_secret"] = get_from_dict_or_env( values, "spark_api_secret", "IFLYTEK_SPARK_API_SECRET", ) values["spark_api_url"] = get_from_dict_or_env( values, "spark_api_url", "IFLYTEK_SPARK_API_URL", "wss://spark-api.xf-yun.com/v3.1/chat", ) values["spark_llm_domain"] = get_from_dict_or_env( values, "spark_llm_domain", "IFLYTEK_SPARK_LLM_DOMAIN", "generalv3", ) # put extra params into model_kwargs values["model_kwargs"]["temperature"] = values["temperature"] or cls.temperature values["model_kwargs"]["top_k"] = values["top_k"] or cls.top_k values["client"] = _SparkLLMClient( app_id=values["spark_app_id"], api_key=values["spark_api_key"], api_secret=values["spark_api_secret"], api_url=values["spark_api_url"], spark_domain=values["spark_llm_domain"], model_kwargs=values["model_kwargs"], ) return values @property def _llm_type(self) -> str: """llm的返回类型。""" return "spark-llm-chat" @property def _default_params(self) -> Dict[str, Any]: """获取调用SparkLLM API的默认参数。""" normal_params = { "spark_llm_domain": self.spark_llm_domain, "stream": self.streaming, "request_timeout": self.request_timeout, "top_k": self.top_k, "temperature": self.temperature, } return {**normal_params, **self.model_kwargs} def _call( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: """为每一代使用一个提示来调用sparkllm。 参数: prompt: 传递给模型的提示。 stop: 生成时可选的停止词列表。 返回: sparkllm生成的字符串。 示例: .. code-block:: python response = client("Tell me a joke.") """ if self.streaming: completion = "" for chunk in self._stream(prompt, stop, run_manager, **kwargs): completion += chunk.text return completion completion = "" self.client.arun( [{"role": "user", "content": prompt}], self.spark_user_id, self.model_kwargs, self.streaming, ) for content in self.client.subscribe(timeout=self.request_timeout): if "data" not in content: continue completion = content["data"]["content"] return completion def _stream( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[GenerationChunk]: self.client.run( [{"role": "user", "content": prompt}], self.spark_user_id, self.model_kwargs, self.streaming, ) for content in self.client.subscribe(timeout=self.request_timeout): if "data" not in content: continue delta = content["data"] if run_manager: run_manager.on_llm_new_token(delta) yield GenerationChunk(text=delta["content"])
class _SparkLLMClient: """使用websocket-client调用讯飞提供的SparkLLM接口,该接口是讯飞科技的人工智能能力开放平台。""" def __init__( self, app_id: str, api_key: str, api_secret: str, api_url: Optional[str] = None, spark_domain: Optional[str] = None, model_kwargs: Optional[dict] = None, ): try: import websocket self.websocket_client = websocket except ImportError: raise ImportError( "Could not import websocket client python package. " "Please install it with `pip install websocket-client`." ) self.api_url = ( "wss://spark-api.xf-yun.com/v3.1/chat" if not api_url else api_url ) self.app_id = app_id self.model_kwargs = model_kwargs self.spark_domain = spark_domain or "generalv3" self.queue: Queue[Dict] = Queue() self.blocking_message = {"content": "", "role": "assistant"} self.api_key = api_key self.api_secret = api_secret @staticmethod def _create_url(api_url: str, api_key: str, api_secret: str) -> str: """ 使用API密钥和API秘钥生成一个请求URL。 """ # generate timestamp by RFC1123 date = format_date_time(mktime(datetime.now().timetuple())) # urlparse parsed_url = urlparse(api_url) host = parsed_url.netloc path = parsed_url.path signature_origin = f"host: {host}\ndate: {date}\nGET {path} HTTP/1.1" # encrypt using hmac-sha256 signature_sha = hmac.new( api_secret.encode("utf-8"), signature_origin.encode("utf-8"), digestmod=hashlib.sha256, ).digest() signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding="utf-8") authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", \ headers="host date request-line", signature="{signature_sha_base64}"' authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode( encoding="utf-8" ) # generate url params_dict = {"authorization": authorization, "date": date, "host": host} encoded_params = urlencode(params_dict) url = urlunparse( ( parsed_url.scheme, parsed_url.netloc, parsed_url.path, parsed_url.params, encoded_params, parsed_url.fragment, ) ) return url def run( self, messages: List[Dict], user_id: str, model_kwargs: Optional[dict] = None, streaming: bool = False, ) -> None: self.websocket_client.enableTrace(False) ws = self.websocket_client.WebSocketApp( _SparkLLMClient._create_url( self.api_url, self.api_key, self.api_secret, ), on_message=self.on_message, on_error=self.on_error, on_close=self.on_close, on_open=self.on_open, ) ws.messages = messages # type: ignore[attr-defined] ws.user_id = user_id # type: ignore[attr-defined] ws.model_kwargs = self.model_kwargs if model_kwargs is None else model_kwargs # type: ignore[attr-defined] ws.streaming = streaming # type: ignore[attr-defined] ws.run_forever() def arun( self, messages: List[Dict], user_id: str, model_kwargs: Optional[dict] = None, streaming: bool = False, ) -> threading.Thread: ws_thread = threading.Thread( target=self.run, args=( messages, user_id, model_kwargs, streaming, ), ) ws_thread.start() return ws_thread def on_error(self, ws: Any, error: Optional[Any]) -> None: self.queue.put({"error": error}) ws.close() def on_close(self, ws: Any, close_status_code: int, close_reason: str) -> None: logger.debug( { "log": { "close_status_code": close_status_code, "close_reason": close_reason, } } ) self.queue.put({"done": True}) def on_open(self, ws: Any) -> None: self.blocking_message = {"content": "", "role": "assistant"} data = json.dumps( self.gen_params( messages=ws.messages, user_id=ws.user_id, model_kwargs=ws.model_kwargs ) ) ws.send(data) def on_message(self, ws: Any, message: str) -> None: data = json.loads(message) code = data["header"]["code"] if code != 0: self.queue.put( {"error": f"Code: {code}, Error: {data['header']['message']}"} ) ws.close() else: choices = data["payload"]["choices"] status = choices["status"] content = choices["text"][0]["content"] if ws.streaming: self.queue.put({"data": choices["text"][0]}) else: self.blocking_message["content"] += content if status == 2: if not ws.streaming: self.queue.put({"data": self.blocking_message}) usage_data = ( data.get("payload", {}).get("usage", {}).get("text", {}) if data else {} ) self.queue.put({"usage": usage_data}) ws.close() def gen_params( self, messages: list, user_id: str, model_kwargs: Optional[dict] = None ) -> dict: data: Dict = { "header": {"app_id": self.app_id, "uid": user_id}, "parameter": {"chat": {"domain": self.spark_domain}}, "payload": {"message": {"text": messages}}, } if model_kwargs: data["parameter"]["chat"].update(model_kwargs) logger.debug(f"Spark Request Parameters: {data}") return data def subscribe(self, timeout: Optional[int] = 30) -> Generator[Dict, None, None]: while True: try: content = self.queue.get(timeout=timeout) except queue.Empty as _: raise TimeoutError( f"SparkLLMClient wait LLM api response timeout {timeout} seconds" ) if "error" in content: raise ConnectionError(content["error"]) if "usage" in content: yield content continue if "done" in content: break if "data" not in content: break yield content