Skip to content

Commit fb0845d

Browse files
authored
Improve narrowing logic for Enum int and str subclasses (#20609)
Fixes some of the primer regressions from #20492 Fixes #19753 This should also happen to fix the case mentioned by cdce8p in #20492 (comment)
1 parent 639fcde commit fb0845d

File tree

2 files changed

+67
-42
lines changed

2 files changed

+67
-42
lines changed

mypy/checker.py

Lines changed: 44 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -6734,6 +6734,7 @@ def narrow_type_by_equality(
67346734
is_target_for_value_narrowing = is_singleton_identity_type
67356735
should_coerce_literals = True
67366736
should_narrow_by_identity_equality = True
6737+
enum_comparison_is_ambiguous = False
67376738

67386739
elif operator in {"==", "!="}:
67396740
is_target_for_value_narrowing = is_singleton_equality_type
@@ -6746,9 +6747,8 @@ def narrow_type_by_equality(
67466747
break
67476748

67486749
expr_types = [operand_types[i] for i in expr_indices]
6749-
should_narrow_by_identity_equality = not any(
6750-
map(has_custom_eq_checks, expr_types)
6751-
) and not is_ambiguous_mix_of_enums(expr_types)
6750+
should_narrow_by_identity_equality = not any(map(has_custom_eq_checks, expr_types))
6751+
enum_comparison_is_ambiguous = True
67526752
else:
67536753
raise AssertionError
67546754

@@ -6781,11 +6781,18 @@ def narrow_type_by_equality(
67816781
for i in expr_indices:
67826782
if i not in narrowable_indices:
67836783
continue
6784+
expr_type = coerce_to_literal(operand_types[i])
6785+
expr_type = try_expanding_sum_type_to_union(expr_type, None)
6786+
expr_enum_keys = ambiguous_enum_equality_keys(expr_type)
67846787
for j, target in value_targets:
67856788
if i == j:
67866789
continue
6787-
expr_type = coerce_to_literal(operand_types[i])
6788-
expr_type = try_expanding_sum_type_to_union(expr_type, None)
6790+
if (
6791+
# See comments in ambiguous_enum_equality_keys
6792+
enum_comparison_is_ambiguous
6793+
and len(expr_enum_keys | ambiguous_enum_equality_keys(target.item)) > 1
6794+
):
6795+
continue
67896796
if_map, else_map = conditional_types_to_typemaps(
67906797
operands[i], *conditional_types(expr_type, [target])
67916798
)
@@ -6795,10 +6802,10 @@ def narrow_type_by_equality(
67956802
for i in expr_indices:
67966803
if i not in narrowable_indices:
67976804
continue
6805+
expr_type = operand_types[i]
67986806
for j, target in type_targets:
67996807
if i == j:
68006808
continue
6801-
expr_type = operand_types[i]
68026809
if_map, else_map = conditional_types_to_typemaps(
68036810
operands[i], *conditional_types(expr_type, [target])
68046811
)
@@ -9387,47 +9394,44 @@ def visit_starred_pattern(self, p: StarredPattern) -> None:
93879394
self.lvalue = False
93889395

93899396

9390-
def is_ambiguous_mix_of_enums(types: list[Type]) -> bool:
9391-
"""Do types have IntEnum/StrEnum types that are potentially overlapping with other types?
9397+
def ambiguous_enum_equality_keys(t: Type) -> set[str]:
9398+
"""
9399+
Used when narrowing types based on equality.
93929400
9393-
If True, we shouldn't attempt type narrowing based on enum values, as it gets
9394-
too ambiguous.
9401+
Certain kinds of enums can compare equal to values of other types, so doing type math
9402+
the way `conditional_types` does will be misleading if you expect it to correspond to
9403+
conditions based on equality comparisons.
93959404
9396-
For example, return True if there's an 'int' type together with an IntEnum literal.
9397-
However, IntEnum together with a literal of the same IntEnum type is not ambiguous.
9405+
For example, StrEnum classes can compare equal to str values. So if we see
9406+
`val: StrEnum; if val == "foo": ...` we currently avoid narrowing.
9407+
Note that we do wish to continue narrowing for `if val == StrEnum.MEMBER: ...`
93989408
"""
93999409
# We need these things for this to be ambiguous:
9400-
# (1) an IntEnum or StrEnum type
9410+
# (1) an IntEnum or StrEnum type or enum subclass of int or str
94019411
# (2) either a different IntEnum/StrEnum type or a non-enum type ("<other>")
9402-
#
9403-
# It would be slightly more correct to calculate this separately for IntEnum and
9404-
# StrEnum related types, as an IntEnum can't be confused with a StrEnum.
9405-
return len(_ambiguous_enum_variants(types)) > 1
9406-
9407-
9408-
def _ambiguous_enum_variants(types: list[Type]) -> set[str]:
94099412
result = set()
9410-
for t in types:
9411-
t = get_proper_type(t)
9412-
if isinstance(t, UnionType):
9413-
result.update(_ambiguous_enum_variants(t.items))
9414-
elif isinstance(t, Instance):
9415-
if t.last_known_value:
9416-
result.update(_ambiguous_enum_variants([t.last_known_value]))
9417-
elif t.type.is_enum and any(
9418-
base.fullname in ("enum.IntEnum", "enum.StrEnum") for base in t.type.mro
9419-
):
9420-
result.add(t.type.fullname)
9421-
elif not t.type.is_enum:
9422-
# These might compare equal to IntEnum/StrEnum types (e.g. Decimal), so
9423-
# let's be conservative
9424-
result.add("<other>")
9425-
elif isinstance(t, LiteralType):
9426-
result.update(_ambiguous_enum_variants([t.fallback]))
9427-
elif isinstance(t, NoneType):
9428-
pass
9429-
else:
9413+
t = get_proper_type(t)
9414+
if isinstance(t, UnionType):
9415+
for item in t.items:
9416+
result.update(ambiguous_enum_equality_keys(item))
9417+
elif isinstance(t, Instance):
9418+
if t.last_known_value:
9419+
result.update(ambiguous_enum_equality_keys(t.last_known_value))
9420+
elif t.type.is_enum and any(
9421+
base.fullname in ("enum.IntEnum", "enum.StrEnum", "builtins.str", "builtins.int")
9422+
for base in t.type.mro
9423+
):
9424+
result.add(t.type.fullname)
9425+
elif not t.type.is_enum:
9426+
# These might compare equal to IntEnum/StrEnum types (e.g. Decimal), so
9427+
# let's be conservative
94309428
result.add("<other>")
9429+
elif isinstance(t, LiteralType):
9430+
result.update(ambiguous_enum_equality_keys(t.fallback))
9431+
elif isinstance(t, NoneType):
9432+
pass
9433+
else:
9434+
result.add("<other>")
94319435
return result
94329436

94339437

test-data/unit/check-narrowing.test

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2124,7 +2124,7 @@ else:
21242124
[builtins fixtures/ops.pyi]
21252125

21262126
[case testNarrowingWithIntEnum]
2127-
# mypy: strict-equality
2127+
# flags: --strict-equality --warn-unreachable
21282128
from __future__ import annotations
21292129
from typing import Any
21302130
from enum import IntEnum
@@ -2179,7 +2179,7 @@ def f6(x: IE) -> None:
21792179
[builtins fixtures/primitives.pyi]
21802180

21812181
[case testNarrowingWithIntEnum2]
2182-
# mypy: strict-equality
2182+
# flags: --strict-equality --warn-unreachable
21832183
from __future__ import annotations
21842184
from typing import Any
21852185
from enum import IntEnum, Enum
@@ -2284,6 +2284,27 @@ def f4(x: SE) -> None:
22842284
reveal_type(x) # N: Revealed type is "Literal[__main__.SE.B]"
22852285
[builtins fixtures/primitives.pyi]
22862286

2287+
[case testNarrowingWithEnumStrSubclass]
2288+
# flags: --strict-equality --warn-unreachable
2289+
from enum import Enum
2290+
2291+
class ParameterLocation(str, Enum):
2292+
QUERY = "query"
2293+
HEADER = "header"
2294+
PATH = "path"
2295+
2296+
def foo(location: ParameterLocation):
2297+
if location == "path":
2298+
reveal_type(location) # N: Revealed type is "__main__.ParameterLocation"
2299+
else:
2300+
reveal_type(location) # N: Revealed type is "__main__.ParameterLocation"
2301+
2302+
if location == ParameterLocation.PATH:
2303+
reveal_type(location) # N: Revealed type is "Literal[__main__.ParameterLocation.PATH]"
2304+
else:
2305+
reveal_type(location) # N: Revealed type is "Literal[__main__.ParameterLocation.QUERY] | Literal[__main__.ParameterLocation.HEADER]"
2306+
[builtins fixtures/primitives.pyi]
2307+
22872308
[case testConsistentNarrowingEqAndIn]
22882309
# flags: --python-version 3.10
22892310

0 commit comments

Comments
 (0)