Skip to content

Clip

ClipEmbedding #

Bases: MultiModalEmbedding

CLIP嵌入模型,用于对文本和图像进行多模态编码。

该类提供了一个接口,用于使用部署在OpenAI CLIP中的模型生成嵌入。在初始化时,需要提供一个CLIP模型的名称。

注意: 需要在PYTHONPATH中可用clip包。可以使用pip install git+https://github.com/openai/CLIP.git进行安装。

Source code in llama_index/embeddings/clip/base.py
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
class ClipEmbedding(MultiModalEmbedding):
    """CLIP嵌入模型,用于对文本和图像进行多模态编码。

    该类提供了一个接口,用于使用部署在OpenAI CLIP中的模型生成嵌入。在初始化时,需要提供一个CLIP模型的名称。

    注意:
        需要在PYTHONPATH中可用`clip`包。可以使用`pip install git+https://github.com/openai/CLIP.git`进行安装。"""

    embed_batch_size: int = Field(default=DEFAULT_EMBED_BATCH_SIZE, gt=0)

    _clip: Any = PrivateAttr()
    _model: Any = PrivateAttr()
    _preprocess: Any = PrivateAttr()
    _device: Any = PrivateAttr()

    @classmethod
    def class_name(cls) -> str:
        return "ClipEmbedding"

    def __init__(
        self,
        *,
        embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
        model_name: str = DEFAULT_CLIP_MODEL,
        **kwargs: Any,
    ):
        """初始化ClipEmbedding类。

在初始化过程中导入了`clip`包。

Args:
    embed_batch_size (int, optional): 用于生成嵌入的批处理大小。默认为10,必须大于0且小于等于100。
    model_name (str): Clip模型的模型名称。

引发:
    ImportError: 如果`clip`包在PYTHONPATH中不可用。
    ValueError: 如果无法从Open AI获取模型,或者embed_batch_size不在范围(0, 100]内。
"""
        if embed_batch_size <= 0:
            raise ValueError(f"Embed batch size {embed_batch_size}  must be > 0.")

        try:
            import clip
            import torch
        except ImportError:
            raise ImportError(
                "ClipEmbedding requires `pip install git+https://github.com/openai/CLIP.git` and torch."
            )

        super().__init__(
            embed_batch_size=embed_batch_size, model_name=model_name, **kwargs
        )

        try:
            self._device = "cuda" if torch.cuda.is_available() else "cpu"
            is_local_path = os.path.exists(self.model_name)
            if not is_local_path and self.model_name not in AVAILABLE_CLIP_MODELS:
                raise ValueError(
                    f"Model name {self.model_name} is not available in CLIP."
                )
            self._model, self._preprocess = clip.load(
                self.model_name, device=self._device
            )

        except Exception as e:
            logger.error("Error while loading clip model.")
            raise ValueError("Unable to fetch the requested embeddings model") from e

    # TEXT EMBEDDINGS

    async def _aget_query_embedding(self, query: str) -> Embedding:
        return self._get_query_embedding(query)

    def _get_text_embedding(self, text: str) -> Embedding:
        return self._get_text_embeddings([text])[0]

    def _get_text_embeddings(self, texts: List[str]) -> List[Embedding]:
        results = []
        for text in texts:
            try:
                import clip
            except ImportError:
                raise ImportError(
                    "ClipEmbedding requires `pip install git+https://github.com/openai/CLIP.git` and torch."
                )
            text_embedding = self._model.encode_text(
                clip.tokenize(text).to(self._device)
            )
            results.append(text_embedding.tolist()[0])

        return results

    def _get_query_embedding(self, query: str) -> Embedding:
        return self._get_text_embedding(query)

    # IMAGE EMBEDDINGS

    async def _aget_image_embedding(self, img_file_path: ImageType) -> Embedding:
        return self._get_image_embedding(img_file_path)

    def _get_image_embedding(self, img_file_path: ImageType) -> Embedding:
        import torch

        with torch.no_grad():
            image = (
                self._preprocess(Image.open(img_file_path))
                .unsqueeze(0)
                .to(self._device)
            )
            return self._model.encode_image(image).tolist()[0]