Source code for langchain_community.document_loaders.athena
from __future__ import annotations
import io
import json
import time
from typing import Any, Dict, Iterator, List, Optional, Tuple
from langchain_core.documents import Document
from langchain_community.document_loaders.base import BaseLoader
[docs]class AthenaLoader(BaseLoader):
"""从`AWS Athena`加载文档。
每个文档代表结果的一行。
- 默认情况下,所有列都写入文档的`page_content`中,
而没有列写入文档的`metadata`中。
- 如果提供了`metadata_columns`,那么这些列将被写入
文档的`metadata`中,而其余列将被写入文档的`page_content`中。
要进行身份验证,AWS客户端使用此方法自动加载凭据:
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
如果应该使用特定的凭据配置文件,则必须传递
要使用的`~/.aws/credentials`文件中的配置文件名称。
确保使用的凭据/角色具有访问Amazon Textract服务所需的策略。"""
[docs] def __init__(
self,
query: str,
database: str,
s3_output_uri: str,
profile_name: str,
metadata_columns: Optional[List[str]] = None,
):
"""初始化 Athena 文档加载器。
参数:
query: 在 Athena 中运行的查询。
database: Athena 数据库
s3_output_uri: Athena 输出路径
metadata_columns: 可选。写入到文档“metadata”的列。
"""
self.query = query
self.database = database
self.s3_output_uri = s3_output_uri
self.metadata_columns = metadata_columns if metadata_columns is not None else []
try:
import boto3
except ImportError:
raise ImportError(
"Could not import boto3 python package. "
"Please install it with `pip install boto3`."
)
try:
session = (
boto3.Session(profile_name=profile_name)
if profile_name is not None
else boto3.Session()
)
except Exception as e:
raise ValueError(
"Could not load credentials to authenticate with AWS client. "
"Please check that credentials in the specified "
"profile name are valid."
) from e
self.athena_client = session.client("athena")
self.s3_client = session.client("s3")
def _execute_query(self) -> List[Dict[str, Any]]:
response = self.athena_client.start_query_execution(
QueryString=self.query,
QueryExecutionContext={"Database": self.database},
ResultConfiguration={"OutputLocation": self.s3_output_uri},
)
query_execution_id = response["QueryExecutionId"]
while True:
response = self.athena_client.get_query_execution(
QueryExecutionId=query_execution_id
)
state = response["QueryExecution"]["Status"]["State"]
if state == "SUCCEEDED":
break
elif state == "FAILED":
resp_status = response["QueryExecution"]["Status"]
state_change_reason = resp_status["StateChangeReason"]
err = f"Query Failed: {state_change_reason}"
raise Exception(err)
elif state == "CANCELLED":
raise Exception("Query was cancelled by the user.")
time.sleep(1)
result_set = self._get_result_set(query_execution_id)
return json.loads(result_set.to_json(orient="records"))
def _remove_suffix(self, input_string: str, suffix: str) -> str:
if suffix and input_string.endswith(suffix):
return input_string[: -len(suffix)]
return input_string
def _remove_prefix(self, input_string: str, suffix: str) -> str:
if suffix and input_string.startswith(suffix):
return input_string[len(suffix) :]
return input_string
def _get_result_set(self, query_execution_id: str) -> Any:
try:
import pandas as pd
except ImportError:
raise ImportError(
"Could not import pandas python package. "
"Please install it with `pip install pandas`."
)
output_uri = self.s3_output_uri
tokens = self._remove_prefix(
self._remove_suffix(output_uri, "/"), "s3://"
).split("/")
bucket = tokens[0]
key = "/".join(tokens[1:] + [query_execution_id]) + ".csv"
obj = self.s3_client.get_object(Bucket=bucket, Key=key)
df = pd.read_csv(io.BytesIO(obj["Body"].read()), encoding="utf8")
return df
def _get_columns(
self, query_result: List[Dict[str, Any]]
) -> Tuple[List[str], List[str]]:
content_columns = []
metadata_columns = []
all_columns = list(query_result[0].keys())
for key in all_columns:
if key in self.metadata_columns:
metadata_columns.append(key)
else:
content_columns.append(key)
return content_columns, metadata_columns
[docs] def lazy_load(self) -> Iterator[Document]:
query_result = self._execute_query()
content_columns, metadata_columns = self._get_columns(query_result)
for row in query_result:
page_content = "\n".join(
f"{k}: {v}" for k, v in row.items() if k in content_columns
)
metadata = {
k: v for k, v in row.items() if k in metadata_columns and v is not None
}
doc = Document(page_content=page_content, metadata=metadata)
yield doc