Skip to content

Commit 8f6ecb1

Browse files
committed
Simplified API
* Fix mypy errors regarding incomplete type names for CDP types imported from other modules. * Update tests to match previous bullet. * Regenerate modules.
1 parent a714563 commit 8f6ecb1

39 files changed

Lines changed: 1373 additions & 1071 deletions

generator/generate.py

Lines changed: 36 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
ForwardRef = typing._ForwardRef # type: ignore
1515

1616

17+
cdp_modules = None
18+
19+
1720
def indent(text: str, count: int):
1821
''' Indent text with the specified number of spaces. '''
1922
return tw_indent(text, ' ' * count)
@@ -23,18 +26,17 @@ def main():
2326
''' Main entry point. '''
2427
root = pathlib.Path(__file__).resolve().parent.parent / 'trio_cdp' / 'generated'
2528
clean(root)
26-
modules = list()
27-
for name, module in inspect.getmembers(cdp):
28-
if name.startswith('_') or name in ('cdp', 'util'):
29-
continue
30-
modules.append(name)
29+
ignored = lambda name: name.startswith('_') or name in ('cdp', 'util')
30+
global cdp_modules
31+
cdp_modules = {n:m for n,m in inspect.getmembers(cdp) if not ignored(n)}
32+
for name, module in cdp_modules.items():
3133
generate_module(root, name, module)
3234
init = root / '__init__.py'
3335
with init.open('w') as file:
3436
file.write('# DO NOT EDIT THIS FILE!\n#\n')
3537
file.write('# This code is generated off of PyCDP modules. If you need to make\n')
3638
file.write('# changes, edit the generator and regenerate all of the modules.\n\n')
37-
for module in modules:
39+
for module in cdp_modules:
3840
file.write(f'from . import {module}\n')
3941

4042

@@ -45,17 +47,18 @@ def clean(root: pathlib.Path):
4547
path.unlink()
4648

4749

48-
def generate_module(root: pathlib.Path, module_name: str, module: types.ModuleType):
50+
def generate_module(root: pathlib.Path, module_name: str,
51+
module: types.ModuleType):
4952
''' Generate code for a module. '''
5053
print('* Generating module:', module_name)
5154
module_path = root / f'{module_name}.py'
5255
commands = list()
5356
classes = list()
5457
for name, obj in inspect.getmembers(module):
55-
if name.startswith('_') or name in ('dataclass', 'event_class'):
58+
if name.startswith('_') or name in ('dataclass', 'deprecated', 'event_class'):
5659
continue
5760
if inspect.isfunction(obj):
58-
commands.append(generate_command(module_name, name, obj))
61+
commands.append(generate_command(module, module_name, obj))
5962
elif inspect.isclass(obj):
6063
classes.append(name)
6164

