Skip to content

SQL (关系型) 数据库与 Peewee (已弃用)

/// 警告 | "已弃用"

本教程已弃用,将在未来版本中移除。

///

/// 警告

如果你刚开始学习,使用 SQLAlchemy 的教程 SQL (关系型) 数据库 应该足够了。

你可以随意跳过本教程。

不推荐在 FastAPI 中使用 Peewee,因为它与任何异步 Python 代码都不兼容。有更好的替代方案。

///

/// 信息

这些文档假设使用 Pydantic v1。

由于 Peewee 与异步代码不兼容,且有更好的替代方案,我不会为 Pydantic v2 更新这些文档,它们目前仅保留历史用途。

这里的示例不再在 CI 中测试(之前是测试的)。

///

如果你从零开始一个项目,你可能最好使用 SQLAlchemy ORM(SQL (关系型) 数据库),或其他任何异步 ORM。

如果你已经有一个使用 Peewee ORM 的代码库,你可以在这里查看如何将其与 FastAPI 一起使用。

/// 警告 | "需要 Python 3.7+"

你需要 Python 3.7 或更高版本才能安全地将 Peewee 与 FastAPI 一起使用。

///

异步的 Peewee

Peewee 并非为异步框架设计,也没有考虑到异步框架。

Peewee 对其默认设置和使用方式有一些强烈的假设。

如果你正在开发一个使用旧的非异步框架的应用程序,并且可以接受其所有默认设置,它可能是一个很棒的工具

但如果你需要更改一些默认设置,支持多个预定义的数据库,使用异步框架(如 FastAPI)等,你需要添加相当多的复杂额外代码来覆盖这些默认设置。

尽管如此,这是可能的,在这里你将看到确切的代码,以便能够将 Peewee 与 FastAPI 一起使用。

/// 注意 | "技术细节"

你可以在文档中阅读更多关于 Peewee 对 Python 异步的立场 在文档中一个 issue一个 PR

///

相同的应用

我们将创建与 SQLAlchemy 教程中相同的应用程序(SQL (关系型) 数据库)。

大部分代码实际上是相同的。

因此,我们将只关注差异部分。

文件结构

假设你有一个名为 my_super_project 的目录,其中包含一个名为 sql_app 的子目录,结构如下:

.
└── sql_app
    ├── __init__.py
    ├── crud.py
    ├── database.py
    ├── main.py
    └── schemas.py

这与我们在 SQLAlchemy 教程中的结构几乎相同。

现在让我们看看每个文件/模块的作用。

创建 Peewee 部分

让我们参考文件 sql_app/database.py

标准的 Peewee 代码

首先检查所有正常的 Peewee 代码,创建一个 Peewee 数据库:

from contextvars import ContextVar

import peewee

DATABASE_NAME = "test.db"
db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None}
db_state = ContextVar("db_state", default=db_state_default.copy())


class PeeweeConnectionState(peewee._ConnectionState):
    def __init__(self, **kwargs):
        super().__setattr__("_state", db_state)
        super().__init__(**kwargs)

    def __setattr__(self, name, value):
        self._state.get()[name] = value

    def __getattr__(self, name):
        return self._state.get()[name]


db = peewee.SqliteDatabase(DATABASE_NAME, check_same_thread=False)

db._state = PeeweeConnectionState()

/// 提示

请记住,如果你想使用不同的数据库,比如 PostgreSQL,你不能只更改字符串。你需要使用不同的 Peewee 数据库类。

///

注意

参数:

check_same_thread=False

相当于 SQLAlchemy 教程中的:

connect_args={"check_same_thread": False}

...它仅在 SQLite 中需要。

/// 信息 | "技术细节"

SQL (关系型) 数据库 中的技术细节完全相同。

///

使 Peewee 兼容异步的 PeeweeConnectionState

Peewee 和 FastAPI 的主要问题是 Peewee 严重依赖 Python 的 threading.local,并且没有直接的方法来覆盖它或让你直接处理连接/会话(如在 SQLAlchemy 教程中所做的)。

threading.local 与现代 Python 的新异步特性不兼容。

/// 注意 | "技术细节"

threading.local 用于拥有一个“神奇”变量,该变量在每个线程中具有不同的值。

这在设计为每个请求只有一个线程的旧框架中很有用,不多也不少。

使用这个,每个请求都会有自己的数据库连接/会话,这实际上是最终目标。

/// 但是,使用新的异步特性的 FastAPI 可以在同一个线程上处理多个请求。同时,对于单个请求,它可以根据你使用的是 async def 还是普通的 def,在不同的线程(在线程池中)运行多个任务。这就是 FastAPI 性能提升的全部原因。

///

但是,Python 3.7 及以上版本提供了一个比 threading.local 更高级的替代方案,这个替代方案也可以在原本使用 threading.local 的地方使用,但与新的异步特性兼容。

我们将使用这个替代方案。它被称为 contextvars

