Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 8397f48

Browse files
Version that finally worked
1 parent 6bf0eb2 commit 8397f48

File tree

6 files changed

+152
-39
lines changed

6 files changed

+152
-39
lines changed

‎app/app.py‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from sqlalchemy.exc import SQLAlchemyError
1010
from starlette.authentication import AuthenticationError
1111
from starlette_context import plugins
12-
from starlette_context.middleware import ContextMiddleware
12+
from starlette_context.middleware import RawContextMiddleware
1313

1414
from app.core.config import settings
1515
from app.utils.logs_formatters import JSONRequestLogFormatter, JSONLogWebFormatter
@@ -54,7 +54,7 @@ def init_middlewares(app: FastAPI) -> None:
5454
allow_headers=["*"],
5555
)
5656
app.add_middleware(
57-
ContextMiddleware,
57+
RawContextMiddleware,
5858
plugins=(
5959
plugins.RequestIdPlugin(),
6060
plugins.CorrelationIdPlugin()

‎tests/conftest.py‎

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,43 @@
1+
from typing import Generator
2+
3+
import pytest
4+
from sqlalchemy import create_engine, text
5+
from sqlalchemy.exc import SQLAlchemyError
6+
7+
from app.core.config import settings
8+
9+
10+
# @pytest.fixture(scope="session")
11+
def setup_db() -> Generator:
12+
engine = create_engine(f"{settings.DATABASE_URI.replace('+asyncpg', '')}")
13+
conn = engine.connect()
14+
# トランザクションを一度終了させる
15+
conn.execute(text("commit"))
16+
try:
17+
conn.execute(text("drop database test"))
18+
except SQLAlchemyError:
19+
pass
20+
finally:
21+
conn.close()
22+
23+
conn = engine.connect()
24+
# トランザクションを一度終了させる
25+
conn.execute(text("commit"))
26+
conn.execute(text("create database test"))
27+
conn.close()
28+
29+
yield
30+
31+
conn = engine.connect()
32+
# トランザクションを一度終了させる
33+
conn.execute(text("commit"))
34+
try:
35+
conn.execute(text("drop database test"))
36+
except SQLAlchemyError:
37+
pass
38+
conn.close()
39+
40+
141
pytest_plugins = [
242
"tests.fixtures",
343
]

‎tests/fixtures/client.py‎

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,23 @@
11
import pytest
2-
from starlette.testclient import TestClient
2+
import pytest_asyncio
3+
from async_asgi_testclient import TestClient
4+
5+
from app.core.config import settings
6+
# from starlette.testclient import TestClient
37

48
from manage import app
59

610

7-
@pytest.fixture
8-
def client():
9-
with TestClient(app) as client:
11+
# @pytest.fixture
12+
# def client():
13+
# with TestClient(app) as client:
14+
# yield client
15+
16+
@pytest_asyncio.fixture
17+
async def client():
18+
scope = {"client": (settings.HOST, str(settings.PORT))}
19+
20+
async with TestClient(
21+
app, scope=scope
22+
) as client:
1023
yield client

‎tests/fixtures/db.py‎

Lines changed: 79 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,92 @@
1-
import contextlib
1+
fromtypingimport Generator, AsyncGenerator
22

33
import pytest
4-
from sqlalchemy import create_engine
5-
from sqlalchemy.orm import sessionmaker, Session
4+
import pytest_asyncio
5+
from sqlalchemy import create_engine, event
6+
from sqlalchemy.exc import SQLAlchemyError
7+
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
8+
from sqlalchemy.orm import Session, SessionTransaction
69

710
from app.core.config import settings
8-
from app.database import BaseModel
11+
from app.database import BaseModel, get_session
912
from manage import app
1013

11-
engine = create_engine(settings.DATABASE_URI.replace('+asyncpg', ''), pool_pre_ping=True)
12-
session_factory = sessionmaker(
13-
engine, expire_on_commit=False, autocommit=False, autoflush=False
14-
)
1514

16-
BaseModel.metadata.create_all(bind=engine)
15+
# engine = create_engine(settings.DATABASE_URI.replace('+asyncpg', ''), pool_pre_ping=True)
16+
# session_factory = sessionmaker(
17+
# engine, expire_on_commit=False, autocommit=False, autoflush=False
18+
# )
1719

20+
# @app.on_event("startup")
21+
# async def startup():
22+
# BaseModel.metadata.create_all(bind=engine)
23+
#
24+
#
25+
# @app.on_event("shutdown")
26+
# async def shutdown():
27+
# BaseModel.metadata.drop_all(bind=engine)
28+
#
29+
#
30+
# def clear_db():
31+
# with contextlib.closing(engine.connect()) as con:
32+
# with con.begin() as trans:
33+
# for table in reversed(BaseModel.metadata.sorted_tables):
34+
# con.execute(table.delete())
35+
# trans.commit()
36+
#
37+
#
38+
# @pytest.fixture
39+
# def session() -> Session:
40+
# with session_factory() as session:
41+
# print(f"{bcolors.OKCYAN.value}WITH SESSION{bcolors.ENDC.value}")
42+
# yield session
43+
# print(f"{bcolors.OKCYAN.value}CLEAR DB{bcolors.ENDC.value}")
44+
# clear_db()
1845

19-
@app.on_event("shutdown")
20-
async def shutdown():
21-
BaseModel.metadata.drop_all(bind=engine)
2246

47+
@pytest.fixture(scope="session", autouse=True)
48+
def setup_test_db() -> Generator:
49+
engine = create_engine(f"{settings.DATABASE_URI.replace('+asyncpg', '')}")
2350

24-
def clear_db():
25-
with contextlib.closing(engine.connect()) as con:
26-
trans = con.begin()
27-
for table in reversed(BaseModel.metadata.sorted_tables):
28-
con.execute(table.delete())
29-
trans.commit()
51+
with engine.begin():
52+
BaseModel.metadata.drop_all(engine)
53+
BaseModel.metadata.create_all(engine)
54+
yield
55+
BaseModel.metadata.drop_all(engine)
3056

3157

32-
@pytest.fixture
33-
def session() -> Session:
34-
with session_factory() as session:
35-
yield session
36-
# clear_db()
58+
@pytest_asyncio.fixture(autouse=True)
59+
async def session() -> AsyncGenerator:
60+
# https://github.com/sqlalchemy/sqlalchemy/issues/5811#issuecomment-756269881
61+
async_engine = create_async_engine(f"{settings.DATABASE_URI}")
62+
async with async_engine.connect() as conn:
63+
await conn.begin()
64+
await conn.begin_nested()
65+
AsyncSessionLocal = async_sessionmaker(
66+
autocommit=False,
67+
autoflush=False,
68+
bind=conn,
69+
future=True,
70+
)
71+
72+
async_session = AsyncSessionLocal()
73+
74+
@event.listens_for(async_session.sync_session, "after_transaction_end")
75+
def end_savepoint(session: Session, transaction: SessionTransaction) -> None:
76+
if conn.closed:
77+
return
78+
if not conn.in_nested_transaction():
79+
if conn.sync_connection:
80+
conn.sync_connection.begin_nested()
81+
82+
def test_get_session() -> Generator:
83+
try:
84+
yield AsyncSessionLocal()
85+
except SQLAlchemyError:
86+
pass
87+
88+
app.dependency_overrides[get_session] = test_get_session
89+
90+
yield async_session
91+
await async_session.close()
92+
await conn.rollback()

‎tests/requests/example.py‎

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from starlette.testclient import TestClient
1+
from async_asgi_testclient import TestClient
22

33

4-
def example_get_request(client: TestClient, example_id: int):
5-
return client.get(f'/example/{example_id}')
4+
asyncdef example_get_request(client: TestClient, example_id: int):
5+
return awaitclient.get(f'/example/{example_id}')

‎tests/template/test_example.py‎

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,27 @@
11
from http import HTTPStatus
22

3-
from sqlalchemy.orm import Session
4-
from starlette.testclient import TestClient
3+
import pytest
4+
from async_asgi_testclient import TestClient
5+
from sqlalchemy.ext.asyncio import AsyncSession
56

67
from app.example.models import TestTable
78
from tests.requests.example import example_get_request
89

910

10-
def test_index(client: TestClient, session: Session):
11+
@pytest.mark.asyncio
12+
async def test_index(client: TestClient, session: AsyncSession):
13+
session = session
1114
example_id = 1
1215
session.add(TestTable(test_field=example_id))
13-
session.commit()
16+
awaitsession.commit()
1417

15-
response = example_get_request(client, example_id)
18+
response = awaitexample_get_request(client, example_id)
1619
assert response.status_code == HTTPStatus.OK
1720

1821

19-
def test_index_not_found(client: TestClient, session: Session):
22+
@pytest.mark.asyncio
23+
async def test_index_not_found(client: TestClient):
2024
example_id = 69
2125

22-
response = example_get_request(client, example_id)
26+
response = awaitexample_get_request(client, example_id)
2327
assert response.status_code == HTTPStatus.NOT_FOUND

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /