Skip to content

Adapter

LinearAdapterEmbeddingModel module-attribute #

LinearAdapterEmbeddingModel = AdapterEmbeddingModel

AdapterEmbeddingModel #

Bases: BaseEmbedding

适配任何嵌入模型的适配器。

这是对任何嵌入模型的包装器,它在其上方添加了一个适配器层。 这对于在下游任务中微调嵌入模型非常有用。 嵌入模型可以是任何模型 - 它不需要暴露梯度。

Parameters:

Name Type Description Default
base_embed_model BaseEmbedding

基础嵌入模型。

required
adapter_path str

适配器路径。

required
adapter_cls Optional[Type[Any]]

适配器类。默认为None,此时使用线性适配器。

None
transform_query bool

是否转换查询嵌入。默认为True。

True
device Optional[str]

要使用的设备。默认为None。

None
embed_batch_size int

嵌入的批处理大小。默认为10。

DEFAULT_EMBED_BATCH_SIZE
callback_manager Optional[CallbackManager]

回调管理器。默认为None。

None
Source code in llama_index/embeddings/adapter/base.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
class AdapterEmbeddingModel(BaseEmbedding):
    """适配任何嵌入模型的适配器。

    这是对任何嵌入模型的包装器,它在其上方添加了一个适配器层。
    这对于在下游任务中微调嵌入模型非常有用。
    嵌入模型可以是任何模型 - 它不需要暴露梯度。

    Args:
        base_embed_model (BaseEmbedding): 基础嵌入模型。
        adapter_path (str): 适配器路径。
        adapter_cls (Optional[Type[Any]]): 适配器类。默认为None,此时使用线性适配器。
        transform_query (bool): 是否转换查询嵌入。默认为True。
        device (Optional[str]): 要使用的设备。默认为None。
        embed_batch_size (int): 嵌入的批处理大小。默认为10。
        callback_manager (Optional[CallbackManager]): 回调管理器。默认为None。
"""

    _base_embed_model: BaseEmbedding = PrivateAttr()
    _adapter: Any = PrivateAttr()
    _transform_query: bool = PrivateAttr()
    _device: Optional[str] = PrivateAttr()
    _target_device: Any = PrivateAttr()

    def __init__(
        self,
        base_embed_model: BaseEmbedding,
        adapter_path: str,
        adapter_cls: Optional[Type[Any]] = None,
        transform_query: bool = True,
        device: Optional[str] = None,
        embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
        callback_manager: Optional[CallbackManager] = None,
    ) -> None:
        """初始化参数。"""
        import torch
        from llama_index.embeddings.adapter.utils import BaseAdapter, LinearLayer

        if device is None:
            device = infer_torch_device()
            logger.info(f"Use pytorch device: {device}")
        self._target_device = torch.device(device)

        self._base_embed_model = base_embed_model

        if adapter_cls is None:
            adapter_cls = LinearLayer
        else:
            adapter_cls = cast(Type[BaseAdapter], adapter_cls)

        adapter = adapter_cls.load(adapter_path)
        self._adapter = cast(BaseAdapter, adapter)
        self._adapter.to(self._target_device)

        self._transform_query = transform_query
        super().__init__(
            embed_batch_size=embed_batch_size,
            callback_manager=callback_manager,
            model_name=f"Adapter for {base_embed_model.model_name}",
        )

    @classmethod
    def class_name(cls) -> str:
        return "AdapterEmbeddingModel"

    def _get_query_embedding(self, query: str) -> List[float]:
        """获取查询嵌入。"""
        import torch

        query_embedding = self._base_embed_model._get_query_embedding(query)
        if self._transform_query:
            query_embedding_t = torch.tensor(query_embedding).to(self._target_device)
            query_embedding_t = self._adapter.forward(query_embedding_t)
            query_embedding = query_embedding_t.tolist()

        return query_embedding

    async def _aget_query_embedding(self, query: str) -> List[float]:
        """获取查询嵌入。"""
        import torch

        query_embedding = await self._base_embed_model._aget_query_embedding(query)
        if self._transform_query:
            query_embedding_t = torch.tensor(query_embedding).to(self._target_device)
            query_embedding_t = self._adapter.forward(query_embedding_t)
            query_embedding = query_embedding_t.tolist()

        return query_embedding

    def _get_text_embedding(self, text: str) -> List[float]:
        return self._base_embed_model._get_text_embedding(text)

    async def _aget_text_embedding(self, text: str) -> List[float]:
        return await self._base_embed_model._aget_text_embedding(text)