1
1
import logging
2
2
import os
3
3
4
- from azure .identity .aio import DefaultAzureCredential
4
+ from azure .identity import DefaultAzureCredential
5
+ from sqlalchemy import event
5
6
from sqlalchemy .ext .asyncio import AsyncEngine , create_async_engine
6
7
7
8
logger = logging .getLogger ("ragapp" )
8
9
9
10
10
11
async def create_postgres_engine (* , host , username , database , password , sslmode , azure_credential ) -> AsyncEngine :
12
+ def get_password_from_azure_credential ():
13
+ token = azure_credential .get_token ("https://ossrdbms-aad.database.windows.net/.default" )
14
+ return token .token
15
+
16
+ token_based_password = False
11
17
if host .endswith (".database.azure.com" ):
18
+ token_based_password = True
12
19
logger .info ("Authenticating to Azure Database for PostgreSQL using Azure Identity..." )
13
20
if azure_credential is None :
14
21
raise ValueError ("Azure credential must be provided for Azure Database for PostgreSQL" )
15
- token = await azure_credential .get_token ("https://ossrdbms-aad.database.windows.net/.default" )
16
- password = token .token
22
+ password = get_password_from_azure_credential ()
17
23
else :
18
24
logger .info ("Authenticating to PostgreSQL using password..." )
19
25
@@ -27,16 +33,20 @@ async def create_postgres_engine(*, host, username, database, password, sslmode,
27
33
echo = False ,
28
34
)
29
35
36
+ @event .listens_for (engine .sync_engine , "do_connect" )
37
+ def update_password_token (dialect , conn_rec , cargs , cparams ):
38
+ if token_based_password :
39
+ logger .info ("Updating password token for Azure Database for PostgreSQL" )
40
+ cparams ["password" ] = get_password_from_azure_credential ()
41
+
30
42
return engine
31
43
32
44
33
45
async def create_postgres_engine_from_env (azure_credential = None ) -> AsyncEngine :
34
- must_close = False
35
46
if azure_credential is None and os .environ ["POSTGRES_HOST" ].endswith (".database.azure.com" ):
36
47
azure_credential = DefaultAzureCredential ()
37
- must_close = True
38
48
39
- engine = await create_postgres_engine (
49
+ return await create_postgres_engine (
40
50
host = os .environ ["POSTGRES_HOST" ],
41
51
username = os .environ ["POSTGRES_USERNAME" ],
42
52
database = os .environ ["POSTGRES_DATABASE" ],
@@ -45,28 +55,16 @@ async def create_postgres_engine_from_env(azure_credential=None) -> AsyncEngine:
45
55
azure_credential = azure_credential ,
46
56
)
47
57
48
- if must_close :
49
- await azure_credential .close ()
50
-
51
- return engine
52
-
53
58
54
59
async def create_postgres_engine_from_args (args , azure_credential = None ) -> AsyncEngine :
55
- must_close = False
56
60
if azure_credential is None and args .host .endswith (".database.azure.com" ):
57
61
azure_credential = DefaultAzureCredential ()
58
- must_close = True
59
62
60
- engine = await create_postgres_engine (
63
+ return await create_postgres_engine (
61
64
host = args .host ,
62
65
username = args .username ,
63
66
database = args .database ,
64
67
password = args .password ,
65
68
sslmode = args .sslmode ,
66
69
azure_credential = azure_credential ,
67
70
)
68
-
69
- if must_close :
70
- await azure_credential .close ()
71
-
72
- return engine
0 commit comments