Skip to content

Entity

EntityExtractor #

Bases: BaseExtractor

实体提取器。使用默认模型 tomaarsen/span-marker-mbert-base-multinerd 和 SpanMarker 库将 entities 提取到元数据字段中。

使用 pip install span-marker 安装 SpanMarker。

Source code in llama_index/extractors/entity/base.py
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
class EntityExtractor(BaseExtractor):
    """实体提取器。使用默认模型 `tomaarsen/span-marker-mbert-base-multinerd` 和 SpanMarker 库将 `entities` 提取到元数据字段中。

使用 `pip install span-marker` 安装 SpanMarker。"""

    model_name: str = Field(
        default=DEFAULT_ENTITY_MODEL,
        description="The model name of the SpanMarker model to use.",
    )
    prediction_threshold: float = Field(
        default=0.5,
        description="The confidence threshold for accepting predictions.",
        gte=0.0,
        lte=1.0,
    )
    span_joiner: str = Field(
        default=" ", description="The separator between entity names."
    )
    label_entities: bool = Field(
        default=False, description="Include entity class labels or not."
    )
    device: Optional[str] = Field(
        default=None, description="Device to run model on, i.e. 'cuda', 'cpu'"
    )
    entity_map: Dict[str, str] = Field(
        default_factory=dict,
        description="Mapping of entity class names to usable names.",
    )

    _tokenizer: Callable = PrivateAttr()
    _model: Any = PrivateAttr()

    def __init__(
        self,
        model_name: str = DEFAULT_ENTITY_MODEL,
        prediction_threshold: float = 0.5,
        span_joiner: str = " ",
        label_entities: bool = False,
        device: Optional[str] = None,
        entity_map: Optional[Dict[str, str]] = None,
        tokenizer: Optional[Callable[[str], List[str]]] = None,
        **kwargs: Any,
    ):
        """实体提取器,用于从文本中提取实体并插入节点元数据。

Args:
    model_name (str):
        要使用的SpanMarker模型的名称。
    prediction_threshold (float):
        实体的最小预测阈值。默认为0.5。
    span_joiner (str):
        用于连接跨度的字符串。默认为" "。
    label_entities (bool):
        是否使用其类型标记实体。设置为true可能会略有错误,但对下游任务可能有用。默认为False。
    device (Optional[str]):
        用于SpanMarker模型的设备,即"cpu"或"cuda"。默认加载到"cpu"上。
    entity_map (Optional[Dict[str, str]]):
        从实体类名称到标签的映射。
    tokenizer (Optional[Callable[[str], List[str]]):
        用于将文本分割成单词的分词器。
        默认为NLTK word_tokenize。
"""
        self._model = SpanMarkerModel.from_pretrained(model_name)
        if device is not None:
            self._model = self._model.to(device)

        self._tokenizer = tokenizer or word_tokenize

        base_entity_map = DEFAULT_ENTITY_MAP
        if entity_map is not None:
            base_entity_map.update(entity_map)

        super().__init__(
            model_name=model_name,
            prediction_threshold=prediction_threshold,
            span_joiner=span_joiner,
            label_entities=label_entities,
            device=device,
            entity_map=base_entity_map,
            **kwargs,
        )

    @classmethod
    def class_name(cls) -> str:
        return "EntityExtractor"

    async def aextract(self, nodes: Sequence[BaseNode]) -> List[Dict]:
        # Extract node-level entity metadata
        metadata_list: List[Dict] = [{} for _ in nodes]
        metadata_queue: Iterable[int] = get_tqdm_iterable(
            range(len(nodes)), self.show_progress, "Extracting entities"
        )

        for i in metadata_queue:
            metadata = metadata_list[i]
            node_text = nodes[i].get_content(metadata_mode=self.metadata_mode)
            words = self._tokenizer(node_text)
            spans = self._model.predict(words)
            for span in spans:
                if span["score"] > self.prediction_threshold:
                    ent_label = self.entity_map.get(span["label"], span["label"])
                    metadata_label = ent_label if self.label_entities else "entities"

                    if metadata_label not in metadata:
                        metadata[metadata_label] = set()

                    metadata[metadata_label].add(self.span_joiner.join(span["span"]))

        # convert metadata from set to list
        for metadata in metadata_list:
            for key, val in metadata.items():
                metadata[key] = list(val)

        return metadata_list