我们将覆盖 Peewee 内部使用 threading.local 的部分,并用 contextvars 替换它们,并进行相应的更新。

这可能看起来有点复杂(实际上确实如此),你并不需要完全理解它的工作原理才能使用它。

我们将创建一个 PeeweeConnectionState

from contextvars import ContextVar

import peewee

DATABASE_NAME = "test.db"
db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None}
db_state = ContextVar("db_state", default=db_state_default.copy())


class PeeweeConnectionState(peewee._ConnectionState):
    def __init__(self, **kwargs):
        super().__setattr__("_state", db_state)
        super().__init__(**kwargs)

    def __setattr__(self, name, value):
        self._state.get()[name] = value

    def __getattr__(self, name):
        return self._state.get()[name]


db = peewee.SqliteDatabase(DATABASE_NAME, check_same_thread=False)

db._state = PeeweeConnectionState()

这个类继承自 Peewee 使用的一个特殊的内部类。

它包含了所有使 Peewee 使用 contextvars 而不是 threading.local 的逻辑。

contextvars 的工作方式与 threading.local 略有不同。但 Peewee 的其余内部代码假设这个类使用 threading.local

因此,我们需要做一些额外的技巧,使其看起来就像只是使用了 threading.local__init____setattr____getattr__ 实现了所有必要的技巧,以便 Peewee 可以在不知道它现在与 FastAPI 兼容的情况下使用它。

Tip

这将只是使 Peewee 在使用 FastAPI 时正确地工作。不会随机打开或关闭正在使用的连接,不会创建错误等。

但它并没有赋予 Peewee 异步的超能力。你仍然应该使用普通的 def 函数,而不是 async def

使用自定义的 PeeweeConnectionState

现在,使用新的 PeeweeConnectionState 覆盖 Peewee 数据库 db 对象的内部属性 ._state

from contextvars import ContextVar

import peewee

DATABASE_NAME = "test.db"
db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None}
db_state = ContextVar("db_state", default=db_state_default.copy())


class PeeweeConnectionState(peewee._ConnectionState):
    def __init__(self, **kwargs):
        super().__setattr__("_state", db_state)
        super().__init__(**kwargs)

    def __setattr__(self, name, value):
        self._state.get()[name] = value

    def __getattr__(self, name):
        return self._state.get()[name]


db = peewee.SqliteDatabase(DATABASE_NAME, check_same_thread=False)

db._state = PeeweeConnectionState()

Tip

确保你在创建 db 之后覆盖 db._state

Tip

对于其他任何 Peewee 数据库,包括 PostgresqlDatabaseMySQLDatabase 等,你也需要做同样的操作。

创建数据库模型

现在让我们看看 sql_app/models.py 文件。

为我们的数据创建 Peewee 模型

现在为 UserItem 创建 Peewee 模型(类)。

这与你在遵循 Peewee 教程并更新模型以拥有与 SQLAlchemy 教程中相同的数据时所做的相同。

Tip

Peewee 也使用术语“模型”来指代这些与数据库交互的类和实例。

但 Pydantic 也使用术语“模型”来指代不同的东西,即数据验证、转换和文档类和实例。

database(上面的 database.py 文件)导入 db 并在这里使用它。

import peewee

from .database import db


class User(peewee.Model):
    email = peewee.CharField(unique=True, index=True)
    hashed_password = peewee.CharField()
    is_active = peewee.BooleanField(default=True)

    class Meta:
        database = db


class Item(peewee.Model):
    title = peewee.CharField(index=True)
    description = peewee.CharField(index=True)
    owner = peewee.ForeignKeyField(User, backref="items")

    class Meta:
        database = db

Tip

Peewee 创建了几个魔法属性。

它会自动添加一个 id 属性作为整数,作为主键。

它会根据类名选择表名。

对于 Item,它会创建一个 owner_id 属性,其中包含 User 的整数 ID。但我们并没有在任何地方声明它。

创建 Pydantic 模型

现在让我们检查 sql_app/schemas.py 文件。

Tip

为了避免 Peewee 模型 和 Pydantic 模型 之间的混淆,我们将有包含 Peewee 模型的 models.py 文件,以及包含 Pydantic 模型的 schemas.py 文件。

这些 Pydantic 模型或多或少定义了一个“模式”(有效的数据形状)。

因此,这将帮助我们在使用两者时避免混淆。

创建 Pydantic 模型 / 模式

创建与 SQLAlchemy 教程中相同的所有 Pydantic 模型:

from typing import Any, List, Union

import peewee
from pydantic import BaseModel
from pydantic.utils import GetterDict


class PeeweeGetterDict(GetterDict):
    def get(self, key: Any, default: Any = None):
        res = getattr(self._obj, key, default)
        if isinstance(res, peewee.ModelSelect):
            return list(res)
        return res


class ItemBase(BaseModel):
    title: str
    description: Union[str, None] = None


class ItemCreate(ItemBase):
    pass


