Skip to content

Commit 1e96fb7

Browse files
authored
Better narrowing for enums and other types with known equality (#21281)
Fixes #21187, fully fixes this comment #9003 (comment) (previously improved on in 1.20 narrowing changes) This diff adds general functionality that replaces a few different pieces of ad hoc logic: - The conservative enum handling is now generalised - We can get rid of the special cased logic for None when narrowing contains, it is fully subsumed by more general logic - We no longer rely on promotion behaviour to narrow numeric types correctly - Similarly, we now narrow bytes consistently, regardless of `--strict-bytes`. We also no longer have to mark memoryview and bytearray as having custom equality implementations. That is, it is a long term fix for #20701 Co-authored-by Codex
1 parent aaf6002 commit 1e96fb7

4 files changed

Lines changed: 346 additions & 115 deletions

File tree

mypy/checker.py

Lines changed: 196 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def __init__(self) -> None:
108108
from mypy.expandtype import expand_type
109109
from mypy.literals import Key, extract_var_from_literal_hash, literal, literal_hash
110110
from mypy.maptype import map_instance_to_supertype
111-
from mypy.meet import is_overlapping_erased_types, is_overlapping_types, meet_types
111+
from mypy.meet import is_overlapping_types, meet_types
112112
from mypy.message_registry import ErrorMessage
113113
from mypy.messages import (
114114
SUGGESTED_TEST_FIXTURES,
@@ -6761,22 +6761,6 @@ def comparison_type_narrowing_helper(self, node: ComparisonExpr) -> tuple[TypeMa
67616761
narrowable_indices={0},
67626762
)
67636763

6764-
# TODO: This remove_optional code should no longer be needed. The only
6765-
# thing it does is paper over a pre-existing deficiency in equality
6766-
# narrowing w.r.t to enums.
6767-
# We only try and narrow away 'None' for now
6768-
if (
6769-
not is_unreachable_map(if_map)
6770-
and is_overlapping_none(item_type)
6771-
and not is_overlapping_none(collection_item_type)
6772-
and not (
6773-
isinstance(collection_item_type, Instance)
6774-
and collection_item_type.type.fullname == "builtins.object"
6775-
)
6776-
and is_overlapping_erased_types(item_type, collection_item_type)
6777-
):
6778-
if_map[operands[left_index]] = remove_optional(item_type)
6779-
67806764
if right_index in narrowable_operand_index_to_hash:
67816765
if_type, else_type = self.conditional_types_for_iterable(
67826766
item_type, iterable_type
@@ -6861,17 +6845,15 @@ def narrow_type_by_identity_equality(
68616845
# have to be more careful about what narrowing we can conclude from a successful comparison
68626846
custom_eq_indices: set[int]
68636847

6864-
# enum_comparison_is_ambiguous:
6865-
# `if x is Fruits.APPLE` we know `x` is `Fruits.APPLE`, but `if x == Fruits.APPLE: ...`
6866-
# it could e.g. be an int or str if Fruits is an IntEnum or StrEnum.
6867-
# See ambiguous_enum_equality_keys for more details
6868-
enum_comparison_is_ambiguous: bool
6848+
# Equality can use value semantics, so `if x == Fruits.APPLE: ...` may also
6849+
# match non-enum values for IntEnum/StrEnum-like enums. Identity checks don't
6850+
# have this ambiguity.
6851+
is_identity_comparison = operator in {"is", "is not"}
68696852

6870-
if operator in {"is", "is not"}:
6853+
if is_identity_comparison:
68716854
is_target_for_value_narrowing = is_singleton_identity_type
68726855
should_coerce_literals = True
68736856
custom_eq_indices = set()
6874-
enum_comparison_is_ambiguous = False
68756857

68766858
elif operator in {"==", "!="}:
68776859
is_target_for_value_narrowing = is_singleton_equality_type
@@ -6884,7 +6866,6 @@ def narrow_type_by_identity_equality(
68846866
break
68856867

68866868
custom_eq_indices = {i for i in expr_indices if has_custom_eq_checks(operand_types[i])}
6887-
enum_comparison_is_ambiguous = True
68886869
else:
68896870
raise AssertionError
68906871

@@ -6900,8 +6881,6 @@ def narrow_type_by_identity_equality(
69006881
continue
69016882

69026883
expr_type = operand_types[i]
6903-
expr_enum_keys = ambiguous_enum_equality_keys(expr_type)
6904-
expr_type = try_expanding_sum_type_to_union(coerce_to_literal(expr_type), None)
69056884
for j in expr_indices:
69066885
if i == j:
69076886
continue
@@ -6913,18 +6892,30 @@ def narrow_type_by_identity_equality(
69136892
if should_coerce_literals:
69146893
target_type = coerce_to_literal(target_type)
69156894

6916-
if (
6917-
# See comments in ambiguous_enum_equality_keys
6918-
enum_comparison_is_ambiguous
6919-
and len(expr_enum_keys | ambiguous_enum_equality_keys(target_type)) > 1
6920-
):
6921-
continue
6895+
narrowable_expr_type, ambiguous_expr_type = partition_equality_ambiguous_types(
6896+
expr_type, target_type, is_identity=is_identity_comparison
6897+
)
69226898

6923-
target = TypeRange(target_type, is_upper_bound=False)
6899+
if narrowable_expr_type is None:
6900+
if_type = else_type = ambiguous_expr_type
6901+
else:
6902+
narrowable_expr_type = try_expanding_sum_type_to_union(
6903+
coerce_to_literal(narrowable_expr_type), None
6904+
)
6905+
if_type, else_type = conditional_types(
6906+
narrowable_expr_type,
6907+
[TypeRange(target_type, is_upper_bound=False)],
6908+
from_equality=True,
6909+
)
6910+
if ambiguous_expr_type is not None:
6911+
if_type = make_simplified_union(
6912+
[if_type or narrowable_expr_type, ambiguous_expr_type]
6913+
)
6914+
else_type = make_simplified_union(
6915+
[else_type or narrowable_expr_type, ambiguous_expr_type]
6916+
)
69246917

6925-
if_map, else_map = conditional_types_to_typemaps(
6926-
operands[i], *conditional_types(expr_type, [target], from_equality=True)
6927-
)
6918+
if_map, else_map = conditional_types_to_typemaps(operands[i], if_type, else_type)
69286919
if is_target_for_value_narrowing(get_proper_type(target_type)):
69296920
all_if_maps.append(if_map)
69306921
all_else_maps.append(else_map)
@@ -7005,13 +6996,29 @@ def narrow_type_by_identity_equality(
70056996
target_type = operand_types[j]
70066997
if should_coerce_literals:
70076998
target_type = coerce_to_literal(target_type)
7008-
target = TypeRange(target_type, is_upper_bound=False)
6999+
7000+
narrowable_expr_type, ambiguous_expr_type = partition_equality_ambiguous_types(
7001+
expr_type, target_type, is_identity=is_identity_comparison
7002+
)
7003+
7004+
if narrowable_expr_type is None:
7005+
if_type = else_type = ambiguous_expr_type
7006+
else:
7007+
narrowable_expr_type = coerce_to_literal(
7008+
try_expanding_sum_type_to_union(narrowable_expr_type, None)
7009+
)
7010+
if_type, else_type = conditional_types(
7011+
narrowable_expr_type,
7012+
[TypeRange(target_type, is_upper_bound=False)],
7013+
default=narrowable_expr_type,
7014+
from_equality=True,
7015+
)
7016+
if ambiguous_expr_type is not None:
7017+
if_type = make_simplified_union([if_type, ambiguous_expr_type])
7018+
else_type = make_simplified_union([else_type, ambiguous_expr_type])
70097019

70107020
if_map, else_map = conditional_types_to_typemaps(
7011-
operands[i],
7012-
*conditional_types(
7013-
expr_type, [target], default=expr_type, from_equality=True
7014-
),
7021+
operands[i], if_type, else_type
70157022
)
70167023
or_if_maps.append(if_map)
70177024
if is_target_for_value_narrowing(get_proper_type(target_type)):
@@ -8618,17 +8625,10 @@ def conditional_types(
86188625
# We erase generic args because values with different generic types can compare equal
86198626
# For instance, cast(list[str], []) and cast(list[int], [])
86208627
proposed_type = shallow_erase_type_for_equality(proposed_type)
8621-
if not is_overlapping_types(current_type, proposed_type, ignore_promotions=False):
8622-
# Equality narrowing is one of the places at runtime where subtyping with promotion
8623-
# does happen to match runtime semantics
8624-
# Expression is never of any type in proposed_type_ranges
8625-
return UninhabitedType(), default
8626-
if not is_overlapping_types(current_type, proposed_type, ignore_promotions=True):
8627-
return default, default
8628-
else:
8629-
if not is_overlapping_types(current_type, proposed_type, ignore_promotions=True):
8630-
# Expression is never of any type in proposed_type_ranges
8631-
return UninhabitedType(), default
8628+
8629+
if not is_overlapping_types(current_type, proposed_type, ignore_promotions=True):
8630+
# Expression is never of any type in proposed_type_ranges
8631+
return UninhabitedType(), default
86328632

86338633
# we can only restrict when the type is precise, not bounded
86348634
proposed_precise_type = UnionType.make_union(
@@ -8898,8 +8898,6 @@ def reduce_and_conditional_type_maps(ms: list[TypeMap], *, use_meet: bool) -> Ty
88988898

88998899

89008900
BUILTINS_CUSTOM_EQ_CHECKS: Final = {
8901-
"builtins.bytearray",
8902-
"builtins.memoryview",
89038901
"builtins.frozenset",
89048902
"_collections_abc.dict_keys",
89058903
"_collections_abc.dict_items",
@@ -8911,9 +8909,8 @@ def has_custom_eq_checks(t: Type) -> bool:
89118909
custom_special_method(t, "__eq__", check_all=False)
89128910
or custom_special_method(t, "__ne__", check_all=False)
89138911
# custom_special_method has special casing for builtins.* and typing.* that make the
8914-
# above always return False. So here we return True if the a value of a builtin type
8915-
# will ever compare equal to value of another type, e.g. a bytes value can compare equal
8916-
# to a bytearray value.
8912+
# above always return False. Some builtin collections still have equality behavior that
8913+
# crosses nominal type boundaries and isn't captured by VALUE_EQUALITY_TYPE_DOMAINS.
89178914
or (
89188915
isinstance(pt := get_proper_type(t), Instance)
89198916
and pt.type.fullname in BUILTINS_CUSTOM_EQ_CHECKS
@@ -9691,45 +9688,152 @@ def visit_starred_pattern(self, p: StarredPattern) -> None:
96919688
self.lvalue = False
96929689

96939690

9694-
def ambiguous_enum_equality_keys(t: Type) -> set[str]:
9695-
"""
9696-
Used when narrowing types based on equality.
9691+
# Open domains also block cross-type narrowing for known domain members, but they
9692+
# don't provide an exhaustive union to narrow top types to.
9693+
OPEN_VALUE_EQUALITY_DOMAINS: Final = {
9694+
"builtins.str": "builtins.str",
9695+
"builtins.bool": "builtins.numeric",
9696+
"builtins.int": "builtins.numeric",
9697+
"builtins.float": "builtins.numeric",
9698+
"builtins.complex": "builtins.numeric",
9699+
}
9700+
OPEN_VALUE_EQUALITY_DOMAIN_NAMES: Final = frozenset(OPEN_VALUE_EQUALITY_DOMAINS.values())
9701+
9702+
# Closed domains also block ordinary cross-type narrowing within the domain.
9703+
CLOSED_VALUE_EQUALITY_DOMAINS: Final = {
9704+
"builtins.bytes": "builtins.bytes",
9705+
"builtins.bytearray": "builtins.bytes",
9706+
"builtins.memoryview": "builtins.bytes",
9707+
}
9708+
9709+
VALUE_EQUALITY_DOMAINS: Final = {**OPEN_VALUE_EQUALITY_DOMAINS, **CLOSED_VALUE_EQUALITY_DOMAINS}
96979710

9698-
Certain kinds of enums can compare equal to values of other types, so doing type math
9699-
the way `conditional_types` does will be misleading if you expect it to correspond to
9700-
conditions based on equality comparisons.
97019711

9702-
For example, StrEnum classes can compare equal to str values. So if we see
9703-
`val: StrEnum; if val == "foo": ...` we currently avoid narrowing.
9704-
Note that we do wish to continue narrowing for `if val == StrEnum.MEMBER: ...`
9712+
class EqualityDomainInfo:
9713+
def __init__(self, type_names: set[str], enum_type_names: set[str]) -> None:
9714+
self.type_names = type_names
9715+
self.enum_type_names = enum_type_names
9716+
9717+
9718+
class EqualityValueInfo:
9719+
def __init__(self, domains: dict[str, EqualityDomainInfo], is_top: bool) -> None:
9720+
self.domains = domains
9721+
self.is_top = is_top
9722+
9723+
9724+
def partition_equality_ambiguous_types(
9725+
current_type: Type, target_type: Type, *, is_identity: bool
9726+
) -> tuple[Type | None, Type | None]:
9727+
"""Split current_type into ordinary-narrowable and equality-ambiguous pieces.
9728+
9729+
Some values compare equal through a value domain broader than their nominal type. For
9730+
example, an IntEnum member can compare equal to an int, and a StrEnum member can compare
9731+
equal to a str. When narrowing `x: MyStrEnum | str` against `MyStrEnum.MEMBER`, we can
9732+
still narrow the enum portion of the union, but we must keep the str portion in both
9733+
branches.
97059734
"""
9706-
# We need these things for this to be ambiguous:
9707-
# (1) an IntEnum or StrEnum type or enum subclass of int or str
9708-
# (2) either a different IntEnum/StrEnum type or a non-enum type ("<other>")
9709-
result = set()
9735+
if is_identity:
9736+
return current_type, None
9737+
9738+
typ = get_proper_type(current_type)
9739+
items = typ.relevant_items() if isinstance(typ, UnionType) else [current_type]
9740+
narrowable_items = []
9741+
ambiguous_items = []
9742+
for item in items:
9743+
if is_equality_ambiguous_for_narrowing(item, target_type):
9744+
ambiguous_items.append(item)
9745+
else:
9746+
narrowable_items.append(item)
9747+
return (
9748+
UnionType.make_union(narrowable_items) if narrowable_items else None,
9749+
UnionType.make_union(ambiguous_items) if ambiguous_items else None,
9750+
)
9751+
9752+
9753+
def is_equality_ambiguous_for_narrowing(left: Type, right: Type) -> bool:
9754+
"""Can left compare equal to right through a value domain outside nominal overlap?"""
9755+
left_info = equality_value_info(left)
9756+
right_info = equality_value_info(right)
9757+
9758+
if left_info.is_top or right_info.is_top:
9759+
# Only open-domain enum values can make a top-like type ambiguous.
9760+
# Closed domains can be narrowed to their complete known set instead.
9761+
other_info = right_info if left_info.is_top else left_info
9762+
return any(
9763+
domain in OPEN_VALUE_EQUALITY_DOMAIN_NAMES and domain_info.enum_type_names
9764+
for domain, domain_info in other_info.domains.items()
9765+
)
9766+
9767+
shared_domains = left_info.domains.keys() & right_info.domains.keys()
9768+
if not shared_domains:
9769+
return False
9770+
9771+
for domain in shared_domains:
9772+
left_domain = left_info.domains[domain]
9773+
right_domain = right_info.domains[domain]
9774+
# Equality between two values from the same enum can still narrow by literal member.
9775+
if (
9776+
left_domain.enum_type_names
9777+
and left_domain.enum_type_names == right_domain.enum_type_names
9778+
and left_domain.type_names == left_domain.enum_type_names
9779+
and right_domain.type_names == right_domain.enum_type_names
9780+
):
9781+
continue
9782+
# Different domain-member types may compare equal, but nominal narrowing would
9783+
# otherwise treat them as disjoint.
9784+
if left_domain.type_names != right_domain.type_names:
9785+
return True
9786+
# Same domain-member types are only ambiguous if an enum value may compare equal to
9787+
# its underlying value type.
9788+
if left_domain.enum_type_names or right_domain.enum_type_names:
9789+
return True
9790+
9791+
return False
9792+
9793+
9794+
def equality_value_info(t: Type) -> EqualityValueInfo:
97109795
t = get_proper_type(t)
97119796
if isinstance(t, UnionType):
9712-
for item in t.items:
9713-
result.update(ambiguous_enum_equality_keys(item))
9714-
elif isinstance(t, Instance):
9715-
if t.last_known_value:
9716-
result.update(ambiguous_enum_equality_keys(t.last_known_value))
9717-
elif t.type.is_enum and any(
9718-
base.fullname in ("enum.IntEnum", "enum.StrEnum", "builtins.str", "builtins.int")
9719-
for base in t.type.mro
9720-
):
9721-
result.add(t.type.fullname)
9722-
elif not t.type.is_enum:
9723-
# These might compare equal to IntEnum/StrEnum types (e.g. Decimal), so
9724-
# let's be conservative
9725-
result.add("<other>")
9726-
elif isinstance(t, LiteralType):
9727-
result.update(ambiguous_enum_equality_keys(t.fallback))
9728-
elif isinstance(t, NoneType):
9729-
pass
9730-
else:
9731-
result.add("<other>")
9732-
return result
9797+
return combine_equality_value_info(equality_value_info(item) for item in t.items)
9798+
if isinstance(t, TypeVarType):
9799+
if t.values:
9800+
return combine_equality_value_info(equality_value_info(item) for item in t.values)
9801+
return equality_value_info(t.upper_bound)
9802+
if isinstance(t, Instance) and t.last_known_value is not None:
9803+
return equality_value_info(t.last_known_value)
9804+
if isinstance(t, LiteralType):
9805+
return equality_value_info(t.fallback)
9806+
if isinstance(t, Instance):
9807+
if t.type.fullname == "builtins.object":
9808+
return EqualityValueInfo({}, is_top=True)
9809+
9810+
enum_type_names = {t.type.fullname} if t.type.is_enum else set()
9811+
domains = {}
9812+
for base in t.type.mro:
9813+
if domain := VALUE_EQUALITY_DOMAINS.get(base.fullname):
9814+
domains[domain] = EqualityDomainInfo({t.type.fullname}, enum_type_names)
9815+
9816+
return EqualityValueInfo(domains, is_top=False)
9817+
if isinstance(t, AnyType):
9818+
return EqualityValueInfo({}, is_top=True)
9819+
return EqualityValueInfo({}, is_top=False)
9820+
9821+
9822+
def combine_equality_value_info(infos: Iterable[EqualityValueInfo]) -> EqualityValueInfo:
9823+
domains: dict[str, EqualityDomainInfo] = {}
9824+
is_top = False
9825+
for info in infos:
9826+
for domain, domain_info in info.domains.items():
9827+
existing_domain_info = domains.get(domain)
9828+
if existing_domain_info is None:
9829+
domains[domain] = EqualityDomainInfo(
9830+
set(domain_info.type_names), set(domain_info.enum_type_names)
9831+
)
9832+
else:
9833+
existing_domain_info.type_names.update(domain_info.type_names)
9834+
existing_domain_info.enum_type_names.update(domain_info.enum_type_names)
9835+
is_top = is_top or info.is_top
9836+
return EqualityValueInfo(domains, is_top)
97339837

97349838

97359839
def is_typeddict_type_context(lvalue_type: Type) -> bool:

0 commit comments

Comments
 (0)