11import asyncio
22from contextvars import ContextVar
3- from typing import Dict , Optional , Union
3+ from typing import Dict , Optional , Type , Union
44
5- from sqlalchemy .engine import Engine
65from sqlalchemy .engine .url import URL
7- from sqlalchemy .ext .asyncio import AsyncSession , create_async_engine
6+ from sqlalchemy .ext .asyncio import AsyncEngine , AsyncSession , create_async_engine
87from starlette .middleware .base import BaseHTTPMiddleware , RequestResponseEndpoint
98from starlette .requests import Request
109from starlette .types import ASGIApp
1110
12- from fastapi_async_sqlalchemy .exceptions import MissingSessionError , SessionNotInitialisedError
11+ from fastapi_async_sqlalchemy .exceptions import (
12+ MissingSessionError ,
13+ SessionNotInitialisedError ,
14+ )
1315
1416try :
15- from sqlalchemy .ext .asyncio import async_sessionmaker # noqa: F811
17+ from sqlalchemy .ext .asyncio import async_sessionmaker
1618except ImportError :
17- from sqlalchemy .orm import sessionmaker as async_sessionmaker
19+ from sqlalchemy .orm import sessionmaker as async_sessionmaker # type: ignore
1820
1921# Try to import SQLModel's AsyncSession which has the .exec() method
2022try :
2123 from sqlmodel .ext .asyncio .session import AsyncSession as SQLModelAsyncSession
2224
23- DefaultAsyncSession = SQLModelAsyncSession
25+ DefaultAsyncSession : Type [ AsyncSession ] = SQLModelAsyncSession # type: ignore
2426except ImportError :
25- DefaultAsyncSession = AsyncSession
27+ DefaultAsyncSession : Type [ AsyncSession ] = AsyncSession # type: ignore
2628
2729
28- def create_middleware_and_session_proxy ():
30+ def create_middleware_and_session_proxy () -> tuple :
2931 _Session : Optional [async_sessionmaker ] = None
30- _session : ContextVar [Optional [DefaultAsyncSession ]] = ContextVar ("_session" , default = None )
32+ _session : ContextVar [Optional [AsyncSession ]] = ContextVar ("_session" , default = None )
3133 _multi_sessions_ctx : ContextVar [bool ] = ContextVar ("_multi_sessions_context" , default = False )
3234 _commit_on_exit_ctx : ContextVar [bool ] = ContextVar ("_commit_on_exit_ctx" , default = False )
3335 # Usage of context vars inside closures is not recommended, since they are not properly
@@ -39,9 +41,9 @@ def __init__(
3941 self ,
4042 app : ASGIApp ,
4143 db_url : Optional [Union [str , URL ]] = None ,
42- custom_engine : Optional [Engine ] = None ,
43- engine_args : Dict = None ,
44- session_args : Dict = None ,
44+ custom_engine : Optional [AsyncEngine ] = None ,
45+ engine_args : Optional [ Dict ] = None ,
46+ session_args : Optional [ Dict ] = None ,
4547 commit_on_exit : bool = False ,
4648 ):
4749 super ().__init__ (app )
@@ -52,13 +54,18 @@ def __init__(
5254 if not custom_engine and not db_url :
5355 raise ValueError ("You need to pass a db_url or a custom_engine parameter." )
5456 if not custom_engine :
57+ if db_url is None :
58+ raise ValueError ("db_url cannot be None when custom_engine is not provided" )
5559 engine = create_async_engine (db_url , ** engine_args )
5660 else :
5761 engine = custom_engine
5862
5963 nonlocal _Session
6064 _Session = async_sessionmaker (
61- engine , class_ = DefaultAsyncSession , expire_on_commit = False , ** session_args
65+ engine ,
66+ class_ = DefaultAsyncSession ,
67+ expire_on_commit = False ,
68+ ** session_args ,
6269 )
6370
6471 async def dispatch (self , request : Request , call_next : RequestResponseEndpoint ):
@@ -67,7 +74,7 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
6774
6875 class DBSessionMeta (type ):
6976 @property
70- def session (self ) -> DefaultAsyncSession :
77+ def session (self ) -> AsyncSession :
7178 """Return an instance of Session local to the current async context."""
7279 if _Session is None :
7380 raise SessionNotInitialisedError
@@ -123,7 +130,7 @@ async def cleanup():
123130 class DBSession (metaclass = DBSessionMeta ):
124131 def __init__ (
125132 self ,
126- session_args : Dict = None ,
133+ session_args : Optional [ Dict ] = None ,
127134 commit_on_exit : bool = False ,
128135 multi_sessions : bool = False ,
129136 ):
0 commit comments