class Item(ItemBase):
    id: int
    owner_id: int

    class Config:
        orm_mode = True
        getter_dict = PeeweeGetterDict


class UserBase(BaseModel):
    email: str


class UserCreate(UserBase):
    password: str


class User(UserBase):
    id: int
    is_active: bool
    items: List[Item] = []

    class Config:
        orm_mode = True
        getter_dict = PeeweeGetterDict

Tip

在这里,我们使用 id 创建模型。

我们没有在 Peewee 模型中显式指定 id 属性,但 Peewee 会自动添加一个。

我们还在 Item 中添加了魔法 owner_id 属性。

为 Pydantic 模型 / 模式创建 PeeweeGetterDict

当你访问 Peewee 对象中的关系时,比如在 some_user.items 中,Peewee 不会提供一个 Itemlist

它提供了一个特殊定制的 ModelSelect 类对象。

你可以使用 list(some_user.items) 创建其项目的 list。 但这个对象本身并不是一个 list。它也不是一个实际的 Python 生成器。因此,Pydantic 默认情况下不知道如何将其转换为 Pydantic 模型 / 模式列表。

但 Pydantic 的最新版本允许提供一个继承自 pydantic.utils.GetterDict 的自定义类,以在使用 orm_mode = True 时提供检索 ORM 模型属性值的功能。

我们将创建一个自定义的 PeeweeGetterDict 类,并在所有使用 orm_mode 的 Pydantic 模型 / 模式中使用它:

from typing import Any, List, Union

import peewee
from pydantic import BaseModel
from pydantic.utils import GetterDict


class PeeweeGetterDict(GetterDict):
    def get(self, key: Any, default: Any = None):
        res = getattr(self._obj, key, default)
        if isinstance(res, peewee.ModelSelect):
            return list(res)
        return res


class ItemBase(BaseModel):
    title: str
    description: Union[str, None] = None


class ItemCreate(ItemBase):
    pass


class Item(ItemBase):
    id: int
    owner_id: int

    class Config:
        orm_mode = True
        getter_dict = PeeweeGetterDict


class UserBase(BaseModel):
    email: str


class UserCreate(UserBase):
    password: str


class User(UserBase):
    id: int
    is_active: bool
    items: List[Item] = []

    class Config:
        orm_mode = True
        getter_dict = PeeweeGetterDict

在这里,我们检查正在访问的属性(例如 some_user.items 中的 .items)是否是 peewee.ModelSelect 的实例。

如果是这种情况,只需返回一个包含它的 list

然后,我们在使用 orm_mode = True 的 Pydantic 模型 / 模式中使用它,配置变量为 getter_dict = PeeweeGetterDict

Tip

我们只需要创建一个 PeeweeGetterDict 类,就可以在所有 Pydantic 模型 / 模式中使用它。

CRUD 工具

现在让我们看看 sql_app/crud.py 文件。

创建所有 CRUD 工具

创建与 SQLAlchemy 教程中相同的所有 CRUD 工具,所有代码都非常相似:

from . import models, schemas


def get_user(user_id: int):
    return models.User.filter(models.User.id == user_id).first()


def get_user_by_email(email: str):
    return models.User.filter(models.User.email == email).first()


def get_users(skip: int = 0, limit: int = 100):
    return list(models.User.select().offset(skip).limit(limit))


def create_user(user: schemas.UserCreate):
    fake_hashed_password = user.password + "notreallyhashed"
    db_user = models.User(email=user.email, hashed_password=fake_hashed_password)
    db_user.save()
    return db_user


def get_items(skip: int = 0, limit: int = 100):
    return list(models.Item.select().offset(skip).limit(limit))


def create_user_item(item: schemas.ItemCreate, user_id: int):
    db_item = models.Item(**item.dict(), owner_id=user_id)
    db_item.save()
    return db_item

与 SQLAlchemy 教程的代码有一些不同。

我们没有传递 db 属性。相反,我们直接使用模型。这是因为 db 对象是一个包含所有连接逻辑的全局对象。这就是为什么我们必须进行上述所有 contextvars 更新的原因。

此外,当返回多个对象时,例如在 get_users 中,我们直接调用 list,如下所示:

list(models.User.select())

这是出于与我们必须创建自定义 PeeweeGetterDict 相同的原因。但通过返回已经是一个 list 而不是 peewee.ModelSelectresponse_model 在带有 List[models.User]路径操作 中(我们稍后会看到)将正常工作。

FastAPI 应用

现在在 sql_app/main.py 文件中,让我们集成并使用之前创建的所有其他部分。

创建数据库表

以非常简单的方式创建数据库表:

import time
from typing import List

from fastapi import Depends, FastAPI, HTTPException

from . import crud, database, models, schemas
from .database import db_state_default

database.db.connect()
database.db.create_tables([models.User, models.Item])
database.db.close()

app = FastAPI()

sleep_time = 10


async def reset_db_state():
    database.db._state._state.set(db_state_default.copy())
    database.db._state.reset()