@@ -76,16 +79,19 @@ def generate_module(root: pathlib.Path, module_name: str, module: types.ModuleTy
7679
file.write('\n\n'.join(commands))
7780

7881

79-
def generate_command(module: str, name: str, fn: types.FunctionType):
82+
def generate_command(module: types.ModuleType, module_name: str,
83+
fn: types.FunctionType):
8084
''' Generate code for one command, i.e. one PyCDP wrapper function. '''
81-
print(f' - {name}()')
85+
fn_name = fn.__name__
86+
print(f' - {fn_name}()')
8287
sig = inspect.signature(fn)
88+
type_hints = typing.get_type_hints(fn, globalns=vars(module), localns=None)
8389

8490
# Generate the argument list.
8591
args = list()
8692
call_args = list()
8793
for param in sig.parameters.values():
88-
ann = format_annotation(param.annotation)
94+
ann = format_annotation(module, type_hints[param.name])
8995
if param.default != inspect.Parameter.empty:
9096
default_str = f' = {param.default}'
9197
else:
@@ -107,18 +113,18 @@ def generate_command(module: str, name: str, fn: types.FunctionType):
107113
doc = ''
108114

109115
# The original function returns a generator. We want to grab the return type of the
110-
# generator and set that as the return type of this wrapper function.
111-
return_type = format_annotation(sig.return_annotation.__args__[2])
116+
# generator and set that as the return type of this wrapper function.
117+
return_type = format_annotation(module, type_hints['return'].__args__[2])
112118

113119
# Format the function and return it as a string.
114-
ctx_name, ctx_fn = which_context(module, name)
115-
body = f"{ctx_name} = {ctx_fn}('{module}.{name}')\n"
116-
body += f"return await {ctx_name}.execute(cdp.{module}.{name}({call_arg_str}))"
120+
ctx_name, ctx_fn = which_context(module_name, fn_name)
121+
body = f"{ctx_name} = {ctx_fn}('{module_name}.{fn_name}')\n"
122+
body += f"return await {ctx_name}.execute(cdp.{module_name}.{fn_name}({call_arg_str}))"
117123
body = indent(body, 4)
118-
return f'async def {name}({arg_str}) -> {return_type}:\n{doc}{body}\n'
124+
return f'async def {fn_name}({arg_str}) -> {return_type}:\n{doc}{body}\n'
119125

120126

121-
def format_annotation(ann: typing.Any):
127+
def format_annotation(current_module: types.ModuleType, ann: typing.Any):
122128
'''
123129
Given a type annotation, return a stringified version.
124130
@@ -127,32 +133,31 @@ def format_annotation(ann: typing.Any):
127133
have to access private members to figure out what the specific annotation actually
128134
is.
129135
'''
130-
if isinstance(ann, str):
131-
ann_str = ann
132-
elif isinstance(ann, ForwardRef):
133-
ann_str = ann.__forward_arg__
134-
elif ann in (bool, dict, float, int, str):
135-
ann_str = ann.__name__
136-
elif ann is type(None):
136+
if ann is type(None):
137137
ann_str = 'None'
138+
elif isinstance(ann, type):
139+
if ann.__module__ not in (current_module.__name__, 'builtins'):
140+
ann_str = f'{ann.__module__}.{ann.__name__}'
141+
else:
142+
ann_str = ann.__name__
138143
elif ann._name == 'Any':
139144
ann_str = 'typing.Any'
140145
elif ann._name == 'List':
141-
nested_ann = format_annotation(ann.__args__[0])
146+
nested_ann = format_annotation(current_module, ann.__args__[0])
142147
ann_str = f'typing.List[{nested_ann}]'
143148
elif ann._name == 'Tuple':
144-
nested_anns = ', '.join(format_annotation(a) for a in ann.__args__)
149+
nested_anns = ', '.join(format_annotation(current_module, a) for a in ann.__args__)
145150
ann_str = f'typing.Tuple[{nested_anns}]'
146151
elif ann._name is None and len(ann.__args__) > 1:
147152
# For some reason union annotations don't have a name?
148153
# If the union has two members and one of them is NoneType, then it's really
149154
# a typing.Optional.
150155
if len(ann.__args__) == 2 and any(a is type(None) for a in ann.__args__):
151-
nested_ann = format_annotation(
152-
[a for a in ann.__args__ if a is not type(None)][0])
156+
opt_type = [a for a in ann.__args__ if a is not type(None)][0]
157+
nested_ann = format_annotation(current_module, opt_type)
153158
ann_str = f'typing.Optional[{nested_ann}]'
154159
else:
155-
nested_anns = ', '.join(format_annotation(a) for a in ann.__args__)
160+
nested_anns = ', '.join(format_annotation(current_module, a) for a in ann.__args__)
156161
ann_str = f'typing.Union[{nested_anns}]'
157162
else:
158163
raise Exception(f'Cannot format annotation: {repr(ann)}')

generator/test_generate.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@
55
from .generate import generate_command
66

77

8-
def test_generate_command():
8+
def test_dom_query_selector():
99

1010
expected = dedent("""\
1111
async def query_selector(
1212
node_id: NodeId,
1313
selector: str
1414
) -> NodeId:
1515
'''
16-
Executes `querySelector` on a given node.
16+
Executes ``querySelector`` on a given node.
1717
1818
:param node_id: Id of the node to query upon.
1919
:param selector: Selector string.
@@ -23,4 +23,45 @@ async def query_selector(
2323
return await session.execute(cdp.dom.query_selector(node_id, selector))
2424
""")
2525

26-
assert expected == generate_command('dom', 'query_selector', cdp.dom.query_selector)
26+
assert expected == generate_command(cdp.dom, 'dom', cdp.dom.query_selector)
27+
28+
29+
def test_accessibility_disable():
30+
expected = dedent("""\
31+
async def disable() -> None:
32+
'''
33+
Disables the accessibility domain.
34+
'''
35+
session = get_session_context('accessibility.disable')
36+
return await session.execute(cdp.accessibility.disable())
37+
""")
38+
39+
assert expected == generate_command(cdp.accessibility, 'accessibility',
40+
cdp.accessibility.disable)
41+
42+
43+
def test_accessibility_get_partial_ax_tree():
44+
expected = dedent("""\
45+
async def get_partial_ax_tree(
46+
node_id: typing.Optional[cdp.dom.NodeId] = None,
47+
backend_node_id: typing.Optional[cdp.dom.BackendNodeId] = None,
48+
object_id: typing.Optional[cdp.runtime.RemoteObjectId] = None,
49+
fetch_relatives: typing.Optional[bool] = None
50+
) -> typing.List[AXNode]:
51+
'''
52+
Fetches the accessibility node and partial accessibility tree for this DOM node, if it exists.
53+
54+
**EXPERIMENTAL**
55+
56+
:param node_id: *(Optional)* Identifier of the node to get the partial accessibility tree for.
57+
:param backend_node_id: *(Optional)* Identifier of the backend node to get the partial accessibility tree for.
58+
:param object_id: *(Optional)* JavaScript object id of the node wrapper to get the partial accessibility tree for.
59+
:param fetch_relatives: *(Optional)* Whether to fetch this nodes ancestors, siblings and children. Defaults to true.
60+
:returns: The ``Accessibility.AXNode`` for this DOM node, if it exists, plus its ancestors, siblings and children, if requested.
61+
'''
62+
session = get_session_context('accessibility.get_partial_ax_tree')
63+
return await session.execute(cdp.accessibility.get_partial_ax_tree(node_id, backend_node_id, object_id, fetch_relatives))
64+
""")
65+
66+
assert expected == generate_command(cdp.accessibility, 'accessibility',
67+
cdp.accessibility.get_partial_ax_tree)

trio_cdp/generated/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from . import headless_experimental
2626
from . import heap_profiler
2727
from . import indexed_db
28-
from . import input
28+
from . import input_
2929
from . import inspector
3030
from . import io
3131
from . import layer_tree

trio_cdp/generated/accessibility.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ async def disable() -> None:
3333

3434
async def enable() -> None:
3535
'''
36-
Enables the accessibility domain which causes `AXNodeId`s to remain consistent between method calls.
36+
Enables the accessibility domain which causes ``AXNodeId``'s to remain consistent between method calls.
3737
This turns on accessibility for the page, which can impact performance until accessibility is disabled.
3838
'''
3939
session = get_session_context('accessibility.enable')
@@ -44,27 +44,30 @@ async def get_full_ax_tree() -> typing.List[AXNode]:
4444
'''
4545
Fetches the entire accessibility tree
4646
47+
**EXPERIMENTAL**
48+
4749
:returns:
4850
'''
4951
session = get_session_context('accessibility.get_full_ax_tree')
5052
return await session.execute(cdp.accessibility.get_full_ax_tree())
5153

5254

5355
async def get_partial_ax_tree(
54-
node_id: typing.Optional[dom.NodeId] = None,
55-
backend_node_id: typing.Optional[dom.BackendNodeId] = None,
56-
object_id: typing.Optional[runtime.RemoteObjectId] = None,
56+
node_id: typing.Optional[cdp.dom.NodeId] = None,
57+
backend_node_id: typing.Optional[cdp.dom.BackendNodeId] = None,
58+
object_id: typing.Optional[cdp.runtime.RemoteObjectId] = None,
5759
fetch_relatives: typing.Optional[bool] = None
5860
) -> typing.List[AXNode]:
5961
'''
6062
Fetches the accessibility node and partial accessibility tree for this DOM node, if it exists.
6163
62-
:param node_id: Identifier of the node to get the partial accessibility tree for.
63-
:param backend_node_id: Identifier of the backend node to get the partial accessibility tree for.
64-
:param object_id: JavaScript object id of the node wrapper to get the partial accessibility tree for.
65-
:param fetch_relatives: Whether to fetch this nodes ancestors, siblings and children. Defaults to true.
66-
:returns: The ``Accessibility.AXNode`` for this DOM node, if it exists, plus its ancestors, siblings and
67-
children, if requested.
64+
**EXPERIMENTAL**
65+
66+
:param node_id: *(Optional)* Identifier of the node to get the partial accessibility tree for.
67+
:param backend_node_id: *(Optional)* Identifier of the backend node to get the partial accessibility tree for.
68+
:param object_id: *(Optional)* JavaScript object id of the node wrapper to get the partial accessibility tree for.
69+
:param fetch_relatives: *(Optional)* Whether to fetch this nodes ancestors, siblings and children. Defaults to true.
70+
:returns: The ``Accessibility.AXNode`` for this DOM node, if it exists, plus its ancestors, siblings and children, if requested.
6871
'''
6972
session = get_session_context('accessibility.get_partial_ax_tree')
7073
return await session.execute(cdp.accessibility.get_partial_ax_tree(node_id, backend_node_id, object_id, fetch_relatives))

trio_cdp/generated/animation.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,16 @@ async def enable() -> None:
3737

3838

3939
async def get_current_time(
40-
id: str
40+
id_: str
4141
) -> float:
4242
'''
4343
Returns the current time of the an animation.
4444
45-
:param id: Id of animation.
45+
:param id_: Id of animation.
4646
:returns: Current time of the page.
4747
'''
4848
session = get_session_context('animation.get_current_time')
49-
return await session.execute(cdp.animation.get_current_time(id))
49+
return await session.execute(cdp.animation.get_current_time(id_))
5050

5151

5252
async def get_playback_rate() -> float:
@@ -73,7 +73,7 @@ async def release_animations(
7373

7474
async def resolve_animation(
7575
animation_id: str
76-
) -> runtime.RemoteObject:
76+
) -> cdp.runtime.RemoteObject:
7777
'''
7878
Gets the remote object of the Animation.
7979

trio_cdp/generated/application_cache.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ async def enable() -> None:
2727

2828

2929
async def get_application_cache_for_frame(
30-
frame_id: page.FrameId
30+
frame_id: cdp.page.FrameId
3131
) -> ApplicationCache:
3232
'''
3333
Returns relevant application cache data for the document in given frame.
@@ -44,15 +44,14 @@ async def get_frames_with_manifests() -> typing.List[FrameWithManifest]:
4444
Returns array of frame identifiers with manifest urls for each frame containing a document
4545
associated with some application cache.
4646
47-
:returns: Array of frame identifiers with manifest urls for each frame containing a document
48-
associated with some application cache.
47+
:returns: Array of frame identifiers with manifest urls for each frame containing a document associated with some application cache.
4948
'''
5049
session = get_session_context('application_cache.get_frames_with_manifests')
5150
return await session.execute(cdp.application_cache.get_frames_with_manifests())
5251

5352

5453
async def get_manifest_for_frame(
55-
frame_id: page.FrameId
54+
frame_id: cdp.page.FrameId
5655
) -> str:
5756
'''
5857
Returns manifest URL for document in the given frame.

trio_cdp/generated/audits.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import cdp.audits
1212

1313
async def get_encoded_response(
14-
request_id: network.RequestId,
14+
request_id: cdp.network.RequestId,
1515
encoding: str,
1616
quality: typing.Optional[float] = None,
1717
size_only: typing.Optional[bool] = None
@@ -22,12 +22,13 @@ async def get_encoded_response(
2222
2323
:param request_id: Identifier of the network request to get content for.
2424
:param encoding: The encoding to use.
25-
:param quality: The quality of the encoding (0-1). (defaults to 1)
26-
:param size_only: Whether to only return the size information (defaults to false).
27-
:returns: a tuple with the following items:
28-
0. body: (Optional) The encoded body as a base64 string. Omitted if sizeOnly is true.
29-
1. originalSize: Size before re-encoding.
30-
2. encodedSize: Size after re-encoding.
25+
:param quality: *(Optional)* The quality of the encoding (0-1). (defaults to 1)
26+
:param size_only: *(Optional)* Whether to only return the size information (defaults to false).
27+
:returns: A tuple with the following items:
28+
29+
0. **body** – *(Optional)* The encoded body as a base64 string. Omitted if sizeOnly is true.
30+
1. **originalSize** – Size before re-encoding.
31+
2. **encodedSize** – Size after re-encoding.
3132
'''
3233
session = get_session_context('audits.get_encoded_response')
3334
return await session.execute(cdp.audits.get_encoded_response(request_id, encoding, quality, size_only))

0 commit comments

Comments
 (0)