1+ from contextlib import contextmanager
12import contextvars
23
34
4- connection_context : contextvars .ContextVar = contextvars .ContextVar ('connection_context' )
5- session_context : contextvars .ContextVar = contextvars .ContextVar ('session_context' )
5+ _connection_context : contextvars .ContextVar = contextvars .ContextVar ('connection_context' )
6+ _session_context : contextvars .ContextVar = contextvars .ContextVar ('session_context' )
7+
8+
69
710
811def get_connection_context (fn_name ):
@@ -11,7 +14,7 @@ def get_connection_context(fn_name):
1114 ``RuntimeError`` with a helpful message.
1215 '''
1316 try :
14- return connection_context .get ()
17+ return _connection_context .get ()
1518 except LookupError :
1619 raise RuntimeError (f'{ fn_name } () must be called in a connection context.' )
1720
@@ -22,19 +25,41 @@ def get_session_context(fn_name):
2225 ``RuntimeError`` with a helpful message.
2326 '''
2427 try :
25- return session_context .get ()
28+ return _session_context .get ()
2629 except LookupError :
2730 raise RuntimeError (f'{ fn_name } () must be called in a session context.' )
2831
2932
33+ @contextmanager
34+ def connection_context (connection ):
35+ ''' This context manager installs ``connection`` as the session context for the current
36+ Trio task. '''
37+ token = _connection_context .set (connection )
38+ try :
39+ yield
40+ finally :
41+ _connection_context .reset (token )
42+
43+
44+ @contextmanager
45+ def session_context (session ):
46+ ''' This context manager installs ``session`` as the session context for the current
47+ Trio task. '''
48+ token = _session_context .set (session )
49+ try :
50+ yield
51+ finally :
52+ _session_context .reset (token )
53+
54+
3055def set_global_connection (connection ):
3156 '''
3257 Install ``connection`` in the root context so that it will become the default
3358 connection for all tasks. This is generally not recommended, except it may be
3459 necessary in certain use cases such as running inside Jupyter notebook.
3560 '''
36- global connection_context
37- connection_context = contextvars .ContextVar ('connection_context ' ,
61+ global _connection_context
62+ _connection_context = contextvars .ContextVar ('_connection_context ' ,
3863 default = connection )
3964
4065
@@ -44,5 +69,5 @@ def set_global_session(session):
4469 session for all tasks. This is generally not recommended, except it may be
4570 necessary in certain use cases such as running inside Jupyter notebook.
4671 '''
47- global session_context
48- session_context = contextvars .ContextVar ('session_context ' , default = session )
72+ global _session_context
73+ _session_context = contextvars .ContextVar ('_session_context ' , default = session )
0 commit comments