def get_db(db_state=Depends(reset_db_state)):
    try:
        database.db.connect()
        yield
    finally:
        if not database.db.is_closed():
            database.db.close()


@app.post("/users/", response_model=schemas.User, dependencies=[Depends(get_db)])
def create_user(user: schemas.UserCreate):
    db_user = crud.get_user_by_email(email=user.email)
    if db_user:
        raise HTTPException(status_code=400, detail="Email already registered")
    return crud.create_user(user=user)


@app.get("/users/", response_model=List[schemas.User], dependencies=[Depends(get_db)])
def read_users(skip: int = 0, limit: int = 100):
    users = crud.get_users(skip=skip, limit=limit)
    return users


@app.get(
    "/users/{user_id}", response_model=schemas.User, dependencies=[Depends(get_db)]
)
def read_user(user_id: int):
    db_user = crud.get_user(user_id=user_id)
    if db_user is None:
        raise HTTPException(status_code=404, detail="User not found")
    return db_user


@app.post(
    "/users/{user_id}/items/",
    response_model=schemas.Item,
    dependencies=[Depends(get_db)],
)
def create_item_for_user(user_id: int, item: schemas.ItemCreate):
    return crud.create_user_item(item=item, user_id=user_id)


@app.get("/items/", response_model=List[schemas.Item], dependencies=[Depends(get_db)])
def read_items(skip: int = 0, limit: int = 100):
    items = crud.get_items(skip=skip, limit=limit)
    return items


@app.get(
    "/slowusers/", response_model=List[schemas.User], dependencies=[Depends(get_db)]
)
def read_slow_users(skip: int = 0, limit: int = 100):
    global sleep_time
    sleep_time = max(0, sleep_time - 1)
    time.sleep(sleep_time)  # Fake long processing request
    users = crud.get_users(skip=skip, limit=limit)
    return users

创建依赖项

创建一个依赖项,该依赖项将在请求开始时连接数据库,并在结束时断开连接:

import time
from typing import List

from fastapi import Depends, FastAPI, HTTPException

from . import crud, database, models, schemas
from .database import db_state_default

database.db.connect()
database.db.create_tables([models.User, models.Item])
database.db.close()

app = FastAPI()

sleep_time = 10


async def reset_db_state():
    database.db._state._state.set(db_state_default.copy())
    database.db._state.reset()


def get_db(db_state=Depends(reset_db_state)):
    try:
        database.db.connect()
        yield
    finally:
        if not database.db.is_closed():
            database.db.close()


@app.post("/users/", response_model=schemas.User, dependencies=[Depends(get_db)])
def create_user(user: schemas.UserCreate):
    db_user = crud.get_user_by_email(email=user.email)
    if db_user:
        raise HTTPException(status_code=400, detail="Email already registered")
    return crud.create_user(user=user)


@app.get("/users/", response_model=List[schemas.User], dependencies=[Depends(get_db)])
def read_users(skip: int = 0, limit: int = 100):
    users = crud.get_users(skip=skip, limit=limit)
    return users


@app.get(
    "/users/{user_id}", response_model=schemas.User, dependencies=[Depends(get_db)]
)
def read_user(user_id: int):
    db_user = crud.get_user(user_id=user_id)
    if db_user is None:
        raise HTTPException(status_code=404, detail="User not found")
    return db_user


@app.post(
    "/users/{user_id}/items/",
    response_model=schemas.Item,
    dependencies=[Depends(get_db)],
)
def create_item_for_user(user_id: int, item: schemas.ItemCreate):
    return crud.create_user_item(item=item, user_id=user_id)


@app.get("/items/", response_model=List[schemas.Item], dependencies=[Depends(get_db)])
def read_items(skip: int = 0, limit: int = 100):
    items = crud.get_items(skip=skip, limit=limit)
    return items


@app.get(
    "/slowusers/", response_model=List[schemas.User], dependencies=[Depends(get_db)]
)
def read_slow_users(skip: int = 0, limit: int = 100):
    global sleep_time
    sleep_time = max(0, sleep_time - 1)
    time.sleep(sleep_time)  # Fake long processing request
    users = crud.get_users(skip=skip, limit=limit)
    return users

这里我们有一个空的 yield,因为我们实际上没有直接使用数据库对象。

它连接到数据库,并将连接数据存储在每个请求独立的内部变量中(使用上述的 contextvars 技巧)。

由于数据库连接可能是 I/O 阻塞的,因此此依赖项是使用普通的 def 函数创建的。

然后,在每个需要访问数据库的 路径操作函数 中,我们将其作为依赖项添加。

但我们没有使用此依赖项提供的值(实际上它没有提供任何值,因为它有一个空的 yield)。因此,我们没有将其添加到 路径操作函数 中,而是添加到 路径操作装饰器dependencies 参数中:

import time
from typing import List

from fastapi import Depends, FastAPI, HTTPException

