Skip to content

[mypyc] Fix exception swallowing in async try/finally blocks with await #19353

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 135 additions & 2 deletions mypyc/irbuild/statement.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from collections.abc import Sequence
from typing import Callable

import mypy.nodes
from mypy.nodes import (
ARG_NAMED,
ARG_POS,
Expand Down Expand Up @@ -101,6 +102,7 @@
get_exc_info_op,
get_exc_value_op,
keep_propagating_op,
no_err_occurred_op,
raise_exception_op,
reraise_exception_op,
restore_exc_info_op,
Expand Down Expand Up @@ -679,7 +681,7 @@ def try_finally_resolve_control(


def transform_try_finally_stmt(
builder: IRBuilder, try_body: GenFunc, finally_body: GenFunc
builder: IRBuilder, try_body: GenFunc, finally_body: GenFunc, line: int = -1
) -> None:
"""Generalized try/finally handling that takes functions to gen the bodies.

Expand Down Expand Up @@ -715,6 +717,118 @@ def transform_try_finally_stmt(
builder.activate_block(out_block)


def transform_try_finally_stmt_async(
builder: IRBuilder, try_body: GenFunc, finally_body: GenFunc, line: int = -1
) -> None:
"""Async-aware try/finally handling for when finally contains await.

This version uses a modified approach that preserves exceptions across await."""

# We need to handle returns properly, so we'll use TryFinallyNonlocalControl
# to track return values, similar to the regular try/finally implementation

err_handler, main_entry, return_entry, finally_entry = (
BasicBlock(),
BasicBlock(),
BasicBlock(),
BasicBlock(),
)

# Track if we're returning from the try block
control = TryFinallyNonlocalControl(return_entry)
builder.builder.push_error_handler(err_handler)
builder.nonlocal_control.append(control)
builder.goto_and_activate(BasicBlock())
try_body()
builder.goto(main_entry)
builder.nonlocal_control.pop()
builder.builder.pop_error_handler()
ret_reg = control.ret_reg

# Normal case - no exception or return
builder.activate_block(main_entry)
builder.goto(finally_entry)

# Return case
builder.activate_block(return_entry)
builder.goto(finally_entry)

# Exception case - need to catch to clear the error indicator
builder.activate_block(err_handler)
# Catch the error to clear Python's error indicator
builder.call_c(error_catch_op, [], line)
# We're not going to use old_exc since it won't survive await
# The exception is now in sys.exc_info()
builder.goto(finally_entry)

# Finally block
builder.activate_block(finally_entry)

# Execute finally body
finally_body()

# After finally, we need to handle exceptions carefully:
# 1. If finally raised a new exception, it's in the error indicator - let it propagate
# 2. If finally didn't raise, check if we need to reraise the original from sys.exc_info()
# 3. If there was a return, return that value
# 4. Otherwise, normal exit

# First, check if there's a current exception in the error indicator
# (this would be from the finally block)
no_current_exc = builder.call_c(no_err_occurred_op, [], line)
finally_raised = BasicBlock()
check_original = BasicBlock()
builder.add(Branch(no_current_exc, check_original, finally_raised, Branch.BOOL))

# Finally raised an exception - let it propagate naturally
builder.activate_block(finally_raised)
builder.call_c(keep_propagating_op, [], NO_TRACEBACK_LINE_NO)
builder.add(Unreachable())

# No exception from finally, check if we need to handle return or original exception
builder.activate_block(check_original)

# Check if we have a return value
if ret_reg:
return_block, check_old_exc = BasicBlock(), BasicBlock()
builder.add(Branch(builder.read(ret_reg), check_old_exc, return_block, Branch.IS_ERROR))

builder.activate_block(return_block)
builder.nonlocal_control[-1].gen_return(builder, builder.read(ret_reg), -1)

builder.activate_block(check_old_exc)

# Check if we need to reraise the original exception from sys.exc_info
exc_info = builder.call_c(get_exc_info_op, [], line)
exc_type = builder.add(TupleGet(exc_info, 0, line))

# Check if exc_type is None
none_obj = builder.none_object()
has_exc = builder.binary_op(exc_type, none_obj, "is not", line)

reraise_block, exit_block = BasicBlock(), BasicBlock()
builder.add(Branch(has_exc, reraise_block, exit_block, Branch.BOOL))

# Reraise the original exception
builder.activate_block(reraise_block)
builder.call_c(reraise_exception_op, [], NO_TRACEBACK_LINE_NO)
builder.add(Unreachable())

# Normal exit
builder.activate_block(exit_block)


# A simple visitor to detect await expressions
class AwaitDetector(mypy.traverser.TraverserVisitor):
def __init__(self) -> None:
super().__init__()
self.has_await = False

def visit_await_expr(self, o: mypy.nodes.AwaitExpr) -> None:
self.has_await = True
super().visit_await_expr(o)


def transform_try_stmt(builder: IRBuilder, t: TryStmt) -> None:
# Our compilation strategy for try/except/else/finally is to
# treat try/except/else and try/finally as separate language
Expand All @@ -723,6 +837,17 @@ def transform_try_stmt(builder: IRBuilder, t: TryStmt) -> None:
# body of a try/finally block.
if t.is_star:
builder.error("Exception groups and except* cannot be compiled yet", t.line)

# Check if we're in an async function with a finally block that contains await
use_async_version = False
if t.finally_body and builder.fn_info.is_coroutine:
detector = AwaitDetector()
t.finally_body.accept(detector)

if detector.has_await:
# Use the async version that handles exceptions correctly
use_async_version = True

if t.finally_body:

def transform_try_body() -> None:
Expand All @@ -733,7 +858,14 @@ def transform_try_body() -> None:

body = t.finally_body

transform_try_finally_stmt(builder, transform_try_body, lambda: builder.accept(body))
if use_async_version:
transform_try_finally_stmt_async(
builder, transform_try_body, lambda: builder.accept(body), t.line
)
else:
transform_try_finally_stmt(
builder, transform_try_body, lambda: builder.accept(body), t.line
)
else:
transform_try_except_stmt(builder, t)

Expand Down Expand Up @@ -824,6 +956,7 @@ def finally_body() -> None:
builder,
lambda: transform_try_except(builder, try_body, [(None, None, except_body)], None, line),
finally_body,
line,
)


Expand Down
Loading