Skip to content

Commit 6e07717

Browse files
committed
Refactor connection/session context a bit
Add helper functions to trio_cdp.context. These can be used in the main module but are also helpful in downstream projects where sessions might be reused across multiple tasks.
1 parent c22511b commit 6e07717

2 files changed

Lines changed: 36 additions & 15 deletions

File tree

trio_cdp/__init__.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -239,11 +239,8 @@ async def open_session(self, target_id: cdp.target.TargetID) -> \
239239
and it will execute on the current session automatically.
240240
'''
241241
session = await self.connect_session(target_id)
242-
token = session_context.set(session)
243-
try:
242+
with session_context(session):
244243
yield session
245-
finally:
246-
session_context.reset(token)
247244

248245
async def connect_session(self, target_id: cdp.target.TargetID) -> 'CdpSession':
249246
'''
@@ -369,11 +366,10 @@ async def open_cdp(url) -> typing.AsyncIterator[CdpConnection]:
369366
'''
370367
async with trio.open_nursery() as nursery:
371368
conn = await connect_cdp(nursery, url)
372-
token = connection_context.set(conn)
373369
try:
374-
yield conn
370+
with connection_context(conn):
371+
yield conn
375372
finally:
376-
connection_context.reset(token)
377373
await conn.aclose()
378374

379375

trio_cdp/context.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
from contextlib import contextmanager
12
import contextvars
23

34

4-
connection_context: contextvars.ContextVar = contextvars.ContextVar('connection_context')
5-
session_context: contextvars.ContextVar = contextvars.ContextVar('session_context')
5+
_connection_context: contextvars.ContextVar = contextvars.ContextVar('connection_context')
6+
_session_context: contextvars.ContextVar = contextvars.ContextVar('session_context')
7+
8+
69

710

811
def get_connection_context(fn_name):
@@ -11,7 +14,7 @@ def get_connection_context(fn_name):
1114
``RuntimeError`` with a helpful message.
1215
'''
1316
try:
14-
return connection_context.get()
17+
return _connection_context.get()
1518
except LookupError:
1619
raise RuntimeError(f'{fn_name}() must be called in a connection context.')
1720

@@ -22,19 +25,41 @@ def get_session_context(fn_name):
2225
``RuntimeError`` with a helpful message.
2326
'''
2427
try:
25-
return session_context.get()
28+
return _session_context.get()
2629
except LookupError:
2730
raise RuntimeError(f'{fn_name}() must be called in a session context.')
2831

2932

33+
@contextmanager
34+
def connection_context(connection):
35+
''' This context manager installs ``connection`` as the session context for the current
36+
Trio task. '''
37+
token = _connection_context.set(connection)
38+
try:
39+
yield
40+
finally:
41+
_connection_context.reset(token)
42+
43+
44+
@contextmanager
45+
def session_context(session):
46+
''' This context manager installs ``session`` as the session context for the current
47+
Trio task. '''
48+
token = _session_context.set(session)
49+
try:
50+
yield
51+
finally:
52+
_session_context.reset(token)
53+
54+
3055
def set_global_connection(connection):
3156
'''
3257
Install ``connection`` in the root context so that it will become the default
3358
connection for all tasks. This is generally not recommended, except it may be
3459
necessary in certain use cases such as running inside Jupyter notebook.
3560
'''
36-
global connection_context
37-
connection_context = contextvars.ContextVar('connection_context',
61+
global _connection_context
62+
_connection_context = contextvars.ContextVar('_connection_context',
3863
default=connection)
3964

4065

@@ -44,5 +69,5 @@ def set_global_session(session):
4469
session for all tasks. This is generally not recommended, except it may be
4570
necessary in certain use cases such as running inside Jupyter notebook.
4671
'''
47-
global session_context
48-
session_context = contextvars.ContextVar('session_context', default=session)
72+
global _session_context
73+
_session_context = contextvars.ContextVar('_session_context', default=session)

0 commit comments

Comments
 (0)