Source code for langchain_core.output_parsers.xml

import re
import xml
import xml.etree.ElementTree as ET
from typing import Any, AsyncIterator, Dict, Iterator, List, Literal, Optional, Union
from xml.etree.ElementTree import TreeBuilder

from langchain_core.exceptions import OutputParserException
from langchain_core.messages import BaseMessage
from langchain_core.output_parsers.transform import BaseTransformOutputParser
from langchain_core.runnables.utils import AddableDict

XML_FORMAT_INSTRUCTIONS = """The output should be formatted as a XML file.
1. Output should conform to the tags below. 
2. If tags are not given, make them on your own.
3. Remember to always open and close all the tags.

As an example, for the tags ["foo", "bar", "baz"]:
1. String "<foo>\n   <bar>\n      <baz></baz>\n   </bar>\n</foo>" is a well-formatted instance of the schema. 
2. String "<foo>\n   <bar>\n   </foo>" is a badly-formatted instance.
3. String "<foo>\n   <tag>\n   </tag>\n</foo>" is a badly-formatted instance.

Here are the output tags:
```
{tags}
```"""  # noqa: E501


class _StreamingParser:
    """XML的流解析器。

将此实现放入一个类中,以避免在XMLOutputParser的transform和atransform之间出现实现漂移。
"""

    def __init__(self, parser: Literal["defusedxml", "xml"]) -> None:
        """初始化流解析器。

参数:
    parser:用于XML解析的解析器。可以是'defusedxml'或'xml'。有关更多信息,请参阅XMLOutputParser中的文档。
"""
        if parser == "defusedxml":
            try:
                from defusedxml import ElementTree as DET  # type: ignore
            except ImportError:
                raise ImportError(
                    "defusedxml is not installed. "
                    "Please install it to use the defusedxml parser."
                    "You can install it with `pip install defusedxml` "
                )
            _parser = DET.DefusedXMLParser(target=TreeBuilder())
        else:
            _parser = None
        self.pull_parser = ET.XMLPullParser(["start", "end"], _parser=_parser)
        self.xml_start_re = re.compile(r"<[a-zA-Z:_]")
        self.current_path: List[str] = []
        self.current_path_has_children = False
        self.buffer = ""
        self.xml_started = False

    def parse(self, chunk: Union[str, BaseMessage]) -> Iterator[AddableDict]:
        """解析一块文本。

参数:
    chunk: 需要解析的文本块。可以是字符串或BaseMessage。

产出:
    AddableDict: 表示解析的XML元素的字典。
"""
        if isinstance(chunk, BaseMessage):
            # extract text
            chunk_content = chunk.content
            if not isinstance(chunk_content, str):
                # ignore non-string messages (e.g., function calls)
                return
            chunk = chunk_content
        # add chunk to buffer of unprocessed text
        self.buffer += chunk
        # if xml string hasn't started yet, continue to next chunk
        if not self.xml_started:
            if match := self.xml_start_re.search(self.buffer):
                # if xml string has started, remove all text before it
                self.buffer = self.buffer[match.start() :]
                self.xml_started = True
            else:
                return
        # feed buffer to parser
        self.pull_parser.feed(self.buffer)
        self.buffer = ""
        # yield all events
        try:
            for event, elem in self.pull_parser.read_events():
                if event == "start":
                    # update current path
                    self.current_path.append(elem.tag)
                    self.current_path_has_children = False
                elif event == "end":
                    # remove last element from current path
                    #
                    self.current_path.pop()
                    # yield element
                    if not self.current_path_has_children:
                        yield nested_element(self.current_path, elem)
                    # prevent yielding of parent element
                    if self.current_path:
                        self.current_path_has_children = True
                    else:
                        self.xml_started = False
        except xml.etree.ElementTree.ParseError:
            # This might be junk at the end of the XML input.
            # Let's check whether the current path is empty.
            if not self.current_path:
                # If it is empty, we can ignore this error.
                return
            else:
                raise

    def close(self) -> None:
        """关闭解析器。"""
        try:
            self.pull_parser.close()
        except xml.etree.ElementTree.ParseError:
            # Ignore. This will ignore any incomplete XML at the end of the input
            pass


[docs]class XMLOutputParser(BaseTransformOutputParser): """使用xml格式解析输出。""" tags: Optional[List[str]] = None encoding_matcher: re.Pattern = re.compile( r"<([^>]*encoding[^>]*)>\n(.*)", re.MULTILINE | re.DOTALL ) parser: Literal["defusedxml", "xml"] = "defusedxml" """用于XML解析的解析器。可以是'defusedxml'或'xml'。 * 'defusedxml'是默认解析器,用于防止Python标准库xml中存在的XML漏洞。 `defusedxml`是标准库解析器的包装器,它使用安全默认设置来设置解析器。 * 'xml'是标准库解析器。 仅当您确定标准库的发行版不容易受到XML漏洞攻击时,才使用`xml`。 请查阅以下资源以获取更多信息: * https://docs.python.org/3/library/xml.html#xml-vulnerabilities * https://github.com/tiran/defusedxml 标准库依赖于libexpat来解析XML: https://github.com/libexpat/libexpat"""
[docs] def get_format_instructions(self) -> str: return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags)
[docs] def parse(self, text: str) -> Dict[str, Union[str, List[Any]]]: # Try to find XML string within triple backticks # Imports are temporarily placed here to avoid issue with caching on CI # likely if you're reading this you can move them to the top of the file if self.parser == "defusedxml": try: from defusedxml import ElementTree as DET # type: ignore except ImportError: raise ImportError( "defusedxml is not installed. " "Please install it to use the defusedxml parser." "You can install it with `pip install defusedxml`" "See https://github.com/tiran/defusedxml for more details" ) _ET = DET # Use the defusedxml parser else: _ET = ET # Use the standard library parser match = re.search(r"```(xml)?(.*)```", text, re.DOTALL) if match is not None: # If match found, use the content within the backticks text = match.group(2) encoding_match = self.encoding_matcher.search(text) if encoding_match: text = encoding_match.group(2) text = text.strip() try: root = ET.fromstring(text) return self._root_to_dict(root) except ET.ParseError as e: msg = f"Failed to parse XML format from completion {text}. Got: {e}" raise OutputParserException(msg, llm_output=text) from e
def _transform( self, input: Iterator[Union[str, BaseMessage]] ) -> Iterator[AddableDict]: streaming_parser = _StreamingParser(self.parser) for chunk in input: yield from streaming_parser.parse(chunk) streaming_parser.close() async def _atransform( self, input: AsyncIterator[Union[str, BaseMessage]] ) -> AsyncIterator[AddableDict]: streaming_parser = _StreamingParser(self.parser) async for chunk in input: for output in streaming_parser.parse(chunk): yield output streaming_parser.close() def _root_to_dict(self, root: ET.Element) -> Dict[str, Union[str, List[Any]]]: """将xml树转换为Python字典。""" if root.text and bool(re.search(r"\S", root.text)): # If root text contains any non-whitespace character it # returns {root.tag: root.text} return {root.tag: root.text} result: Dict = {root.tag: []} for child in root: if len(child) == 0: result[root.tag].append({child.tag: child.text}) else: result[root.tag].append(self._root_to_dict(child)) return result @property def _type(self) -> str: return "xml"
[docs]def nested_element(path: List[str], elem: ET.Element) -> Any: """从路径中获取嵌套元素。""" if len(path) == 0: return AddableDict({elem.tag: elem.text}) else: return AddableDict({path[0]: [nested_element(path[1:], elem)]})