Skip to content

测试数据库

/// 信息

这些文档即将更新。🎉

当前版本假设使用 Pydantic v1 和 SQLAlchemy 版本低于 2.0。

新文档将包括 Pydantic v2,并将使用 SQLModel(它也基于 SQLAlchemy),一旦它更新为使用 Pydantic v2。

///

你可以使用与使用覆盖测试依赖项中相同的依赖项覆盖来更改测试数据库。

你可能希望为测试设置一个不同的数据库,在测试后回滚数据,预填充一些测试数据等。

主要思想与你在前一章中看到的完全相同。

为 SQL 应用添加测试

让我们从SQL(关系型)数据库中的示例更新为使用测试数据库。

所有应用代码都相同,你可以回到那一章查看它是如何实现的。

这里唯一的更改是在新的测试文件中。

你的正常依赖项 get_db() 会返回一个数据库会话。

在测试中,你可以使用依赖项覆盖来返回你的*自定义*数据库会话,而不是通常使用的会话。

在这个示例中,我们将仅为测试创建一个临时数据库。

文件结构

我们在 sql_app/tests/test_sql_app.py 创建一个新文件。

因此,新的文件结构如下:

.
└── sql_app
    ├── __init__.py
    ├── crud.py
    ├── database.py
    ├── main.py
    ├── models.py
    ├── schemas.py
    └── tests
        ├── __init__.py
        └── test_sql_app.py

创建新的数据库会话

首先,我们使用新数据库创建一个新的数据库会话。

我们将使用一个在测试期间持久化的内存数据库,而不是本地文件 sql_app.db

但会话代码的其余部分大致相同,我们只是复制它。

from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool

from ..database import Base
from ..main import app, get_db

SQLALCHEMY_DATABASE_URL = "sqlite://"

engine = create_engine(
    SQLALCHEMY_DATABASE_URL,
    connect_args={"check_same_thread": False},
    poolclass=StaticPool,
)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)


Base.metadata.create_all(bind=engine)


def override_get_db():
    try:
        db = TestingSessionLocal()
        yield db
    finally:
        db.close()


app.dependency_overrides[get_db] = override_get_db

client = TestClient(app)


def test_create_user():
    response = client.post(
        "/users/",
        json={"email": "deadpool@example.com", "password": "chimichangas4life"},
    )
    assert response.status_code == 200, response.text
    data = response.json()
    assert data["email"] == "deadpool@example.com"
    assert "id" in data
    user_id = data["id"]

    response = client.get(f"/users/{user_id}")
    assert response.status_code == 200, response.text
    data = response.json()
    assert data["email"] == "deadpool@example.com"
    assert data["id"] == user_id

/// 提示

你可以通过将该代码放入一个函数中,并从 database.pytests/test_sql_app.py 中使用它来减少代码重复。

为了简单起见,并专注于特定的测试代码,我们只是复制了它。

///

创建数据库

因为现在我们要在一个新文件中使用一个新数据库,我们需要确保使用以下命令创建数据库:

Base.metadata.create_all(bind=engine)

这通常在 main.py 中调用,但 main.py 中的那一行使用数据库文件 sql_app.db,我们需要确保为测试创建 test.db

因此,我们在这里添加了那一行,使用新文件。

from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool

from ..database import Base
from ..main import app, get_db

SQLALCHEMY_DATABASE_URL = "sqlite://"

engine = create_engine(
    SQLALCHEMY_DATABASE_URL,
    connect_args={"check_same_thread": False},
    poolclass=StaticPool,
)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)


Base.metadata.create_all(bind=engine)


def override_get_db():
    try:
        db = TestingSessionLocal()
        yield db
    finally:
        db.close()


app.dependency_overrides[get_db] = override_get_db

client = TestClient(app)


def test_create_user():
    response = client.post(
        "/users/",
        json={"email": "deadpool@example.com", "password": "chimichangas4life"},
    )
    assert response.status_code == 200, response.text
    data = response.json()
    assert data["email"] == "deadpool@example.com"
    assert "id" in data
    user_id = data["id"]

    response = client.get(f"/users/{user_id}")
    assert response.status_code == 200, response.text
    data = response.json()
    assert data["email"] == "deadpool@example.com"
    assert data["id"] == user_id

依赖项覆盖

现在我们创建依赖项覆盖,并将其添加到应用的覆盖中。

from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool

from ..database import Base
from ..main import app, get_db

SQLALCHEMY_DATABASE_URL = "sqlite://"

engine = create_engine(
    SQLALCHEMY_DATABASE_URL,
    connect_args={"check_same_thread": False},
    poolclass=StaticPool,
)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)


Base.metadata.create_all(bind=engine)


def override_get_db():
    try:
        db = TestingSessionLocal()
        yield db
    finally:
        db.close()


app.dependency_overrides[get_db] = override_get_db

client = TestClient(app)


def test_create_user():
    response = client.post(
        "/users/",
        json={"email": "deadpool@example.com", "password": "chimichangas4life"},
    )
    assert response.status_code == 200, response.text
    data = response.json()
    assert data["email"] == "deadpool@example.com"
    assert "id" in data
    user_id = data["id"]

    response = client.get(f"/users/{user_id}")
    assert response.status_code == 200, response.text
    data = response.json()
    assert data["email"] == "deadpool@example.com"
    assert data["id"] == user_id

/// 提示

override_get_db() 的代码几乎与 get_db() 完全相同,但在 override_get_db() 中,我们使用 TestingSessionLocal 作为测试数据库。

///

测试应用

然后我们可以像往常一样测试应用。

from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool

from ..database import Base
from ..main import app, get_db

SQLALCHEMY_DATABASE_URL = "sqlite://"

engine = create_engine(
    SQLALCHEMY_DATABASE_URL,
    connect_args={"check_same_thread": False},
    poolclass=StaticPool,
)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)


Base.metadata.create_all(bind=engine)


def override_get_db():
    try:
        db = TestingSessionLocal()
        yield db
    finally:
        db.close()


app.dependency_overrides[get_db] = override_get_db

client = TestClient(app)


def test_create_user():
    response = client.post(
        "/users/",
        json={"email": "deadpool@example.com", "password": "chimichangas4life"},
    )
    assert response.status_code == 200, response.text
    data = response.json()
    assert data["email"] == "deadpool@example.com"
    assert "id" in data
    user_id = data["id"]

    response = client.get(f"/users/{user_id}")
    assert response.status_code == 200, response.text
    data = response.json()
    assert data["email"] == "deadpool@example.com"
    assert data["id"] == user_id

我们在测试期间对数据库所做的所有修改都将位于 test.db 数据库中,而不是主 sql_app.db 中。