Skip to content

行为定制

所有Vanna函数都继承自VannaBase类。这是一个抽象基类,为所有Vanna函数提供基本功能。根据您选择的配置的具体情况,实现位于继承自该基类的类中。

您可以选择通过创建一个直接从VannaBase继承或从其继承的类之一继承的类来自定义Vanna的行为。如果您想更改特定函数的行为或添加新功能,这将非常有用。

类实例化

这是一个示例,展示了当配置为使用OpenAI API和ChromaDB向量存储时,类是如何实例化的。

要自定义行为的具体细节,您可以在实例化MyVanna类时覆盖基类中的任何方法。

from vanna.openai.openai_chat import OpenAI_Chat
from vanna.chromadb.chromadb_vector import ChromaDB_VectorStore

class MyVanna(ChromaDB_VectorStore, OpenAI_Chat):
    def __init__(self, config=None):
        ChromaDB_VectorStore.__init__(self, config=config)
        OpenAI_Chat.__init__(self, config=config)

vn = MyVanna(config={'api_key': 'sk-...', 'model': 'gpt-4-...'})

重写特定函数

这里有一个如何重写 is_sql_valid 函数的示例。

from vanna.openai.openai_chat import OpenAI_Chat
from vanna.chromadb.chromadb_vector import ChromaDB_VectorStore

class MyVanna(ChromaDB_VectorStore, OpenAI_Chat):
    def __init__(self, config=None):
        ChromaDB_VectorStore.__init__(self, config=config)
        OpenAI_Chat.__init__(self, config=config)

    def is_sql_valid(self, sql: str) -> bool:
        # Your implementation here

        return False

vn = MyVanna(config={'api_key': 'sk-...', 'model': 'gpt-4-...'})

# Example usage
is_valid = vn.is_sql_valid("SELECT user_name, user_email FROM users WHERE user_id = 123")
print(f"Is the SQL valid? {is_valid}")

添加一个基于LLM的附加功能

如果你想添加一个使用LLM的新功能,你可以通过向类中添加一个新方法来实现。假设你想“解释”一个SQL查询。你可以向类中添加一个新方法,使用LLM生成SQL查询的解释。

from vanna.openai.openai_chat import OpenAI_Chat
from vanna.chromadb.chromadb_vector import ChromaDB_VectorStore

class MyVanna(ChromaDB_VectorStore, OpenAI_Chat):
    def __init__(self, config=None):
        ChromaDB_VectorStore.__init__(self, config=config)
        OpenAI_Chat.__init__(self, config=config)

    def generate_query_explanation(self, sql: str):
        my_prompt = [
            self.system_message("You are a helpful assistant that will explain a SQL query"),
            self.user_message("Explain this SQL query: " + sql),
        ]

        return self.submit_prompt(prompt=my_prompt)

vn = MyVanna(config={'api_key': 'sk-...', 'model': 'gpt-3.5-turbo'})

vn.generate_query_explanation("SELECT user_name, user_email FROM users WHERE user_id = 123")

输出: '这个SQL查询是从users表中选择user_nameuser_email列。它使用WHERE子句指定了一个条件,其中user_id列必须等于123。换句话说,它正在检索user_id123的用户的user_nameuser_email。'