Skip to content

Commit 32897bf

Browse files
authored
Fix false negatives in walrus vs inference fallback logic (#20622)
Fixes #20606 This fixes couple edge cases where walrus interferes with type inference fallback logic. This is not a complete fix, but I think it is OK for now as these are probably rare situations. Most changes in `binder.py` are a pure performance optimization to compensate the fact that `frame_context()` will be more hot now. I also leave a TODO explaining a more proper fix for the assignment case (and other possible cases), return case should be already good (it is simpler since we don't need to apply any of the narrowing).
1 parent b5587a3 commit 32897bf

File tree

3 files changed

+123
-38
lines changed

3 files changed

+123
-38
lines changed

mypy/binder.py

Lines changed: 79 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from collections import defaultdict
44
from collections.abc import Iterator
55
from contextlib import contextmanager
6-
from typing import NamedTuple, TypeAlias as _TypeAlias
6+
from typing import Literal, NamedTuple, TypeAlias as _TypeAlias
77

88
from mypy.erasetype import remove_instance_last_known_values
99
from mypy.literals import Key, extract_var_from_literal_hash, literal, literal_hash, subkeys
@@ -83,6 +83,61 @@ def __repr__(self) -> str:
8383
Assigns = defaultdict[Expression, list[tuple[Type, Type | None]]]
8484

8585

86+
class FrameContext:
87+
"""Context manager pushing a Frame to ConditionalTypeBinder.
88+
89+
See frame_context() below for documentation on parameters. We use this class
90+
instead of @contextmanager as a mypyc-specific performance optimization.
91+
"""
92+
93+
def __init__(
94+
self,
95+
binder: ConditionalTypeBinder,
96+
can_skip: bool,
97+
fall_through: int,
98+
break_frame: int,
99+
continue_frame: int,
100+
conditional_frame: bool,
101+
try_frame: bool,
102+
discard: bool,
103+
) -> None:
104+
self.binder = binder
105+
self.can_skip = can_skip
106+
self.fall_through = fall_through
107+
self.break_frame = break_frame
108+
self.continue_frame = continue_frame
109+
self.conditional_frame = conditional_frame
110+
self.try_frame = try_frame
111+
self.discard = discard
112+
113+
def __enter__(self) -> Frame:
114+
assert len(self.binder.frames) > 1
115+
116+
if self.break_frame:
117+
self.binder.break_frames.append(len(self.binder.frames) - self.break_frame)
118+
if self.continue_frame:
119+
self.binder.continue_frames.append(len(self.binder.frames) - self.continue_frame)
120+
if self.try_frame:
121+
self.binder.try_frames.add(len(self.binder.frames) - 1)
122+
123+
new_frame = self.binder.push_frame(self.conditional_frame)
124+
if self.try_frame:
125+
# An exception may occur immediately
126+
self.binder.allow_jump(-1)
127+
return new_frame
128+
129+
def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> Literal[False]:
130+
self.binder.pop_frame(self.can_skip, self.fall_through, discard=self.discard)
131+
132+
if self.break_frame:
133+
self.binder.break_frames.pop()
134+
if self.continue_frame:
135+
self.binder.continue_frames.pop()
136+
if self.try_frame:
137+
self.binder.try_frames.remove(len(self.binder.frames) - 1)
138+
return False
139+
140+
86141
class ConditionalTypeBinder:
87142
"""Keep track of conditional types of variables.
88143
@@ -338,10 +393,10 @@ def update_from_options(self, frames: list[Frame]) -> bool:
338393

339394
return changed
340395

341-
def pop_frame(self, can_skip: bool, fall_through: int) -> Frame:
396+
def pop_frame(self, can_skip: bool, fall_through: int, *, discard: bool = False) -> Frame:
342397
"""Pop a frame and return it.
343398
344-
See frame_context() for documentation of fall_through.
399+
See frame_context() for documentation of fall_through and discard.
345400
"""
346401

347402
if fall_through > 0:
@@ -350,6 +405,10 @@ def pop_frame(self, can_skip: bool, fall_through: int) -> Frame:
350405
result = self.frames.pop()
351406
options = self.options_on_return.pop()
352407

408+
if discard:
409+
self.last_pop_changed = False
410+
return result
411+
353412
if can_skip:
354413
options.insert(0, self.frames[-1])
355414

@@ -484,7 +543,6 @@ def handle_continue(self) -> None:
484543
self.allow_jump(self.continue_frames[-1])
485544
self.unreachable()
486545

487-
@contextmanager
488546
def frame_context(
489547
self,
490548
*,
@@ -494,53 +552,45 @@ def frame_context(
494552
continue_frame: int = 0,
495553
conditional_frame: bool = False,
496554
try_frame: bool = False,
497-
) -> Iterator[Frame]:
555+
discard: bool = False,
556+
) -> FrameContext:
498557
"""Return a context manager that pushes/pops frames on enter/exit.
499558
500559
If can_skip is True, control flow is allowed to bypass the
501560
newly-created frame.
502561
503562
If fall_through > 0, then it will allow control flow that
504563
falls off the end of the frame to escape to its ancestor
505-
`fall_through` levels higher. Otherwise control flow ends
564+
`fall_through` levels higher. Otherwise, control flow ends
506565
at the end of the frame.
507566
508567
If break_frame > 0, then 'break' statements within this frame
509568
will jump out to the frame break_frame levels higher than the
510-
frame created by this call to frame_context. Similarly for
569+
frame created by this call to frame_context. Similarly, for
511570
continue_frame and 'continue' statements.
512571
513572
If try_frame is true, then execution is allowed to jump at any
514573
point within the newly created frame (or its descendants) to
515574
its parent (i.e., to the frame that was on top before this
516575
call to frame_context).
517576
577+
If discard is True, then this is a temporary throw-away frame
578+
(used e.g. for isolation) and its effect will be discarded on pop.
579+
518580
After the context manager exits, self.last_pop_changed indicates
519581
whether any types changed in the newly-topmost frame as a result
520582
of popping this frame.
521583
"""
522-
assert len(self.frames) > 1
523-
524-
if break_frame:
525-
self.break_frames.append(len(self.frames) - break_frame)
526-
if continue_frame:
527-
self.continue_frames.append(len(self.frames) - continue_frame)
528-
if try_frame:
529-
self.try_frames.add(len(self.frames) - 1)
530-
531-
new_frame = self.push_frame(conditional_frame)
532-
if try_frame:
533-
# An exception may occur immediately
534-
self.allow_jump(-1)
535-
yield new_frame
536-
self.pop_frame(can_skip, fall_through)
537-
538-
if break_frame:
539-
self.break_frames.pop()
540-
if continue_frame:
541-
self.continue_frames.pop()
542-
if try_frame:
543-
self.try_frames.remove(len(self.frames) - 1)
584+
return FrameContext(
585+
self,
586+
can_skip=can_skip,
587+
fall_through=fall_through,
588+
break_frame=break_frame,
589+
continue_frame=continue_frame,
590+
conditional_frame=conditional_frame,
591+
try_frame=try_frame,
592+
discard=discard,
593+
)
544594

545595
@contextmanager
546596
def top_frame_context(self) -> Iterator[Frame]:

mypy/checker.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3391,7 +3391,7 @@ def check_assignment(
33913391
inferred = None
33923392

33933393
# Special case: only non-abstract non-protocol classes can be assigned to
3394-
# variables with explicit type Type[A], where A is protocol or abstract.
3394+
# variables with explicit type `Type[A]`, where A is protocol or abstract.
33953395
p_rvalue_type = get_proper_type(rvalue_type)
33963396
p_lvalue_type = get_proper_type(lvalue_type)
33973397
if (
@@ -4664,6 +4664,19 @@ def check_simple_assignment(
46644664
type_context = lvalue_type
46654665
else:
46664666
type_context = None
4667+
4668+
# TODO: make assignment checking correct in presence of walrus in r.h.s.
4669+
# Right now we can accept the r.h.s. up to four(!) times. In presence of
4670+
# walrus this can result in weird false negatives and "back action". A proper
4671+
# solution would be to:
4672+
# * Refactor the code to reduce number of times we accept the r.h.s.
4673+
# (two should be enough: empty context + l.h.s. context).
4674+
# * For each accept use binder.accumulate_type_assignments() and assign
4675+
# the types inferred for context that is ultimately used.
4676+
# For now we simply disable some logic that is known to cause problems in
4677+
# presence of walrus, see e.g. testAssignToOptionalTupleWalrus.
4678+
binder_version = self.binder.version
4679+
46674680
rvalue_type = self.expr_checker.accept(
46684681
rvalue, type_context=type_context, always_allow_any=always_allow_any
46694682
)
@@ -4711,6 +4724,7 @@ def check_simple_assignment(
47114724
# Skip literal types, as they have special logic (for better errors).
47124725
and not is_literal_type_like(rvalue_type)
47134726
and not self.simple_rvalue(rvalue)
4727+
and binder_version == self.binder.version
47144728
):
47154729
# Try re-inferring r.h.s. in empty context, and use that if it
47164730
# results in a narrower type. We don't do this always because this
@@ -4913,11 +4927,13 @@ def visit_return_stmt(self, s: ReturnStmt) -> None:
49134927
def infer_context_dependent(
49144928
self, expr: Expression, type_ctx: Type, allow_none_func_call: bool
49154929
) -> ProperType:
4916-
"""Infer type of an expression with fallback to empty type context."""
4917-
with self.msg.filter_errors(
4918-
filter_errors=True, filter_deprecated=True, save_filtered_errors=True
4919-
) as msg:
4920-
with self.local_type_map as type_map:
4930+
"""Infer type of expression with fallback to empty type context."""
4931+
with self.msg.filter_errors(filter_deprecated=True, save_filtered_errors=True) as msg:
4932+
with (
4933+
self.local_type_map as type_map,
4934+
# Prevent any narrowing (e.g. from walrus) to have effect during second accept.
4935+
self.binder.frame_context(can_skip=False, discard=True),
4936+
):
49214937
typ = get_proper_type(
49224938
self.expr_checker.accept(
49234939
expr, type_ctx, allow_none_return=allow_none_func_call
@@ -4930,9 +4946,7 @@ def infer_context_dependent(
49304946
# If there are errors with the original type context, try re-inferring in empty context.
49314947
original_messages = msg.filtered_errors()
49324948
original_type_map = type_map
4933-
with self.msg.filter_errors(
4934-
filter_errors=True, filter_deprecated=True, save_filtered_errors=True
4935-
) as msg:
4949+
with self.msg.filter_errors(filter_deprecated=True, save_filtered_errors=True) as msg:
49364950
with self.local_type_map as type_map:
49374951
alt_typ = get_proper_type(
49384952
self.expr_checker.accept(expr, None, allow_none_return=allow_none_func_call)

test-data/unit/check-python38.test

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -810,3 +810,24 @@ y: List[int]
810810
if (y := []):
811811
reveal_type(y) # N: Revealed type is "builtins.list[builtins.int]"
812812
[builtins fixtures/list.pyi]
813+
814+
[case testAssignToOptionalTupleWalrus]
815+
from typing import Optional
816+
817+
def condition() -> bool: return False
818+
819+
i: Optional[int] = 0 if condition() else None
820+
x: Optional[tuple[int, int]] = (i, (i := 1)) # E: Incompatible types in assignment (expression has type "tuple[int | None, int]", variable has type "tuple[int, int] | None")
821+
[builtins fixtures/tuple.pyi]
822+
823+
[case testReturnTupleOptionalWalrus]
824+
from typing import Optional
825+
826+
def condition() -> bool: return False
827+
828+
def fn() -> tuple[int, int]:
829+
i: Optional[int] = 0 if condition() else None
830+
return (i, (i := i + 1)) # E: Incompatible return value type (got "tuple[int | None, int]", expected "tuple[int, int]") \
831+
# E: Unsupported operand types for + ("None" and "int") \
832+
# N: Left operand is of type "int | None"
833+
[builtins fixtures/dict.pyi]

0 commit comments

Comments
 (0)