31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144 | class EntityExtractor(BaseExtractor):
"""实体提取器。使用默认模型 `tomaarsen/span-marker-mbert-base-multinerd` 和 SpanMarker 库将 `entities` 提取到元数据字段中。
使用 `pip install span-marker` 安装 SpanMarker。"""
model_name: str = Field(
default=DEFAULT_ENTITY_MODEL,
description="The model name of the SpanMarker model to use.",
)
prediction_threshold: float = Field(
default=0.5,
description="The confidence threshold for accepting predictions.",
gte=0.0,
lte=1.0,
)
span_joiner: str = Field(
default=" ", description="The separator between entity names."
)
label_entities: bool = Field(
default=False, description="Include entity class labels or not."
)
device: Optional[str] = Field(
default=None, description="Device to run model on, i.e. 'cuda', 'cpu'"
)
entity_map: Dict[str, str] = Field(
default_factory=dict,
description="Mapping of entity class names to usable names.",
)
_tokenizer: Callable = PrivateAttr()
_model: Any = PrivateAttr()
def __init__(
self,
model_name: str = DEFAULT_ENTITY_MODEL,
prediction_threshold: float = 0.5,
span_joiner: str = " ",
label_entities: bool = False,
device: Optional[str] = None,
entity_map: Optional[Dict[str, str]] = None,
tokenizer: Optional[Callable[[str], List[str]]] = None,
**kwargs: Any,
):
"""实体提取器,用于从文本中提取实体并插入节点元数据。
Args:
model_name (str):
要使用的SpanMarker模型的名称。
prediction_threshold (float):
实体的最小预测阈值。默认为0.5。
span_joiner (str):
用于连接跨度的字符串。默认为" "。
label_entities (bool):
是否使用其类型标记实体。设置为true可能会略有错误,但对下游任务可能有用。默认为False。
device (Optional[str]):
用于SpanMarker模型的设备,即"cpu"或"cuda"。默认加载到"cpu"上。
entity_map (Optional[Dict[str, str]]):
从实体类名称到标签的映射。
tokenizer (Optional[Callable[[str], List[str]]):
用于将文本分割成单词的分词器。
默认为NLTK word_tokenize。
"""
self._model = SpanMarkerModel.from_pretrained(model_name)
if device is not None:
self._model = self._model.to(device)
self._tokenizer = tokenizer or word_tokenize
base_entity_map = DEFAULT_ENTITY_MAP
if entity_map is not None:
base_entity_map.update(entity_map)
super().__init__(
model_name=model_name,
prediction_threshold=prediction_threshold,
span_joiner=span_joiner,
label_entities=label_entities,
device=device,
entity_map=base_entity_map,
**kwargs,
)
@classmethod
def class_name(cls) -> str:
return "EntityExtractor"
async def aextract(self, nodes: Sequence[BaseNode]) -> List[Dict]:
# Extract node-level entity metadata
metadata_list: List[Dict] = [{} for _ in nodes]
metadata_queue: Iterable[int] = get_tqdm_iterable(
range(len(nodes)), self.show_progress, "Extracting entities"
)
for i in metadata_queue:
metadata = metadata_list[i]
node_text = nodes[i].get_content(metadata_mode=self.metadata_mode)
words = self._tokenizer(node_text)
spans = self._model.predict(words)
for span in spans:
if span["score"] > self.prediction_threshold:
ent_label = self.entity_map.get(span["label"], span["label"])
metadata_label = ent_label if self.label_entities else "entities"
if metadata_label not in metadata:
metadata[metadata_label] = set()
metadata[metadata_label].add(self.span_joiner.join(span["span"]))
# convert metadata from set to list
for metadata in metadata_list:
for key, val in metadata.items():
metadata[key] = list(val)
return metadata_list
|