FastAPI + SQLAlchemy example¶
This example shows how to use Dependency Injector with FastAPI and
SQLAlchemy.
The source code is available on the Github.
Thanks to @ShvetsovYura for providing initial example: FastAPI_DI_SqlAlchemy.
Application structure¶
Application has next structure:
./ ├──webapp/ │├──__init__.py │├──application.py │├──containers.py │├──database.py │├──endpoints.py │├──models.py │├──repositories.py │├──services.py │└──tests.py ├──config.yml ├──docker-compose.yml ├──Dockerfile └──requirements.txt
Application factory¶
Application factory creates container, wires it with the endpoints module, creates
FastAPI app, and setup routes.
Application factory also creates database if it does not exist.
Listing of webapp/application.py:
"""Application module.""" fromfastapiimport FastAPI from.containersimport Container from.import endpoints defcreate_app() -> FastAPI: container = Container() db = container.db() db.create_database() app = FastAPI() app.container = container app.include_router(endpoints.router) return app app = create_app()
Endpoints¶
Module endpoints contains example endpoints. Endpoints have a dependency on user service.
User service is injected using Wiring feature. See webapp/endpoints.py:
"""Endpoints module.""" fromtypingimport Annotated fromfastapiimport APIRouter, Depends, Response, status fromdependency_injector.wiringimport Provide, inject from.containersimport Container from.repositoriesimport NotFoundError from.servicesimport UserService router = APIRouter() @router.get("/users") @inject defget_list( user_service: Annotated[UserService, Depends(Provide[Container.user_service])], ): return user_service.get_users() @router.get("/users/{user_id}") @inject defget_by_id( user_id: int, user_service: Annotated[UserService, Depends(Provide[Container.user_service])], ): try: return user_service.get_user_by_id(user_id) except NotFoundError: return Response(status_code=status.HTTP_404_NOT_FOUND) @router.post("/users", status_code=status.HTTP_201_CREATED) @inject defadd( user_service: Annotated[UserService, Depends(Provide[Container.user_service])], ): return user_service.create_user() @router.delete("/users/{user_id}", status_code=status.HTTP_204_NO_CONTENT) @inject defremove( user_id: int, user_service: Annotated[UserService, Depends(Provide[Container.user_service])], ) -> Response: try: user_service.delete_user_by_id(user_id) except NotFoundError: return Response(status_code=status.HTTP_404_NOT_FOUND) else: return Response(status_code=status.HTTP_204_NO_CONTENT) @router.get("/status") defget_status(): return {"status": "OK"}
Container¶
Declarative container wires example user service, user repository, and utility database class.
See webapp/containers.py:
"""Containers module.""" fromdependency_injectorimport containers, providers from.databaseimport Database from.repositoriesimport UserRepository from.servicesimport UserService classContainer(containers.DeclarativeContainer): wiring_config = containers.WiringConfiguration(modules=[".endpoints"]) config = providers.Configuration(yaml_files=["config.yml"]) db = providers.Singleton(Database, db_url=config.db.url) user_repository = providers.Factory( UserRepository, session_factory=db.provided.session, ) user_service = providers.Factory( UserService, user_repository=user_repository, )
Services¶
Module services contains example user service. See webapp/services.py:
"""Services module.""" fromuuidimport uuid4 fromtypingimport Iterator from.repositoriesimport UserRepository from.modelsimport User classUserService: def__init__(self, user_repository: UserRepository) -> None: self._repository: UserRepository = user_repository defget_users(self) -> Iterator[User]: return self._repository.get_all() defget_user_by_id(self, user_id: int) -> User: return self._repository.get_by_id(user_id) defcreate_user(self) -> User: uid = uuid4() return self._repository.add(email=f"{uid}@email.com", password="pwd") defdelete_user_by_id(self, user_id: int) -> None: return self._repository.delete_by_id(user_id)
Repositories¶
Module repositories contains example user repository. See webapp/repositories.py:
"""Repositories module.""" fromcontextlibimport AbstractContextManager fromtypingimport Callable, Iterator fromsqlalchemy.ormimport Session from.modelsimport User classUserRepository: def__init__(self, session_factory: Callable[..., AbstractContextManager[Session]]) -> None: self.session_factory = session_factory defget_all(self) -> Iterator[User]: with self.session_factory() as session: return session.query(User).all() defget_by_id(self, user_id: int) -> User: with self.session_factory() as session: user = session.query(User).filter(User.id == user_id).first() if not user: raise UserNotFoundError(user_id) return user defadd(self, email: str, password: str, is_active: bool = True) -> User: with self.session_factory() as session: user = User(email=email, hashed_password=password, is_active=is_active) session.add(user) session.commit() session.refresh(user) return user defdelete_by_id(self, user_id: int) -> None: with self.session_factory() as session: entity: User = session.query(User).filter(User.id == user_id).first() if not entity: raise UserNotFoundError(user_id) session.delete(entity) session.commit() classNotFoundError(Exception): entity_name: str def__init__(self, entity_id): super().__init__(f"{self.entity_name} not found, id: {entity_id}") classUserNotFoundError(NotFoundError): entity_name: str = "User"
Models¶
Module models contains example SQLAlchemy user model. See webapp/models.py:
"""Models module.""" fromsqlalchemyimport Column, String, Boolean, Integer from.databaseimport Base classUser(Base): __tablename__ = "users" id = Column(Integer, primary_key=True) email = Column(String, unique=True) hashed_password = Column(String) is_active = Column(Boolean, default=True) def__repr__(self): return f"<User(id={self.id}, " \ f"email=\"{self.email}\", " \ f"hashed_password=\"{self.hashed_password}\", " \ f"is_active={self.is_active})>"
Database¶
Module database defines declarative base and utility class with engine and session factory.
See webapp/database.py:
"""Database module.""" fromcontextlibimport contextmanager, AbstractContextManager fromtypingimport Callable importlogging fromsqlalchemyimport create_engine, orm fromsqlalchemy.ext.declarativeimport declarative_base fromsqlalchemy.ormimport Session logger = logging.getLogger(__name__) Base = declarative_base() classDatabase: def__init__(self, db_url: str) -> None: self._engine = create_engine(db_url, echo=True) self._session_factory = orm.scoped_session( orm.sessionmaker( autocommit=False, autoflush=False, bind=self._engine, ), ) defcreate_database(self) -> None: Base.metadata.create_all(self._engine) @contextmanager defsession(self) -> Callable[..., AbstractContextManager[Session]]: session: Session = self._session_factory() try: yield session except Exception: logger.exception("Session rollback because of exception") session.rollback() raise finally: session.close()
Tests¶
Tests use Provider overriding feature to replace repository with a mock. See webapp/tests.py:
"""Tests module.""" fromunittestimport mock importpytest fromfastapi.testclientimport TestClient from.repositoriesimport UserRepository, UserNotFoundError from.modelsimport User from.applicationimport app @pytest.fixture defclient(): yield TestClient(app) deftest_get_list(client): repository_mock = mock.Mock(spec=UserRepository) repository_mock.get_all.return_value = [ User(id=1, email="test1@email.com", hashed_password="pwd", is_active=True), User(id=2, email="test2@email.com", hashed_password="pwd", is_active=False), ] with app.container.user_repository.override(repository_mock): response = client.get("/users") assert response.status_code == 200 data = response.json() assert data == [ {"id": 1, "email": "test1@email.com", "hashed_password": "pwd", "is_active": True}, {"id": 2, "email": "test2@email.com", "hashed_password": "pwd", "is_active": False}, ] deftest_get_by_id(client): repository_mock = mock.Mock(spec=UserRepository) repository_mock.get_by_id.return_value = User( id=1, email="xyz@email.com", hashed_password="pwd", is_active=True, ) with app.container.user_repository.override(repository_mock): response = client.get("/users/1") assert response.status_code == 200 data = response.json() assert data == {"id": 1, "email": "xyz@email.com", "hashed_password": "pwd", "is_active": True} repository_mock.get_by_id.assert_called_once_with(1) deftest_get_by_id_404(client): repository_mock = mock.Mock(spec=UserRepository) repository_mock.get_by_id.side_effect = UserNotFoundError(1) with app.container.user_repository.override(repository_mock): response = client.get("/users/1") assert response.status_code == 404 @mock.patch("webapp.services.uuid4", return_value="xyz") deftest_add(_, client): repository_mock = mock.Mock(spec=UserRepository) repository_mock.add.return_value = User( id=1, email="xyz@email.com", hashed_password="pwd", is_active=True, ) with app.container.user_repository.override(repository_mock): response = client.post("/users") assert response.status_code == 201 data = response.json() assert data == {"id": 1, "email": "xyz@email.com", "hashed_password": "pwd", "is_active": True} repository_mock.add.assert_called_once_with(email="xyz@email.com", password="pwd") deftest_remove(client): repository_mock = mock.Mock(spec=UserRepository) with app.container.user_repository.override(repository_mock): response = client.delete("/users/1") assert response.status_code == 204 repository_mock.delete_by_id.assert_called_once_with(1) deftest_remove_404(client): repository_mock = mock.Mock(spec=UserRepository) repository_mock.delete_by_id.side_effect = UserNotFoundError(1) with app.container.user_repository.override(repository_mock): response = client.delete("/users/1") assert response.status_code == 404 deftest_status(client): response = client.get("/status") assert response.status_code == 200 data = response.json() assert data == {"status": "OK"}
Sources¶
The source code is available on the Github.
Sponsor the project on GitHub:
[フレーム]