"""Callback Handler that prints to streamlit."""
from __future__ import annotations
from enum import Enum
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult
from langchain_community.callbacks.streamlit.mutable_expander import MutableExpander
if TYPE_CHECKING:
from streamlit.delta_generator import DeltaGenerator
def _convert_newlines(text: str) -> str:
"""Convert newline characters to markdown newline sequences
(space, space, newline).
"""
return text.replace("\n", " \n")
CHECKMARK_EMOJI = "β
"
THINKING_EMOJI = ":thinking_face:"
HISTORY_EMOJI = ":books:"
EXCEPTION_EMOJI = "β οΈ"
[docs]
class LLMThoughtState(Enum):
"""Enumerator of the LLMThought state."""
# The LLM is thinking about what to do next. We don't know which tool we'll run.
THINKING = "THINKING"
# The LLM has decided to run a tool. We don't have results from the tool yet.
RUNNING_TOOL = "RUNNING_TOOL"
# We have results from the tool.
COMPLETE = "COMPLETE"
[docs]
class LLMThoughtLabeler:
"""
Generates markdown labels for LLMThought containers. Pass a custom
subclass of this to StreamlitCallbackHandler to override its default
labeling logic.
"""
[docs]
@staticmethod
def get_initial_label() -> str:
"""Return the markdown label for a new LLMThought that doesn't have
an associated tool yet.
"""
return f"{THINKING_EMOJI} **Thinking...**"
[docs]
@staticmethod
def get_history_label() -> str:
"""Return a markdown label for the special 'history' container
that contains overflow thoughts.
"""
return f"{HISTORY_EMOJI} **History**"
[docs]
@staticmethod
def get_final_agent_thought_label() -> str:
"""Return the markdown label for the agent's final thought -
the "Now I have the answer" thought, that doesn't involve
a tool.
"""
return f"{CHECKMARK_EMOJI} **Complete!**"
[docs]
class LLMThought:
"""A thought in the LLM's thought stream."""
[docs]
def __init__(
self,
parent_container: DeltaGenerator,
labeler: LLMThoughtLabeler,
expanded: bool,
collapse_on_complete: bool,
):
"""Initialize the LLMThought.
Args:
parent_container: The container we're writing into.
labeler: The labeler to use for this thought.
expanded: Whether the thought should be expanded by default.
collapse_on_complete: Whether the thought should be collapsed.
"""
self._container = MutableExpander(
parent_container=parent_container,
label=labeler.get_initial_label(),
expanded=expanded,
)
self._state = LLMThoughtState.THINKING
self._llm_token_stream = ""
self._llm_token_writer_idx: Optional[int] = None
self._last_tool: Optional[ToolRecord] = None
self._collapse_on_complete = collapse_on_complete
self._labeler = labeler
@property
def container(self) -> MutableExpander:
"""The container we're writing into."""
return self._container
@property
def last_tool(self) -> Optional[ToolRecord]:
"""The last tool executed by this thought"""
return self._last_tool
def _reset_llm_token_stream(self) -> None:
self._llm_token_stream = ""
self._llm_token_writer_idx = None
[docs]
def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str]) -> None:
self._reset_llm_token_stream()
[docs]
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
# This is only called when the LLM is initialized with `streaming=True`
self._llm_token_stream += _convert_newlines(token)
self._llm_token_writer_idx = self._container.markdown(
self._llm_token_stream, index=self._llm_token_writer_idx
)
[docs]
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
# `response` is the concatenation of all the tokens received by the LLM.
# If we're receiving streaming tokens from `on_llm_new_token`, this response
# data is redundant
self._reset_llm_token_stream()
[docs]
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
self._container.markdown("**LLM encountered an error...**")
self._container.exception(error)
[docs]
def on_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
) -> Any:
# Called when we're about to kick off a new tool. The `action` data
# tells us the tool we're about to use, and the input we'll give it.
# We don't output anything here, because we'll receive this same data
# when `on_tool_start` is called immediately after.
pass
[docs]
def complete(self, final_label: Optional[str] = None) -> None:
"""Finish the thought."""
if final_label is None and self._state == LLMThoughtState.RUNNING_TOOL:
assert (
self._last_tool is not None
), "_last_tool should never be null when _state == RUNNING_TOOL"
final_label = self._labeler.get_tool_label(
self._last_tool, is_complete=True
)
self._state = LLMThoughtState.COMPLETE
if self._collapse_on_complete:
self._container.update(new_label=final_label, new_expanded=False)
else:
self._container.update(new_label=final_label)
[docs]
def clear(self) -> None:
"""Remove the thought from the screen. A cleared thought can't be reused."""
self._container.clear()
[docs]
class StreamlitCallbackHandler(BaseCallbackHandler):
"""Callback handler that writes to a Streamlit app."""
[docs]
def __init__(
self,
parent_container: DeltaGenerator,
*,
max_thought_containers: int = 4,
expand_new_thoughts: bool = True,
collapse_completed_thoughts: bool = True,
thought_labeler: Optional[LLMThoughtLabeler] = None,
):
"""Create a StreamlitCallbackHandler instance.
Parameters
----------
parent_container
The `st.container` that will contain all the Streamlit elements that the
Handler creates.
max_thought_containers
The max number of completed LLM thought containers to show at once. When
this threshold is reached, a new thought will cause the oldest thoughts to
be collapsed into a "History" expander. Defaults to 4.
expand_new_thoughts
Each LLM "thought" gets its own `st.expander`. This param controls whether
that expander is expanded by default. Defaults to True.
collapse_completed_thoughts
If True, LLM thought expanders will be collapsed when completed.
Defaults to True.
thought_labeler
An optional custom LLMThoughtLabeler instance. If unspecified, the handler
will use the default thought labeling logic. Defaults to None.
"""
self._parent_container = parent_container
self._history_parent = parent_container.container()
self._history_container: Optional[MutableExpander] = None
self._current_thought: Optional[LLMThought] = None
self._completed_thoughts: List[LLMThought] = []
self._max_thought_containers = max(max_thought_containers, 1)
self._expand_new_thoughts = expand_new_thoughts
self._collapse_completed_thoughts = collapse_completed_thoughts
self._thought_labeler = thought_labeler or LLMThoughtLabeler()
def _require_current_thought(self) -> LLMThought:
"""Return our current LLMThought. Raise an error if we have no current
thought.
"""
if self._current_thought is None:
raise RuntimeError("Current LLMThought is unexpectedly None!")
return self._current_thought
def _get_last_completed_thought(self) -> Optional[LLMThought]:
"""Return our most recent completed LLMThought, or None if we don't have one."""
if len(self._completed_thoughts) > 0:
return self._completed_thoughts[len(self._completed_thoughts) - 1]
return None
@property
def _num_thought_containers(self) -> int:
"""The number of 'thought containers' we're currently showing: the
number of completed thought containers, the history container (if it exists),
and the current thought container (if it exists).
"""
count = len(self._completed_thoughts)
if self._history_container is not None:
count += 1
if self._current_thought is not None:
count += 1
return count
def _complete_current_thought(self, final_label: Optional[str] = None) -> None:
"""Complete the current thought, optionally assigning it a new label.
Add it to our _completed_thoughts list.
"""
thought = self._require_current_thought()
thought.complete(final_label)
self._completed_thoughts.append(thought)
self._current_thought = None
def _prune_old_thought_containers(self) -> None:
"""If we have too many thoughts onscreen, move older thoughts to the
'history container.'
"""
while (
self._num_thought_containers > self._max_thought_containers
and len(self._completed_thoughts) > 0
):
# Create our history container if it doesn't exist, and if
# max_thought_containers is > 1. (if max_thought_containers is 1, we don't
# have room to show history.)
if self._history_container is None and self._max_thought_containers > 1:
self._history_container = MutableExpander(
self._history_parent,
label=self._thought_labeler.get_history_label(),
expanded=False,
)
oldest_thought = self._completed_thoughts.pop(0)
if self._history_container is not None:
self._history_container.markdown(oldest_thought.container.label)
self._history_container.append_copy(oldest_thought.container)
oldest_thought.clear()
[docs]
def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
if self._current_thought is None:
self._current_thought = LLMThought(
parent_container=self._parent_container,
expanded=self._expand_new_thoughts,
collapse_on_complete=self._collapse_completed_thoughts,
labeler=self._thought_labeler,
)
self._current_thought.on_llm_start(serialized, prompts)
# We don't prune_old_thought_containers here, because our container won't
# be visible until it has a child.
[docs]
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
self._require_current_thought().on_llm_new_token(token, **kwargs)
self._prune_old_thought_containers()
[docs]
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
self._require_current_thought().on_llm_end(response, **kwargs)
self._prune_old_thought_containers()
[docs]
def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
self._require_current_thought().on_llm_error(error, **kwargs)
self._prune_old_thought_containers()
[docs]
def on_text(
self,
text: str,
color: Optional[str] = None,
end: str = "",
**kwargs: Any,
) -> None:
pass
[docs]
def on_chain_start(
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
) -> None:
pass
[docs]
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
pass
[docs]
def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
pass
[docs]
def on_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
) -> Any:
self._require_current_thought().on_agent_action(action, color, **kwargs)
self._prune_old_thought_containers()
[docs]
def on_agent_finish(
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
) -> None:
if self._current_thought is not None:
self._current_thought.complete(
self._thought_labeler.get_final_agent_thought_label()
)
self._current_thought = None