from . import crud, database, models, schemas
from .database import db_state_default

database.db.connect()
database.db.create_tables([models.User, models.Item])
database.db.close()

app = FastAPI()

sleep_time = 10


async def reset_db_state():
    database.db._state._state.set(db_state_default.copy())
    database.db._state.reset()


def get_db(db_state=Depends(reset_db_state)):
    try:
        database.db.connect()
        yield
    finally:
        if not database.db.is_closed():
            database.db.close()


@app.post("/users/", response_model=schemas.User, dependencies=[Depends(get_db)])
def create_user(user: schemas.UserCreate):
    db_user = crud.get_user_by_email(email=user.email)
    if db_user:
        raise HTTPException(status_code=400, detail="Email already registered")
    return crud.create_user(user=user)


@app.get("/users/", response_model=List[schemas.User], dependencies=[Depends(get_db)])
def read_users(skip: int = 0, limit: int = 100):
    users = crud.get_users(skip=skip, limit=limit)
    return users


@app.get(
    "/users/{user_id}", response_model=schemas.User, dependencies=[Depends(get_db)]
)
def read_user(user_id: int):
    db_user = crud.get_user(user_id=user_id)
    if db_user is None:
        raise HTTPException(status_code=404, detail="User not found")
    return db_user


@app.post(
    "/users/{user_id}/items/",
    response_model=schemas.Item,
    dependencies=[Depends(get_db)],
)
def create_item_for_user(user_id: int, item: schemas.ItemCreate):
    return crud.create_user_item(item=item, user_id=user_id)


@app.get("/items/", response_model=List[schemas.Item], dependencies=[Depends(get_db)])
def read_items(skip: int = 0, limit: int = 100):
    items = crud.get_items(skip=skip, limit=limit)
    return items


@app.get(
    "/slowusers/", response_model=List[schemas.User], dependencies=[Depends(get_db)]
)
def read_slow_users(skip: int = 0, limit: int = 100):
    global sleep_time
    sleep_time = max(0, sleep_time - 1)
    time.sleep(sleep_time)  # Fake long processing request
    users = crud.get_users(skip=skip, limit=limit)
    return users

上下文变量子依赖项

为了使所有 contextvars 部分正常工作,我们需要确保每个使用数据库的请求在 ContextVar 中都有一个独立值,并且该值将用作整个请求的数据库状态(连接、事务等)。

为此,我们需要创建另一个 async 依赖项 reset_db_state(),它作为 get_db() 中的子依赖项使用。它将为上下文变量设置值(仅使用默认的 dict),该值将用作整个请求的数据库状态。然后,依赖项 get_db() 将存储数据库状态(连接、事务等)。

import time
from typing import List

from fastapi import Depends, FastAPI, HTTPException

from . import crud, database, models, schemas
from .database import db_state_default

database.db.connect()
database.db.create_tables([models.User, models.Item])
database.db.close()

app = FastAPI()

sleep_time = 10


async def reset_db_state():
    database.db._state._state.set(db_state_default.copy())
    database.db._state.reset()


def get_db(db_state=Depends(reset_db_state)):
    try:
        database.db.connect()
        yield
    finally:
        if not database.db.is_closed():
            database.db.close()


@app.post("/users/", response_model=schemas.User, dependencies=[Depends(get_db)])
def create_user(user: schemas.UserCreate):
    db_user = crud.get_user_by_email(email=user.email)
    if db_user:
        raise HTTPException(status_code=400, detail="Email already registered")
    return crud.create_user(user=user)


@app.get("/users/", response_model=List[schemas.User], dependencies=[Depends(get_db)])
def read_users(skip: int = 0, limit: int = 100):
    users = crud.get_users(skip=skip, limit=limit)
    return users


@app.get(
    "/users/{user_id}", response_model=schemas.User, dependencies=[Depends(get_db)]
)
def read_user(user_id: int):
    db_user = crud.get_user(user_id=user_id)
    if db_user is None:
        raise HTTPException(status_code=404, detail="User not found")
    return db_user


@app.post(
    "/users/{user_id}/items/",
    response_model=schemas.Item,
    dependencies=[Depends(get_db)],
)
def create_item_for_user(user_id: int, item: schemas.ItemCreate):
    return crud.create_user_item(item=item, user_id=user_id)


@app.get("/items/", response_model=List[schemas.Item], dependencies=[Depends(get_db)])
def read_items(skip: int = 0, limit: int = 100):
    items = crud.get_items(skip=skip, limit=limit)
    return items


@app.get(
    "/slowusers/", response_model=List[schemas.User], dependencies=[Depends(get_db)]
)
def read_slow_users(skip: int = 0, limit: int = 100):
    global sleep_time
    sleep_time = max(0, sleep_time - 1)
    time.sleep(sleep_time)  # Fake long processing request
    users = crud.get_users(skip=skip, limit=limit)
    return users

