|
1 | | -import contextlib |
2 | 1 | import logging
|
3 | 2 | import os
|
| 3 | +from collections.abc import AsyncIterator |
| 4 | +from contextlib import asynccontextmanager |
| 5 | +from typing import TypedDict |
4 | 6 |
|
5 | | -import azure.identity |
6 | 7 | from dotenv import load_dotenv
|
7 | | -from environs import Env |
8 | 8 | from fastapi import FastAPI
|
9 | | - |
10 | | -from .globals import global_storage |
11 | | -from .openai_clients import create_openai_chat_client, create_openai_embed_client |
12 | | -from .postgres_engine import create_postgres_engine_from_env |
| 9 | +from openai import AsyncAzureOpenAI, AsyncOpenAI |
| 10 | +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker |
| 11 | + |
| 12 | +from fastapi_app.dependencies import ( |
| 13 | + FastAPIAppContext, |
| 14 | + common_parameters, |
| 15 | + create_async_sessionmaker, |
| 16 | + get_azure_credentials, |
| 17 | +) |
| 18 | +from fastapi_app.openai_clients import create_openai_chat_client, create_openai_embed_client |
| 19 | +from fastapi_app.postgres_engine import create_postgres_engine_from_env |
13 | 20 |
|
14 | 21 | logger = logging.getLogger("ragapp")
|
15 | 22 |
|
16 | 23 |
|
17 | | -@contextlib.asynccontextmanager |
18 | | -async def lifespan(app: FastAPI): |
19 | | - load_dotenv(override=True) |
| 24 | +class State(TypedDict): |
| 25 | + sessionmaker: async_sessionmaker[AsyncSession] |
| 26 | + context: FastAPIAppContext |
| 27 | + chat_client: AsyncOpenAI | AsyncAzureOpenAI |
| 28 | + embed_client: AsyncOpenAI | AsyncAzureOpenAI |
20 | 29 |
|
21 | | - azure_credential = None |
22 | | - try: |
23 | | - if client_id := os.getenv("APP_IDENTITY_ID"): |
24 | | - # Authenticate using a user-assigned managed identity on Azure |
25 | | - # See web.bicep for value of APP_IDENTITY_ID |
26 | | - logger.info( |
27 | | - "Using managed identity for client ID %s", |
28 | | - client_id, |
29 | | - ) |
30 | | - azure_credential = azure.identity.ManagedIdentityCredential(client_id=client_id) |
31 | | - else: |
32 | | - azure_credential = azure.identity.DefaultAzureCredential() |
33 | | - except Exception as e: |
34 | | - logger.warning("Failed to authenticate to Azure: %s", e) |
35 | 30 |
|
| 31 | +@asynccontextmanager |
| 32 | +async def lifespan(app: FastAPI) -> AsyncIterator[State]: |
| 33 | + context = await common_parameters() |
| 34 | + azure_credential = await get_azure_credentials() |
36 | 35 | engine = await create_postgres_engine_from_env(azure_credential)
|
37 | | - global_storage.engine = engine |
38 | | - |
39 | | - openai_chat_client, openai_chat_model = await create_openai_chat_client(azure_credential) |
40 | | - global_storage.openai_chat_client = openai_chat_client |
41 | | - global_storage.openai_chat_model = openai_chat_model |
42 | | - |
43 | | - openai_embed_client, openai_embed_model, openai_embed_dimensions = await create_openai_embed_client( |
44 | | - azure_credential |
45 | | - ) |
46 | | - global_storage.openai_embed_client = openai_embed_client |
47 | | - global_storage.openai_embed_model = openai_embed_model |
48 | | - global_storage.openai_embed_dimensions = openai_embed_dimensions |
49 | | - |
50 | | - yield |
| 36 | + sessionmaker = await create_async_sessionmaker(engine) |
| 37 | + chat_client = await create_openai_chat_client(azure_credential) |
| 38 | + embed_client = await create_openai_embed_client(azure_credential) |
51 | 39 |
|
| 40 | + yield {"sessionmaker": sessionmaker, "context": context, "chat_client": chat_client, "embed_client": embed_client} |
52 | 41 | await engine.dispose()
|
53 | 42 |
|
54 | 43 |
|
55 | | -def create_app(): |
56 | | - env = Env() |
57 | | - |
58 | | - if not os.getenv("RUNNING_IN_PRODUCTION"): |
59 | | - env.read_env(".env") |
60 | | - logging.basicConfig(level=logging.INFO) |
61 | | - else: |
| 44 | +def create_app(testing: bool = False): |
| 45 | + if os.getenv("RUNNING_IN_PRODUCTION"): |
62 | 46 | logging.basicConfig(level=logging.WARNING)
|
| 47 | + else: |
| 48 | + if not testing: |
| 49 | + load_dotenv(override=True) |
| 50 | + logging.basicConfig(level=logging.INFO) |
63 | 51 |
|
64 | 52 | app = FastAPI(docs_url="/docs", lifespan=lifespan)
|
65 | 53 |
|
66 | | - from . import api_routes # noqa |
67 | | - from . import frontend_routes # noqa |
| 54 | + from fastapi_app.routes import api_routes, frontend_routes |
68 | 55 |
|
69 | 56 | app.include_router(api_routes.router)
|
70 | 57 | app.mount("/", frontend_routes.router)
|
|
0 commit comments