Bases: BaseEmbedding
适配任何嵌入模型的适配器。
这是对任何嵌入模型的包装器,它在其上方添加了一个适配器层。
这对于在下游任务中微调嵌入模型非常有用。
嵌入模型可以是任何模型 - 它不需要暴露梯度。
Parameters:
Name |
Type |
Description |
Default |
base_embed_model |
BaseEmbedding
|
|
required
|
adapter_path |
str
|
|
required
|
adapter_cls |
Optional[Type[Any]]
|
|
None
|
transform_query |
bool
|
|
True
|
device |
Optional[str]
|
|
None
|
embed_batch_size |
int
|
|
DEFAULT_EMBED_BATCH_SIZE
|
callback_manager |
Optional[CallbackManager]
|
|
None
|
Source code in llama_index/embeddings/adapter/base.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107 | class AdapterEmbeddingModel(BaseEmbedding):
"""适配任何嵌入模型的适配器。
这是对任何嵌入模型的包装器,它在其上方添加了一个适配器层。
这对于在下游任务中微调嵌入模型非常有用。
嵌入模型可以是任何模型 - 它不需要暴露梯度。
Args:
base_embed_model (BaseEmbedding): 基础嵌入模型。
adapter_path (str): 适配器路径。
adapter_cls (Optional[Type[Any]]): 适配器类。默认为None,此时使用线性适配器。
transform_query (bool): 是否转换查询嵌入。默认为True。
device (Optional[str]): 要使用的设备。默认为None。
embed_batch_size (int): 嵌入的批处理大小。默认为10。
callback_manager (Optional[CallbackManager]): 回调管理器。默认为None。
"""
_base_embed_model: BaseEmbedding = PrivateAttr()
_adapter: Any = PrivateAttr()
_transform_query: bool = PrivateAttr()
_device: Optional[str] = PrivateAttr()
_target_device: Any = PrivateAttr()
def __init__(
self,
base_embed_model: BaseEmbedding,
adapter_path: str,
adapter_cls: Optional[Type[Any]] = None,
transform_query: bool = True,
device: Optional[str] = None,
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
callback_manager: Optional[CallbackManager] = None,
) -> None:
"""初始化参数。"""
import torch
from llama_index.embeddings.adapter.utils import BaseAdapter, LinearLayer
if device is None:
device = infer_torch_device()
logger.info(f"Use pytorch device: {device}")
self._target_device = torch.device(device)
self._base_embed_model = base_embed_model
if adapter_cls is None:
adapter_cls = LinearLayer
else:
adapter_cls = cast(Type[BaseAdapter], adapter_cls)
adapter = adapter_cls.load(adapter_path)
self._adapter = cast(BaseAdapter, adapter)
self._adapter.to(self._target_device)
self._transform_query = transform_query
super().__init__(
embed_batch_size=embed_batch_size,
callback_manager=callback_manager,
model_name=f"Adapter for {base_embed_model.model_name}",
)
@classmethod
def class_name(cls) -> str:
return "AdapterEmbeddingModel"
def _get_query_embedding(self, query: str) -> List[float]:
"""获取查询嵌入。"""
import torch
query_embedding = self._base_embed_model._get_query_embedding(query)
if self._transform_query:
query_embedding_t = torch.tensor(query_embedding).to(self._target_device)
query_embedding_t = self._adapter.forward(query_embedding_t)
query_embedding = query_embedding_t.tolist()
return query_embedding
async def _aget_query_embedding(self, query: str) -> List[float]:
"""获取查询嵌入。"""
import torch
query_embedding = await self._base_embed_model._aget_query_embedding(query)
if self._transform_query:
query_embedding_t = torch.tensor(query_embedding).to(self._target_device)
query_embedding_t = self._adapter.forward(query_embedding_t)
query_embedding = query_embedding_t.tolist()
return query_embedding
def _get_text_embedding(self, text: str) -> List[float]:
return self._base_embed_model._get_text_embedding(text)
async def _aget_text_embedding(self, text: str) -> List[float]:
return await self._base_embed_model._aget_text_embedding(text)
|