@@ -6734,6 +6734,7 @@ def narrow_type_by_equality(
67346734 is_target_for_value_narrowing = is_singleton_identity_type
67356735 should_coerce_literals = True
67366736 should_narrow_by_identity_equality = True
6737+ enum_comparison_is_ambiguous = False
67376738
67386739 elif operator in {"==" , "!=" }:
67396740 is_target_for_value_narrowing = is_singleton_equality_type
@@ -6746,9 +6747,8 @@ def narrow_type_by_equality(
67466747 break
67476748
67486749 expr_types = [operand_types [i ] for i in expr_indices ]
6749- should_narrow_by_identity_equality = not any (
6750- map (has_custom_eq_checks , expr_types )
6751- ) and not is_ambiguous_mix_of_enums (expr_types )
6750+ should_narrow_by_identity_equality = not any (map (has_custom_eq_checks , expr_types ))
6751+ enum_comparison_is_ambiguous = True
67526752 else :
67536753 raise AssertionError
67546754
@@ -6781,11 +6781,18 @@ def narrow_type_by_equality(
67816781 for i in expr_indices :
67826782 if i not in narrowable_indices :
67836783 continue
6784+ expr_type = coerce_to_literal (operand_types [i ])
6785+ expr_type = try_expanding_sum_type_to_union (expr_type , None )
6786+ expr_enum_keys = ambiguous_enum_equality_keys (expr_type )
67846787 for j , target in value_targets :
67856788 if i == j :
67866789 continue
6787- expr_type = coerce_to_literal (operand_types [i ])
6788- expr_type = try_expanding_sum_type_to_union (expr_type , None )
6790+ if (
6791+ # See comments in ambiguous_enum_equality_keys
6792+ enum_comparison_is_ambiguous
6793+ and len (expr_enum_keys | ambiguous_enum_equality_keys (target .item )) > 1
6794+ ):
6795+ continue
67896796 if_map , else_map = conditional_types_to_typemaps (
67906797 operands [i ], * conditional_types (expr_type , [target ])
67916798 )
@@ -6795,10 +6802,10 @@ def narrow_type_by_equality(
67956802 for i in expr_indices :
67966803 if i not in narrowable_indices :
67976804 continue
6805+ expr_type = operand_types [i ]
67986806 for j , target in type_targets :
67996807 if i == j :
68006808 continue
6801- expr_type = operand_types [i ]
68026809 if_map , else_map = conditional_types_to_typemaps (
68036810 operands [i ], * conditional_types (expr_type , [target ])
68046811 )
@@ -9387,47 +9394,44 @@ def visit_starred_pattern(self, p: StarredPattern) -> None:
93879394 self .lvalue = False
93889395
93899396
9390- def is_ambiguous_mix_of_enums (types : list [Type ]) -> bool :
9391- """Do types have IntEnum/StrEnum types that are potentially overlapping with other types?
9397+ def ambiguous_enum_equality_keys (t : Type ) -> set [str ]:
9398+ """
9399+ Used when narrowing types based on equality.
93929400
9393- If True, we shouldn't attempt type narrowing based on enum values, as it gets
9394- too ambiguous.
9401+ Certain kinds of enums can compare equal to values of other types, so doing type math
9402+ the way `conditional_types` does will be misleading if you expect it to correspond to
9403+ conditions based on equality comparisons.
93959404
9396- For example, return True if there's an 'int' type together with an IntEnum literal.
9397- However, IntEnum together with a literal of the same IntEnum type is not ambiguous.
9405+ For example, StrEnum classes can compare equal to str values. So if we see
9406+ `val: StrEnum; if val == "foo": ...` we currently avoid narrowing.
9407+ Note that we do wish to continue narrowing for `if val == StrEnum.MEMBER: ...`
93989408 """
93999409 # We need these things for this to be ambiguous:
9400- # (1) an IntEnum or StrEnum type
9410+ # (1) an IntEnum or StrEnum type or enum subclass of int or str
94019411 # (2) either a different IntEnum/StrEnum type or a non-enum type ("<other>")
9402- #
9403- # It would be slightly more correct to calculate this separately for IntEnum and
9404- # StrEnum related types, as an IntEnum can't be confused with a StrEnum.
9405- return len (_ambiguous_enum_variants (types )) > 1
9406-
9407-
9408- def _ambiguous_enum_variants (types : list [Type ]) -> set [str ]:
94099412 result = set ()
9410- for t in types :
9411- t = get_proper_type (t )
9412- if isinstance (t , UnionType ):
9413- result .update (_ambiguous_enum_variants (t .items ))
9414- elif isinstance (t , Instance ):
9415- if t .last_known_value :
9416- result .update (_ambiguous_enum_variants ([t .last_known_value ]))
9417- elif t .type .is_enum and any (
9418- base .fullname in ("enum.IntEnum" , "enum.StrEnum" ) for base in t .type .mro
9419- ):
9420- result .add (t .type .fullname )
9421- elif not t .type .is_enum :
9422- # These might compare equal to IntEnum/StrEnum types (e.g. Decimal), so
9423- # let's be conservative
9424- result .add ("<other>" )
9425- elif isinstance (t , LiteralType ):
9426- result .update (_ambiguous_enum_variants ([t .fallback ]))
9427- elif isinstance (t , NoneType ):
9428- pass
9429- else :
9413+ t = get_proper_type (t )
9414+ if isinstance (t , UnionType ):
9415+ for item in t .items :
9416+ result .update (ambiguous_enum_equality_keys (item ))
9417+ elif isinstance (t , Instance ):
9418+ if t .last_known_value :
9419+ result .update (ambiguous_enum_equality_keys (t .last_known_value ))
9420+ elif t .type .is_enum and any (
9421+ base .fullname in ("enum.IntEnum" , "enum.StrEnum" , "builtins.str" , "builtins.int" )
9422+ for base in t .type .mro
9423+ ):
9424+ result .add (t .type .fullname )
9425+ elif not t .type .is_enum :
9426+ # These might compare equal to IntEnum/StrEnum types (e.g. Decimal), so
9427+ # let's be conservative
94309428 result .add ("<other>" )
9429+ elif isinstance (t , LiteralType ):
9430+ result .update (ambiguous_enum_equality_keys (t .fallback ))
9431+ elif isinstance (t , NoneType ):
9432+ pass
9433+ else :
9434+ result .add ("<other>" )
94319435 return result
94329436
94339437
0 commit comments