Source code for langchain.smith.evaluation.progress

"""一个简单的控制台进度条。"""

import threading
from typing import Any, Dict, Optional, Sequence
from uuid import UUID

from langchain_core.callbacks import base as base_callbacks
from langchain_core.documents import Document
from langchain_core.outputs import LLMResult


[docs]class ProgressBarCallback(base_callbacks.BaseCallbackHandler): """一个简单的控制台进度条。"""
[docs] def __init__(self, total: int, ncols: int = 50, **kwargs: Any): """初始化进度条。 参数: total: 整数,要处理的总项目数。 ncols: 整数,进度条的字符宽度。 """ self.total = total self.ncols = ncols self.counter = 0 self.lock = threading.Lock() self._print_bar()
[docs] def increment(self) -> None: """增加计数器并更新进度条。""" with self.lock: self.counter += 1 self._print_bar()
def _print_bar(self) -> None: """在控制台打印进度条。""" progress = self.counter / self.total arrow = "-" * int(round(progress * self.ncols) - 1) + ">" spaces = " " * (self.ncols - len(arrow)) print(f"\r[{arrow + spaces}] {self.counter}/{self.total}", end="") # noqa: T201
[docs] def on_chain_error( self, error: BaseException, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: if parent_run_id is None: self.increment()
[docs] def on_chain_end( self, outputs: Dict[str, Any], *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: if parent_run_id is None: self.increment()
[docs] def on_retriever_error( self, error: BaseException, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: if parent_run_id is None: self.increment()
[docs] def on_retriever_end( self, documents: Sequence[Document], *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: if parent_run_id is None: self.increment()
[docs] def on_llm_error( self, error: BaseException, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: if parent_run_id is None: self.increment()
[docs] def on_llm_end( self, response: LLMResult, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: if parent_run_id is None: self.increment()
[docs] def on_tool_error( self, error: BaseException, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: if parent_run_id is None: self.increment()
[docs] def on_tool_end( self, output: str, *, run_id: UUID, parent_run_id: Optional[UUID] = None, **kwargs: Any, ) -> Any: if parent_run_id is None: self.increment()