Skip to content

Ollama

OllamaEmbedding #

Bases: BaseEmbedding

Ollama嵌入的类。

Source code in llama_index/embeddings/ollama/base.py
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
class OllamaEmbedding(BaseEmbedding):
    """Ollama嵌入的类。"""

    base_url: str = Field(description="Base url the model is hosted by Ollama")
    model_name: str = Field(description="The Ollama model to use.")
    embed_batch_size: int = Field(
        default=DEFAULT_EMBED_BATCH_SIZE,
        description="The batch size for embedding calls.",
        gt=0,
        lte=2048,
    )
    ollama_additional_kwargs: Dict[str, Any] = Field(
        default_factory=dict, description="Additional kwargs for the Ollama API."
    )

    def __init__(
        self,
        model_name: str,
        base_url: str = "http://localhost:11434",
        embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
        ollama_additional_kwargs: Optional[Dict[str, Any]] = None,
        callback_manager: Optional[CallbackManager] = None,
    ) -> None:
        super().__init__(
            model_name=model_name,
            base_url=base_url,
            embed_batch_size=embed_batch_size,
            ollama_additional_kwargs=ollama_additional_kwargs or {},
            callback_manager=callback_manager,
        )

    @classmethod
    def class_name(cls) -> str:
        return "OllamaEmbedding"

    def _get_query_embedding(self, query: str) -> List[float]:
        """获取查询嵌入。"""
        return self.get_general_text_embedding(query)

    async def _aget_query_embedding(self, query: str) -> List[float]:
        """_get_query_embedding的异步版本。"""
        return self.get_general_text_embedding(query)

    def _get_text_embedding(self, text: str) -> List[float]:
        """获取文本嵌入。"""
        return self.get_general_text_embedding(text)

    async def _aget_text_embedding(self, text: str) -> List[float]:
        """异步获取文本嵌入。"""
        return self.get_general_text_embedding(text)

    def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
        """获取文本嵌入。"""
        embeddings_list: List[List[float]] = []
        for text in texts:
            embeddings = self.get_general_text_embedding(text)
            embeddings_list.append(embeddings)

        return embeddings_list

    async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
        """异步获取文本嵌入。"""
        return self._get_text_embeddings(texts)

    def get_general_text_embedding(self, prompt: str) -> List[float]:
        """获取Ollama嵌入。"""
        try:
            import requests
        except ImportError:
            raise ImportError(
                "Could not import requests library."
                "Please install requests with `pip install requests`"
            )

        ollama_request_body = {
            "prompt": prompt,
            "model": self.model_name,
            "options": self.ollama_additional_kwargs,
        }

        response = requests.post(
            url=f"{self.base_url}/api/embeddings",
            headers={"Content-Type": "application/json"},
            json=ollama_request_body,
        )
        response.encoding = "utf-8"
        if response.status_code != 200:
            optional_detail = response.json().get("error")
            raise ValueError(
                f"Ollama call failed with status code {response.status_code}."
                f" Details: {optional_detail}"
            )

        try:
            return response.json()["embedding"]
        except requests.exceptions.JSONDecodeError as e:
            raise ValueError(
                f"Error raised for Ollama Call: {e}.\nResponse: {response.text}"
            )

get_general_text_embedding #

get_general_text_embedding(prompt: str) -> List[float]

获取Ollama嵌入。

Source code in llama_index/embeddings/ollama/base.py
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def get_general_text_embedding(self, prompt: str) -> List[float]:
    """获取Ollama嵌入。"""
    try:
        import requests
    except ImportError:
        raise ImportError(
            "Could not import requests library."
            "Please install requests with `pip install requests`"
        )

    ollama_request_body = {
        "prompt": prompt,
        "model": self.model_name,
        "options": self.ollama_additional_kwargs,
    }

    response = requests.post(
        url=f"{self.base_url}/api/embeddings",
        headers={"Content-Type": "application/json"},
        json=ollama_request_body,
    )
    response.encoding = "utf-8"
    if response.status_code != 200:
        optional_detail = response.json().get("error")
        raise ValueError(
            f"Ollama call failed with status code {response.status_code}."
            f" Details: {optional_detail}"
        )

    try:
        return response.json()["embedding"]
    except requests.exceptions.JSONDecodeError as e:
        raise ValueError(
            f"Error raised for Ollama Call: {e}.\nResponse: {response.text}"
        )