对于**下一个请求**,由于我们将在 async 依赖项 reset_db_state() 中再次重置该上下文变量,然后在 get_db() 依赖项中创建新连接,因此该新请求将拥有自己的数据库状态(连接、事务等)。

Tip

由于 FastAPI 是一个异步框架,一个请求可能开始处理,在完成之前,另一个请求可能被接收并开始处理,所有这些都可能在同一线程中处理。 但是上下文变量能够感知这些异步特性,因此,在 async 依赖 reset_db_state() 中设置的 Peewee 数据库状态将在整个请求过程中保持其数据。

同时,其他并发请求将拥有自己的数据库状态,该状态在整个请求过程中是独立的。

Peewee 代理

如果你使用的是 Peewee Proxy,实际的数据库位于 db.obj

因此,你可以通过以下方式重置它:

async def reset_db_state():
    database.db.obj._state._state.set(db_state_default.copy())
    database.db.obj._state.reset()

创建你的 FastAPI 路径操作

现在,最终,这里是标准的 FastAPI 路径操作 代码。

import time
from typing import List

from fastapi import Depends, FastAPI, HTTPException

from . import crud, database, models, schemas
from .database import db_state_default

database.db.connect()
database.db.create_tables([models.User, models.Item])
database.db.close()

app = FastAPI()

sleep_time = 10


async def reset_db_state():
    database.db._state._state.set(db_state_default.copy())
    database.db._state.reset()


def get_db(db_state=Depends(reset_db_state)):
    try:
        database.db.connect()
        yield
    finally:
        if not database.db.is_closed():
            database.db.close()


@app.post("/users/", response_model=schemas.User, dependencies=[Depends(get_db)])
def create_user(user: schemas.UserCreate):
    db_user = crud.get_user_by_email(email=user.email)
    if db_user:
        raise HTTPException(status_code=400, detail="Email already registered")
    return crud.create_user(user=user)


@app.get("/users/", response_model=List[schemas.User], dependencies=[Depends(get_db)])
def read_users(skip: int = 0, limit: int = 100):
    users = crud.get_users(skip=skip, limit=limit)
    return users


@app.get(
    "/users/{user_id}", response_model=schemas.User, dependencies=[Depends(get_db)]
)
def read_user(user_id: int):
    db_user = crud.get_user(user_id=user_id)
    if db_user is None:
        raise HTTPException(status_code=404, detail="User not found")
    return db_user


@app.post(
    "/users/{user_id}/items/",
    response_model=schemas.Item,
    dependencies=[Depends(get_db)],
)
def create_item_for_user(user_id: int, item: schemas.ItemCreate):
    return crud.create_user_item(item=item, user_id=user_id)


@app.get("/items/", response_model=List[schemas.Item], dependencies=[Depends(get_db)])
def read_items(skip: int = 0, limit: int = 100):
    items = crud.get_items(skip=skip, limit=limit)
    return items


@app.get(
    "/slowusers/", response_model=List[schemas.User], dependencies=[Depends(get_db)]
)
def read_slow_users(skip: int = 0, limit: int = 100):
    global sleep_time
    sleep_time = max(0, sleep_time - 1)
    time.sleep(sleep_time)  # Fake long processing request
    users = crud.get_users(skip=skip, limit=limit)
    return users

关于 defasync def

与 SQLAlchemy 一样,我们不会做类似的事情:

user = await models.User.select().first()

...而是使用:

user = models.User.select().first()

因此,再次强调,我们应该使用普通的 def 声明 路径操作函数 和依赖,而不是 async def,例如:

# 这里有一些代码
def read_users(skip: int = 0, limit: int = 100):
    # 这里有一些代码

使用异步测试 Peewee

这个示例包含一个额外的 路径操作,它模拟了一个长时间处理的请求,使用 time.sleep(sleep_time)

它将在请求开始时打开数据库连接,并在回复之前等待几秒钟。每个新的请求将等待的时间减少一秒。

这将很容易让你测试你的应用在处理线程相关内容时是否正确运行。

如果你想检查如果不做修改直接使用 Peewee 会如何破坏你的应用,请转到 sql_app/database.py 文件并注释掉以下行:

# db._state = PeeweeConnectionState()

然后在 sql_app/main.py 文件中,注释掉 async 依赖 reset_db_state() 的主体,并用 pass 替换它:

async def reset_db_state():
#     database.db._state._state.set(db_state_default.copy())
#     database.db._state.reset()
    pass

然后使用 Uvicorn 运行你的应用:

$ uvicorn sql_app.main:app --reload

<span style="color: green;">INFO</span>:     Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit)

在浏览器中打开 http://127.0.0.1:8000/docs 并创建几个用户。

然后在 http://127.0.0.1:8000/docs#/default/read_slow_users_slowusers__get 同时打开 10 个标签页。

在所有标签页中转到 路径操作 "Get /slowusers/"。使用 "Try it out" 按钮并依次执行每个标签页中的请求。

标签页将等待一段时间,然后其中一些将显示 Internal Server Error

