行为定制
所有Vanna函数都继承自VannaBase类。这是一个抽象基类,为所有Vanna函数提供基本功能。根据您选择的配置的具体情况,实现位于继承自该基类的类中。
您可以选择通过创建一个直接从VannaBase继承或从其继承的类之一继承的类来自定义Vanna的行为。如果您想更改特定函数的行为或添加新功能,这将非常有用。
类实例化
这是一个示例,展示了当配置为使用OpenAI API和ChromaDB向量存储时,类是如何实例化的。
要自定义行为的具体细节,您可以在实例化MyVanna类时覆盖基类中的任何方法。
from vanna.openai.openai_chat import OpenAI_Chat
from vanna.chromadb.chromadb_vector import ChromaDB_VectorStore
class MyVanna(ChromaDB_VectorStore, OpenAI_Chat):
def __init__(self, config=None):
ChromaDB_VectorStore.__init__(self, config=config)
OpenAI_Chat.__init__(self, config=config)
vn = MyVanna(config={'api_key': 'sk-...', 'model': 'gpt-4-...'})
重写特定函数
这里有一个如何重写 is_sql_valid 函数的示例。
from vanna.openai.openai_chat import OpenAI_Chat
from vanna.chromadb.chromadb_vector import ChromaDB_VectorStore
class MyVanna(ChromaDB_VectorStore, OpenAI_Chat):
def __init__(self, config=None):
ChromaDB_VectorStore.__init__(self, config=config)
OpenAI_Chat.__init__(self, config=config)
def is_sql_valid(self, sql: str) -> bool:
# Your implementation here
return False
vn = MyVanna(config={'api_key': 'sk-...', 'model': 'gpt-4-...'})
# Example usage
is_valid = vn.is_sql_valid("SELECT user_name, user_email FROM users WHERE user_id = 123")
print(f"Is the SQL valid? {is_valid}")
添加一个基于LLM的附加功能
如果你想添加一个使用LLM的新功能,你可以通过向类中添加一个新方法来实现。假设你想“解释”一个SQL查询。你可以向类中添加一个新方法,使用LLM生成SQL查询的解释。
from vanna.openai.openai_chat import OpenAI_Chat
from vanna.chromadb.chromadb_vector import ChromaDB_VectorStore
class MyVanna(ChromaDB_VectorStore, OpenAI_Chat):
def __init__(self, config=None):
ChromaDB_VectorStore.__init__(self, config=config)
OpenAI_Chat.__init__(self, config=config)
def generate_query_explanation(self, sql: str):
my_prompt = [
self.system_message("You are a helpful assistant that will explain a SQL query"),
self.user_message("Explain this SQL query: " + sql),
]
return self.submit_prompt(prompt=my_prompt)
vn = MyVanna(config={'api_key': 'sk-...', 'model': 'gpt-3.5-turbo'})
vn.generate_query_explanation("SELECT user_name, user_email FROM users WHERE user_id = 123")
输出: '这个SQL查询是从users
表中选择user_name
和user_email
列。它使用WHERE
子句指定了一个条件,其中user_id
列必须等于123
。换句话说,它正在检索user_id
为123
的用户的user_name
和user_email
。'