@@ -108,7 +108,7 @@ def __init__(self) -> None:
108108from mypy .expandtype import expand_type
109109from mypy .literals import Key , extract_var_from_literal_hash , literal , literal_hash
110110from mypy .maptype import map_instance_to_supertype
111- from mypy .meet import is_overlapping_erased_types , is_overlapping_types , meet_types
111+ from mypy .meet import is_overlapping_types , meet_types
112112from mypy .message_registry import ErrorMessage
113113from mypy .messages import (
114114 SUGGESTED_TEST_FIXTURES ,
@@ -6761,22 +6761,6 @@ def comparison_type_narrowing_helper(self, node: ComparisonExpr) -> tuple[TypeMa
67616761 narrowable_indices = {0 },
67626762 )
67636763
6764- # TODO: This remove_optional code should no longer be needed. The only
6765- # thing it does is paper over a pre-existing deficiency in equality
6766- # narrowing w.r.t to enums.
6767- # We only try and narrow away 'None' for now
6768- if (
6769- not is_unreachable_map (if_map )
6770- and is_overlapping_none (item_type )
6771- and not is_overlapping_none (collection_item_type )
6772- and not (
6773- isinstance (collection_item_type , Instance )
6774- and collection_item_type .type .fullname == "builtins.object"
6775- )
6776- and is_overlapping_erased_types (item_type , collection_item_type )
6777- ):
6778- if_map [operands [left_index ]] = remove_optional (item_type )
6779-
67806764 if right_index in narrowable_operand_index_to_hash :
67816765 if_type , else_type = self .conditional_types_for_iterable (
67826766 item_type , iterable_type
@@ -6861,17 +6845,15 @@ def narrow_type_by_identity_equality(
68616845 # have to be more careful about what narrowing we can conclude from a successful comparison
68626846 custom_eq_indices : set [int ]
68636847
6864- # enum_comparison_is_ambiguous:
6865- # `if x is Fruits.APPLE` we know `x` is `Fruits.APPLE`, but `if x == Fruits.APPLE: ...`
6866- # it could e.g. be an int or str if Fruits is an IntEnum or StrEnum.
6867- # See ambiguous_enum_equality_keys for more details
6868- enum_comparison_is_ambiguous : bool
6848+ # Equality can use value semantics, so `if x == Fruits.APPLE: ...` may also
6849+ # match non-enum values for IntEnum/StrEnum-like enums. Identity checks don't
6850+ # have this ambiguity.
6851+ is_identity_comparison = operator in {"is" , "is not" }
68696852
6870- if operator in { "is" , "is not" } :
6853+ if is_identity_comparison :
68716854 is_target_for_value_narrowing = is_singleton_identity_type
68726855 should_coerce_literals = True
68736856 custom_eq_indices = set ()
6874- enum_comparison_is_ambiguous = False
68756857
68766858 elif operator in {"==" , "!=" }:
68776859 is_target_for_value_narrowing = is_singleton_equality_type
@@ -6884,7 +6866,6 @@ def narrow_type_by_identity_equality(
68846866 break
68856867
68866868 custom_eq_indices = {i for i in expr_indices if has_custom_eq_checks (operand_types [i ])}
6887- enum_comparison_is_ambiguous = True
68886869 else :
68896870 raise AssertionError
68906871
@@ -6900,8 +6881,6 @@ def narrow_type_by_identity_equality(
69006881 continue
69016882
69026883 expr_type = operand_types [i ]
6903- expr_enum_keys = ambiguous_enum_equality_keys (expr_type )
6904- expr_type = try_expanding_sum_type_to_union (coerce_to_literal (expr_type ), None )
69056884 for j in expr_indices :
69066885 if i == j :
69076886 continue
@@ -6913,18 +6892,30 @@ def narrow_type_by_identity_equality(
69136892 if should_coerce_literals :
69146893 target_type = coerce_to_literal (target_type )
69156894
6916- if (
6917- # See comments in ambiguous_enum_equality_keys
6918- enum_comparison_is_ambiguous
6919- and len (expr_enum_keys | ambiguous_enum_equality_keys (target_type )) > 1
6920- ):
6921- continue
6895+ narrowable_expr_type , ambiguous_expr_type = partition_equality_ambiguous_types (
6896+ expr_type , target_type , is_identity = is_identity_comparison
6897+ )
69226898
6923- target = TypeRange (target_type , is_upper_bound = False )
6899+ if narrowable_expr_type is None :
6900+ if_type = else_type = ambiguous_expr_type
6901+ else :
6902+ narrowable_expr_type = try_expanding_sum_type_to_union (
6903+ coerce_to_literal (narrowable_expr_type ), None
6904+ )
6905+ if_type , else_type = conditional_types (
6906+ narrowable_expr_type ,
6907+ [TypeRange (target_type , is_upper_bound = False )],
6908+ from_equality = True ,
6909+ )
6910+ if ambiguous_expr_type is not None :
6911+ if_type = make_simplified_union (
6912+ [if_type or narrowable_expr_type , ambiguous_expr_type ]
6913+ )
6914+ else_type = make_simplified_union (
6915+ [else_type or narrowable_expr_type , ambiguous_expr_type ]
6916+ )
69246917
6925- if_map , else_map = conditional_types_to_typemaps (
6926- operands [i ], * conditional_types (expr_type , [target ], from_equality = True )
6927- )
6918+ if_map , else_map = conditional_types_to_typemaps (operands [i ], if_type , else_type )
69286919 if is_target_for_value_narrowing (get_proper_type (target_type )):
69296920 all_if_maps .append (if_map )
69306921 all_else_maps .append (else_map )
@@ -7005,13 +6996,29 @@ def narrow_type_by_identity_equality(
70056996 target_type = operand_types [j ]
70066997 if should_coerce_literals :
70076998 target_type = coerce_to_literal (target_type )
7008- target = TypeRange (target_type , is_upper_bound = False )
6999+
7000+ narrowable_expr_type , ambiguous_expr_type = partition_equality_ambiguous_types (
7001+ expr_type , target_type , is_identity = is_identity_comparison
7002+ )
7003+
7004+ if narrowable_expr_type is None :
7005+ if_type = else_type = ambiguous_expr_type
7006+ else :
7007+ narrowable_expr_type = coerce_to_literal (
7008+ try_expanding_sum_type_to_union (narrowable_expr_type , None )
7009+ )
7010+ if_type , else_type = conditional_types (
7011+ narrowable_expr_type ,
7012+ [TypeRange (target_type , is_upper_bound = False )],
7013+ default = narrowable_expr_type ,
7014+ from_equality = True ,
7015+ )
7016+ if ambiguous_expr_type is not None :
7017+ if_type = make_simplified_union ([if_type , ambiguous_expr_type ])
7018+ else_type = make_simplified_union ([else_type , ambiguous_expr_type ])
70097019
70107020 if_map , else_map = conditional_types_to_typemaps (
7011- operands [i ],
7012- * conditional_types (
7013- expr_type , [target ], default = expr_type , from_equality = True
7014- ),
7021+ operands [i ], if_type , else_type
70157022 )
70167023 or_if_maps .append (if_map )
70177024 if is_target_for_value_narrowing (get_proper_type (target_type )):
@@ -8618,17 +8625,10 @@ def conditional_types(
86188625 # We erase generic args because values with different generic types can compare equal
86198626 # For instance, cast(list[str], []) and cast(list[int], [])
86208627 proposed_type = shallow_erase_type_for_equality (proposed_type )
8621- if not is_overlapping_types (current_type , proposed_type , ignore_promotions = False ):
8622- # Equality narrowing is one of the places at runtime where subtyping with promotion
8623- # does happen to match runtime semantics
8624- # Expression is never of any type in proposed_type_ranges
8625- return UninhabitedType (), default
8626- if not is_overlapping_types (current_type , proposed_type , ignore_promotions = True ):
8627- return default , default
8628- else :
8629- if not is_overlapping_types (current_type , proposed_type , ignore_promotions = True ):
8630- # Expression is never of any type in proposed_type_ranges
8631- return UninhabitedType (), default
8628+
8629+ if not is_overlapping_types (current_type , proposed_type , ignore_promotions = True ):
8630+ # Expression is never of any type in proposed_type_ranges
8631+ return UninhabitedType (), default
86328632
86338633 # we can only restrict when the type is precise, not bounded
86348634 proposed_precise_type = UnionType .make_union (
@@ -8898,8 +8898,6 @@ def reduce_and_conditional_type_maps(ms: list[TypeMap], *, use_meet: bool) -> Ty
88988898
88998899
89008900BUILTINS_CUSTOM_EQ_CHECKS : Final = {
8901- "builtins.bytearray" ,
8902- "builtins.memoryview" ,
89038901 "builtins.frozenset" ,
89048902 "_collections_abc.dict_keys" ,
89058903 "_collections_abc.dict_items" ,
@@ -8911,9 +8909,8 @@ def has_custom_eq_checks(t: Type) -> bool:
89118909 custom_special_method (t , "__eq__" , check_all = False )
89128910 or custom_special_method (t , "__ne__" , check_all = False )
89138911 # custom_special_method has special casing for builtins.* and typing.* that make the
8914- # above always return False. So here we return True if the a value of a builtin type
8915- # will ever compare equal to value of another type, e.g. a bytes value can compare equal
8916- # to a bytearray value.
8912+ # above always return False. Some builtin collections still have equality behavior that
8913+ # crosses nominal type boundaries and isn't captured by VALUE_EQUALITY_TYPE_DOMAINS.
89178914 or (
89188915 isinstance (pt := get_proper_type (t ), Instance )
89198916 and pt .type .fullname in BUILTINS_CUSTOM_EQ_CHECKS
@@ -9691,45 +9688,152 @@ def visit_starred_pattern(self, p: StarredPattern) -> None:
96919688 self .lvalue = False
96929689
96939690
9694- def ambiguous_enum_equality_keys (t : Type ) -> set [str ]:
9695- """
9696- Used when narrowing types based on equality.
9691+ # Open domains also block cross-type narrowing for known domain members, but they
9692+ # don't provide an exhaustive union to narrow top types to.
9693+ OPEN_VALUE_EQUALITY_DOMAINS : Final = {
9694+ "builtins.str" : "builtins.str" ,
9695+ "builtins.bool" : "builtins.numeric" ,
9696+ "builtins.int" : "builtins.numeric" ,
9697+ "builtins.float" : "builtins.numeric" ,
9698+ "builtins.complex" : "builtins.numeric" ,
9699+ }
9700+ OPEN_VALUE_EQUALITY_DOMAIN_NAMES : Final = frozenset (OPEN_VALUE_EQUALITY_DOMAINS .values ())
9701+
9702+ # Closed domains also block ordinary cross-type narrowing within the domain.
9703+ CLOSED_VALUE_EQUALITY_DOMAINS : Final = {
9704+ "builtins.bytes" : "builtins.bytes" ,
9705+ "builtins.bytearray" : "builtins.bytes" ,
9706+ "builtins.memoryview" : "builtins.bytes" ,
9707+ }
9708+
9709+ VALUE_EQUALITY_DOMAINS : Final = {** OPEN_VALUE_EQUALITY_DOMAINS , ** CLOSED_VALUE_EQUALITY_DOMAINS }
96979710
9698- Certain kinds of enums can compare equal to values of other types, so doing type math
9699- the way `conditional_types` does will be misleading if you expect it to correspond to
9700- conditions based on equality comparisons.
97019711
9702- For example, StrEnum classes can compare equal to str values. So if we see
9703- `val: StrEnum; if val == "foo": ...` we currently avoid narrowing.
9704- Note that we do wish to continue narrowing for `if val == StrEnum.MEMBER: ...`
9712+ class EqualityDomainInfo :
9713+ def __init__ (self , type_names : set [str ], enum_type_names : set [str ]) -> None :
9714+ self .type_names = type_names
9715+ self .enum_type_names = enum_type_names
9716+
9717+
9718+ class EqualityValueInfo :
9719+ def __init__ (self , domains : dict [str , EqualityDomainInfo ], is_top : bool ) -> None :
9720+ self .domains = domains
9721+ self .is_top = is_top
9722+
9723+
9724+ def partition_equality_ambiguous_types (
9725+ current_type : Type , target_type : Type , * , is_identity : bool
9726+ ) -> tuple [Type | None , Type | None ]:
9727+ """Split current_type into ordinary-narrowable and equality-ambiguous pieces.
9728+
9729+ Some values compare equal through a value domain broader than their nominal type. For
9730+ example, an IntEnum member can compare equal to an int, and a StrEnum member can compare
9731+ equal to a str. When narrowing `x: MyStrEnum | str` against `MyStrEnum.MEMBER`, we can
9732+ still narrow the enum portion of the union, but we must keep the str portion in both
9733+ branches.
97059734 """
9706- # We need these things for this to be ambiguous:
9707- # (1) an IntEnum or StrEnum type or enum subclass of int or str
9708- # (2) either a different IntEnum/StrEnum type or a non-enum type ("<other>")
9709- result = set ()
9735+ if is_identity :
9736+ return current_type , None
9737+
9738+ typ = get_proper_type (current_type )
9739+ items = typ .relevant_items () if isinstance (typ , UnionType ) else [current_type ]
9740+ narrowable_items = []
9741+ ambiguous_items = []
9742+ for item in items :
9743+ if is_equality_ambiguous_for_narrowing (item , target_type ):
9744+ ambiguous_items .append (item )
9745+ else :
9746+ narrowable_items .append (item )
9747+ return (
9748+ UnionType .make_union (narrowable_items ) if narrowable_items else None ,
9749+ UnionType .make_union (ambiguous_items ) if ambiguous_items else None ,
9750+ )
9751+
9752+
9753+ def is_equality_ambiguous_for_narrowing (left : Type , right : Type ) -> bool :
9754+ """Can left compare equal to right through a value domain outside nominal overlap?"""
9755+ left_info = equality_value_info (left )
9756+ right_info = equality_value_info (right )
9757+
9758+ if left_info .is_top or right_info .is_top :
9759+ # Only open-domain enum values can make a top-like type ambiguous.
9760+ # Closed domains can be narrowed to their complete known set instead.
9761+ other_info = right_info if left_info .is_top else left_info
9762+ return any (
9763+ domain in OPEN_VALUE_EQUALITY_DOMAIN_NAMES and domain_info .enum_type_names
9764+ for domain , domain_info in other_info .domains .items ()
9765+ )
9766+
9767+ shared_domains = left_info .domains .keys () & right_info .domains .keys ()
9768+ if not shared_domains :
9769+ return False
9770+
9771+ for domain in shared_domains :
9772+ left_domain = left_info .domains [domain ]
9773+ right_domain = right_info .domains [domain ]
9774+ # Equality between two values from the same enum can still narrow by literal member.
9775+ if (
9776+ left_domain .enum_type_names
9777+ and left_domain .enum_type_names == right_domain .enum_type_names
9778+ and left_domain .type_names == left_domain .enum_type_names
9779+ and right_domain .type_names == right_domain .enum_type_names
9780+ ):
9781+ continue
9782+ # Different domain-member types may compare equal, but nominal narrowing would
9783+ # otherwise treat them as disjoint.
9784+ if left_domain .type_names != right_domain .type_names :
9785+ return True
9786+ # Same domain-member types are only ambiguous if an enum value may compare equal to
9787+ # its underlying value type.
9788+ if left_domain .enum_type_names or right_domain .enum_type_names :
9789+ return True
9790+
9791+ return False
9792+
9793+
9794+ def equality_value_info (t : Type ) -> EqualityValueInfo :
97109795 t = get_proper_type (t )
97119796 if isinstance (t , UnionType ):
9712- for item in t .items :
9713- result .update (ambiguous_enum_equality_keys (item ))
9714- elif isinstance (t , Instance ):
9715- if t .last_known_value :
9716- result .update (ambiguous_enum_equality_keys (t .last_known_value ))
9717- elif t .type .is_enum and any (
9718- base .fullname in ("enum.IntEnum" , "enum.StrEnum" , "builtins.str" , "builtins.int" )
9719- for base in t .type .mro
9720- ):
9721- result .add (t .type .fullname )
9722- elif not t .type .is_enum :
9723- # These might compare equal to IntEnum/StrEnum types (e.g. Decimal), so
9724- # let's be conservative
9725- result .add ("<other>" )
9726- elif isinstance (t , LiteralType ):
9727- result .update (ambiguous_enum_equality_keys (t .fallback ))
9728- elif isinstance (t , NoneType ):
9729- pass
9730- else :
9731- result .add ("<other>" )
9732- return result
9797+ return combine_equality_value_info (equality_value_info (item ) for item in t .items )
9798+ if isinstance (t , TypeVarType ):
9799+ if t .values :
9800+ return combine_equality_value_info (equality_value_info (item ) for item in t .values )
9801+ return equality_value_info (t .upper_bound )
9802+ if isinstance (t , Instance ) and t .last_known_value is not None :
9803+ return equality_value_info (t .last_known_value )
9804+ if isinstance (t , LiteralType ):
9805+ return equality_value_info (t .fallback )
9806+ if isinstance (t , Instance ):
9807+ if t .type .fullname == "builtins.object" :
9808+ return EqualityValueInfo ({}, is_top = True )
9809+
9810+ enum_type_names = {t .type .fullname } if t .type .is_enum else set ()
9811+ domains = {}
9812+ for base in t .type .mro :
9813+ if domain := VALUE_EQUALITY_DOMAINS .get (base .fullname ):
9814+ domains [domain ] = EqualityDomainInfo ({t .type .fullname }, enum_type_names )
9815+
9816+ return EqualityValueInfo (domains , is_top = False )
9817+ if isinstance (t , AnyType ):
9818+ return EqualityValueInfo ({}, is_top = True )
9819+ return EqualityValueInfo ({}, is_top = False )
9820+
9821+
9822+ def combine_equality_value_info (infos : Iterable [EqualityValueInfo ]) -> EqualityValueInfo :
9823+ domains : dict [str , EqualityDomainInfo ] = {}
9824+ is_top = False
9825+ for info in infos :
9826+ for domain , domain_info in info .domains .items ():
9827+ existing_domain_info = domains .get (domain )
9828+ if existing_domain_info is None :
9829+ domains [domain ] = EqualityDomainInfo (
9830+ set (domain_info .type_names ), set (domain_info .enum_type_names )
9831+ )
9832+ else :
9833+ existing_domain_info .type_names .update (domain_info .type_names )
9834+ existing_domain_info .enum_type_names .update (domain_info .enum_type_names )
9835+ is_top = is_top or info .is_top
9836+ return EqualityValueInfo (domains , is_top )
97339837
97349838
97359839def is_typeddict_type_context (lvalue_type : Type ) -> bool :
0 commit comments