发生了什么

第一个标签页将使你的应用创建一个数据库连接,并在回复之前等待几秒钟,然后关闭数据库连接。

然后,对于下一个标签页中的请求,你的应用将等待的时间减少一秒,依此类推。

这意味着最终某些最后标签页的请求会比之前的请求更早完成。

然后,等待时间较短的最后一个请求将尝试打开数据库连接,但由于其他标签页的请求可能与第一个请求在同一个线程中处理,它将使用已经打开的相同数据库连接,Peewee 将抛出错误,你将在终端中看到它,并且响应将包含 Internal Server Error

这可能会发生在多个标签页中。

如果有多个客户端同时与你的应用通信,这可能会发生。

随着你的应用同时处理的客户端越来越多,单个请求中的等待时间需要越来越短才能触发错误。

使用 FastAPI 修复 Peewee

现在回到 sql_app/database.py 文件,并取消注释以下行:

db._state = PeeweeConnectionState()

然后在 sql_app/main.py 文件中,取消注释 async 依赖 reset_db_state() 的主体:

async def reset_db_state():
    database.db._state._state.set(db_state_default.copy())
    database.db._state.reset()

终止正在运行的应用并重新启动它。

重复相同的 10 个标签页的过程。这次所有标签页都将等待,并且你将获得所有结果而不会出错。

...你修复了它!

查看所有文件

记住,你应该有一个名为 my_super_project(或任何你想要的名称)的目录,其中包含一个名为 sql_app 的子目录。

sql_app 应该包含以下文件:

  • sql_app/__init__.py:这是一个空文件。

  • sql_app/database.py

from contextvars import ContextVar

import peewee

DATABASE_NAME = "test.db"
db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None}
db_state = ContextVar("db_state", default=db_state_default.copy())


class PeeweeConnectionState(peewee._ConnectionState):
    def __init__(self, **kwargs):
        super().__setattr__("_state", db_state)
        super().__init__(**kwargs)

    def __setattr__(self, name, value):
        self._state.get()[name] = value

    def __getattr__(self, name):
        return self._state.get()[name]


db = peewee.SqliteDatabase(DATABASE_NAME, check_same_thread=False)

db._state = PeeweeConnectionState()
  • sql_app/models.py
import peewee

from .database import db


class User(peewee.Model):
    email = peewee.CharField(unique=True, index=True)
    hashed_password = peewee.CharField()
    is_active = peewee.BooleanField(default=True)

    class Meta:
        database = db


class Item(peewee.Model):
    title = peewee.CharField(index=True)
    description = peewee.CharField(index=True)
    owner = peewee.ForeignKeyField(User, backref="items")

    class Meta:
        database = db
  • sql_app/schemas.py
from typing import Any, List, Union

import peewee
from pydantic import BaseModel
from pydantic.utils import GetterDict


class PeeweeGetterDict(GetterDict):
    def get(self, key: Any, default: Any = None):
        res = getattr(self._obj, key, default)
        if isinstance(res, peewee.ModelSelect):
            return list(res)
        return res


class ItemBase(BaseModel):
    title: str
    description: Union[str, None] = None


class ItemCreate(ItemBase):
    pass


class Item(ItemBase):
    id: int
    owner_id: int

    class Config:
        orm_mode = True
        getter_dict = PeeweeGetterDict


class UserBase(BaseModel):
    email: str


class UserCreate(UserBase):
    password: str


class User(UserBase):
    id: int
    is_active: bool
    items: List[Item] = []

    class Config:
        orm_mode = True
        getter_dict = PeeweeGetterDict
  • sql_app/crud.py
from . import models, schemas


def get_user(user_id: int):
    return models.User.filter(models.User.id == user_id).first()


def get_user_by_email(email: str):
    return models.User.filter(models.User.email == email).first()


def get_users(skip: int = 0, limit: int = 100):
    return list(models.User.select().offset(skip).limit(limit))


def create_user(user: schemas.UserCreate):
    fake_hashed_password = user.password + "notreallyhashed"
    db_user = models.User(email=user.email, hashed_password=fake_hashed_password)
    db_user.save()
    return db_user


def get_items(skip: int = 0, limit: int = 100):
    return list(models.Item.select().offset(skip).limit(limit))


def create_user_item(item: schemas.ItemCreate, user_id: int):
    db_item = models.Item(**item.dict(), owner_id=user_id)
    db_item.save()
    return db_item
  • sql_app/main.py
import time
from typing import List

from fastapi import Depends, FastAPI, HTTPException

from . import crud, database, models, schemas
from .database import db_state_default

database.db.connect()
database.db.create_tables([models.User, models.Item])
database.db.close()

app = FastAPI()

sleep_time = 10


async def reset_db_state():
    database.db._state._state.set(db_state_default.copy())
    database.db._state.reset()


def get_db(db_state=Depends(reset_db_state)):
    try:
        database.db.connect()
        yield
    finally:
        if not database.db.is_closed():
            database.db.close()


@app.post("/users/", response_model=schemas.User, dependencies=[Depends(get_db)])
def create_user(user: schemas.UserCreate):
    db_user = crud.get_user_by_email(email=user.email)
    if db_user:
        raise HTTPException(status_code=400, detail="Email already registered")
    return crud.create_user(user=user)


@app.get("/users/", response_model=List[schemas.User], dependencies=[Depends(get_db)])
def read_users(skip: int = 0, limit: int = 100):
    users = crud.get_users(skip=skip, limit=limit)
    return users


@app.get(
    "/users/{user_id}", response_model=schemas.User, dependencies=[Depends(get_db)]
)
def read_user(user_id: int):
    db_user = crud.get_user(user_id=user_id)
    if db_user is None:
        raise HTTPException(status_code=404, detail="User not found")
    return db_user


@app.post(
    "/users/{user_id}/items/",
    response_model=schemas.Item,
    dependencies=[Depends(get_db)],
)
def create_item_for_user(user_id: int, item: schemas.ItemCreate):
    return crud.create_user_item(item=item, user_id=user_id)


@app.get("/items/", response_model=List[schemas.Item], dependencies=[Depends(get_db)])
def read_items(skip: int = 0, limit: int = 100):
    items = crud.get_items(skip=skip, limit=limit)
    return items


@app.get(
    "/slowusers/", response_model=List[schemas.User], dependencies=[Depends(get_db)]
)
def read_slow_users(skip: int = 0, limit: int = 100):
    global sleep_time
    sleep_time = max(0, sleep_time - 1)
    time.sleep(sleep_time)  # Fake long processing request
    users = crud.get_users(skip=skip, limit=limit)
    return users

技术细节

/// 警告

这些是非常技术性的细节,你可能不需要了解。

///

问题

Peewee 默认使用 threading.local 来存储其数据库“状态”数据(连接、事务等)。

threading.local 为当前线程创建一个独占的值,但异步框架会在同一个线程中运行所有代码(例如每个请求的代码),并且可能不会按顺序执行。

除此之外,异步框架可能会在线程池中运行一些同步代码(使用 asyncio.run_in_executor),但这些代码属于同一个请求。

这意味着,在 Peewee 的当前实现中,多个任务可能会使用相同的 threading.local 变量,并最终共享相同的连接和数据(它们不应该共享),同时,如果它们在线程池中执行同步的 I/O 阻塞代码(就像在 FastAPI 中使用普通 def 函数的路径操作和依赖项一样),该代码将无法访问数据库状态变量,即使它属于同一个请求,并且应该能够访问相同的数据库状态。

上下文变量

Python 3.7 引入了 contextvars,它可以创建一个类似于 threading.local 的局部变量,但同时也支持这些异步特性。

有几件事需要注意。

ContextVar 必须在模块的顶部创建,例如:

some_var = ContextVar("some_var", default="默认值")

要在当前“上下文”中设置一个值(例如当前请求),请使用:

some_var.set("新值")

要在上下文中的任何地方获取值(例如在处理当前请求的任何部分),请使用:

some_var.get()

async 依赖项 reset_db_state() 中设置上下文变量

如果异步代码的某些部分使用 some_var.set("在函数中更新") 设置了值(例如像 async 依赖项那样),那么该代码的其余部分以及之后的代码(包括使用 await 调用的 async 函数内部的代码)将看到这个新值。

因此,在我们的例子中,如果我们在 async 依赖项中设置了 Peewee 状态变量(使用默认的 dict),我们应用程序中的所有其余内部代码都将看到这个值,并且能够为整个请求重用它。

并且上下文变量将为下一个请求再次设置,即使它们是并发的。

在依赖项 get_db() 中设置数据库状态

由于 get_db() 是一个普通的 def 函数,FastAPI 将使其在线程池中运行,并带有“上下文”的副本,该副本包含上下文变量的相同值(重置数据库状态的 dict)。然后它可以向该 dict 添加数据库状态,例如连接等。

但如果上下文变量的值(默认的 dict)是在这个普通的 def 函数中设置的,它将创建一个新值,该值将仅保留在该线程池的线程中,其余代码(例如路径操作函数)将无法访问它。在 get_db() 中,我们只能设置 dict 中的值,但不能设置整个 dict 本身。

因此,我们需要 async 依赖项 reset_db_state() 来在上下文变量中设置 dict。这样,所有代码都可以访问同一个 dict,用于单个请求的数据库状态。

在依赖项 get_db() 中连接和断开数据库

接下来要问的问题是,为什么不直接在 async 依赖项本身中连接和断开数据库,而不是在 get_db() 中?

async 依赖项必须为 async,以便上下文变量在请求的其余部分中保持不变,但创建和关闭数据库连接可能是阻塞的,因此如果将其放在那里,可能会降低性能。

因此,我们还需要普通的 def 依赖项 get_db()