From 7b0abac8462a472d69714e316cba40ab66da6511 Mon Sep 17 00:00:00 2001 From: Jorrit Jongma Date: Mon, 30 Jun 2025 12:02:28 +0200 Subject: [PATCH 1/3] [mypyc] Fix AttributeError in async try/finally with mixed return paths Async functions with try/finally blocks were raising AttributeError when: - Some paths in the try block return while others don't - The non-return path is executed at runtime - No further await calls are needed This occurred because mypyc's IR requires all control flow paths to assign to spill targets (temporary variables stored as generator attributes). The non-return path assigns NULL to maintain this invariant, but reading NULL attributes raises AttributeError in Python. Created a new IR operation `GetAttrNullable` that can read NULL attributes without raising AttributeError. This operation is used specifically in try/finally resolution when reading spill targets. - Added `GetAttrNullable` class to mypyc/ir/ops.py with error_kind=ERR_NEVER - Added `read_nullable_attr` method to IRBuilder for creating these operations - Modified `try_finally_resolve_control` in statement.py to use GetAttrNullable only for spill targets (attributes starting with '__mypyc_temp__') - Implemented C code generation in emitfunc.py that reads attributes without NULL checks and only increments reference count if not NULL - Added visitor implementations to all required files: - ir/pprint.py (pretty printing) - analysis/dataflow.py (dataflow analysis) - analysis/ircheck.py (IR validation) - analysis/selfleaks.py (self leak analysis) - transform/ir_transform.py (IR transformation) 1. **Separate operation vs flag**: Created a new operation instead of adding a flag to GetAttr for better performance - avoids runtime flag checks on every attribute access. 2. **Targeted fix**: Only applied to spill targets in try/finally resolution, not a general replacement for GetAttr. This minimizes risk and maintains existing behavior for all other attribute access. 3. **No initialization changes**: Initially tried initializing spill targets to Py_None instead of NULL, but this would incorrectly make try/finally blocks return None instead of falling through to subsequent code. Added two test cases to mypyc/test-data/run-async.test: 1. **testAsyncTryFinallyMixedReturn**: Tests the basic issue with async try/finally blocks containing mixed return/non-return paths. 2. **testAsyncWithMixedReturn**: Tests async with statements (which use try/finally under the hood) to ensure the fix works for this common pattern as well. Both tests verify that the AttributeError no longer occurs when taking the non-return path through the try block. See mypyc/mypyc#1115 --- mypyc/analysis/dataflow.py | 4 + mypyc/analysis/ircheck.py | 5 + mypyc/analysis/selfleaks.py | 8 + mypyc/codegen/emitfunc.py | 19 ++ mypyc/ir/ops.py | 33 ++++ mypyc/ir/pprint.py | 4 + mypyc/irbuild/builder.py | 10 ++ mypyc/irbuild/statement.py | 10 +- mypyc/test-data/run-async.test | 303 ++++++++++++++++++++++++++++++++ mypyc/transform/ir_transform.py | 7 + 10 files changed, 401 insertions(+), 2 deletions(-) diff --git a/mypyc/analysis/dataflow.py b/mypyc/analysis/dataflow.py index db62ef1700fa..affa7d63e887 100644 --- a/mypyc/analysis/dataflow.py +++ b/mypyc/analysis/dataflow.py @@ -24,6 +24,7 @@ FloatNeg, FloatOp, GetAttr, + GetAttrNullable, GetElementPtr, Goto, IncRef, @@ -209,6 +210,9 @@ def visit_load_literal(self, op: LoadLiteral) -> GenAndKill[T]: def visit_get_attr(self, op: GetAttr) -> GenAndKill[T]: return self.visit_register_op(op) + def visit_get_attr_nullable(self, op: GetAttrNullable) -> GenAndKill[T]: + return self.visit_register_op(op) + def visit_set_attr(self, op: SetAttr) -> GenAndKill[T]: return self.visit_register_op(op) diff --git a/mypyc/analysis/ircheck.py b/mypyc/analysis/ircheck.py index 88737ac208de..d5acf5c1e27e 100644 --- a/mypyc/analysis/ircheck.py +++ b/mypyc/analysis/ircheck.py @@ -21,6 +21,7 @@ FloatNeg, FloatOp, GetAttr, + GetAttrNullable, GetElementPtr, Goto, IncRef, @@ -319,6 +320,10 @@ def visit_get_attr(self, op: GetAttr) -> None: # Nothing to do. pass + def visit_get_attr_nullable(self, op: GetAttrNullable) -> None: + # Nothing to do. + pass + def visit_set_attr(self, op: SetAttr) -> None: # Nothing to do. pass diff --git a/mypyc/analysis/selfleaks.py b/mypyc/analysis/selfleaks.py index 4d3a7c87c5d1..600a2f64c1f2 100644 --- a/mypyc/analysis/selfleaks.py +++ b/mypyc/analysis/selfleaks.py @@ -16,6 +16,7 @@ FloatNeg, FloatOp, GetAttr, + GetAttrNullable, GetElementPtr, Goto, InitStatic, @@ -114,6 +115,13 @@ def visit_get_attr(self, op: GetAttr) -> GenAndKill: return self.check_register_op(op) return CLEAN + def visit_get_attr_nullable(self, op: GetAttrNullable) -> GenAndKill: + cl = op.class_type.class_ir + if cl.get_method(op.attr): + # Property -- calls a function + return self.check_register_op(op) + return CLEAN + def visit_set_attr(self, op: SetAttr) -> GenAndKill: cl = op.class_type.class_ir if cl.get_method(op.attr): diff --git a/mypyc/codegen/emitfunc.py b/mypyc/codegen/emitfunc.py index c854516825af..cd1be7562f92 100644 --- a/mypyc/codegen/emitfunc.py +++ b/mypyc/codegen/emitfunc.py @@ -40,6 +40,7 @@ FloatNeg, FloatOp, GetAttr, + GetAttrNullable, GetElementPtr, Goto, IncRef, @@ -426,6 +427,24 @@ def visit_get_attr(self, op: GetAttr) -> None: elif not always_defined: self.emitter.emit_line("}") + def visit_get_attr_nullable(self, op: GetAttrNullable) -> None: + """Handle GetAttrNullable which allows NULL without raising AttributeError.""" + dest = self.reg(op) + obj = self.reg(op.obj) + rtype = op.class_type + cl = rtype.class_ir + attr_rtype, decl_cl = cl.attr_details(op.attr) + + # Direct struct access without NULL check + attr_expr = self.get_attr_expr(obj, op, decl_cl) + self.emitter.emit_line(f"{dest} = {attr_expr};") + + # Only emit inc_ref if not NULL + if attr_rtype.is_refcounted and not op.is_borrowed: + self.emitter.emit_line(f"if ({dest} != NULL) {{") + self.emitter.emit_inc_ref(dest, attr_rtype) + self.emitter.emit_line("}") + def next_branch(self) -> Branch | None: if self.op_index + 1 < len(self.ops): next_op = self.ops[self.op_index + 1] diff --git a/mypyc/ir/ops.py b/mypyc/ir/ops.py index eec9c34a965e..651c5a96a2b8 100644 --- a/mypyc/ir/ops.py +++ b/mypyc/ir/ops.py @@ -799,6 +799,35 @@ def accept(self, visitor: OpVisitor[T]) -> T: return visitor.visit_get_attr(self) +class GetAttrNullable(RegisterOp): + """obj.attr (for a native object) - allows NULL without raising AttributeError + + This is used for spill targets where NULL indicates the non-return path was taken. + Unlike GetAttr, this won't raise AttributeError when the attribute is NULL. + """ + + error_kind = ERR_NEVER + + def __init__(self, obj: Value, attr: str, line: int, *, borrow: bool = False) -> None: + super().__init__(line) + self.obj = obj + self.attr = attr + assert isinstance(obj.type, RInstance), "Attribute access not supported: %s" % obj.type + self.class_type = obj.type + attr_type = obj.type.attr_type(attr) + self.type = attr_type + self.is_borrowed = borrow and attr_type.is_refcounted + + def sources(self) -> list[Value]: + return [self.obj] + + def set_sources(self, new: list[Value]) -> None: + (self.obj,) = new + + def accept(self, visitor: OpVisitor[T]) -> T: + return visitor.visit_get_attr_nullable(self) + + class SetAttr(RegisterOp): """obj.attr = src (for a native object) @@ -1728,6 +1757,10 @@ def visit_load_literal(self, op: LoadLiteral) -> T: def visit_get_attr(self, op: GetAttr) -> T: raise NotImplementedError + @abstractmethod + def visit_get_attr_nullable(self, op: GetAttrNullable) -> T: + raise NotImplementedError + @abstractmethod def visit_set_attr(self, op: SetAttr) -> T: raise NotImplementedError diff --git a/mypyc/ir/pprint.py b/mypyc/ir/pprint.py index 6c96a21e473b..3d18fe5288c5 100644 --- a/mypyc/ir/pprint.py +++ b/mypyc/ir/pprint.py @@ -28,6 +28,7 @@ FloatNeg, FloatOp, GetAttr, + GetAttrNullable, GetElementPtr, Goto, IncRef, @@ -128,6 +129,9 @@ def visit_load_literal(self, op: LoadLiteral) -> str: def visit_get_attr(self, op: GetAttr) -> str: return self.format("%r = %s%r.%s", op, self.borrow_prefix(op), op.obj, op.attr) + def visit_get_attr_nullable(self, op: GetAttrNullable) -> str: + return self.format("%r = %s%r.%s?", op, self.borrow_prefix(op), op.obj, op.attr) + def borrow_prefix(self, op: Op) -> str: if op.is_borrowed: return "borrow " diff --git a/mypyc/irbuild/builder.py b/mypyc/irbuild/builder.py index 75e059a5b570..58d6d5756d89 100644 --- a/mypyc/irbuild/builder.py +++ b/mypyc/irbuild/builder.py @@ -72,6 +72,7 @@ Branch, ComparisonOp, GetAttr, + GetAttrNullable, InitStatic, Integer, IntOp, @@ -708,6 +709,15 @@ def read( assert False, "Unsupported lvalue: %r" % target + def read_nullable_attr(self, obj: Value, attr: str, line: int = -1) -> Value: + """Read an attribute that might be NULL without raising AttributeError. + + This is used for reading spill targets in try/finally blocks where NULL + indicates the non-return path was taken. + """ + assert isinstance(obj.type, RInstance) and obj.type.class_ir.is_ext_class + return self.add(GetAttrNullable(obj, attr, line)) + def assign(self, target: Register | AssignmentTarget, rvalue_reg: Value, line: int) -> None: if isinstance(target, Register): self.add(Assign(target, self.coerce_rvalue(rvalue_reg, target.type, line))) diff --git a/mypyc/irbuild/statement.py b/mypyc/irbuild/statement.py index 16a0483a8729..5c32d8f1a50c 100644 --- a/mypyc/irbuild/statement.py +++ b/mypyc/irbuild/statement.py @@ -46,6 +46,7 @@ YieldExpr, YieldFromExpr, ) +from mypyc.common import TEMP_ATTR_NAME from mypyc.ir.ops import ( NAMESPACE_MODULE, NO_TRACEBACK_LINE_NO, @@ -653,10 +654,15 @@ def try_finally_resolve_control( if ret_reg: builder.activate_block(rest) return_block, rest = BasicBlock(), BasicBlock() - builder.add(Branch(builder.read(ret_reg), rest, return_block, Branch.IS_ERROR)) + # For spill targets in try/finally, use nullable read to avoid AttributeError + if isinstance(ret_reg, AssignmentTargetAttr) and ret_reg.attr.startswith(TEMP_ATTR_NAME): + ret_val = builder.read_nullable_attr(ret_reg.obj, ret_reg.attr, -1) + else: + ret_val = builder.read(ret_reg) + builder.add(Branch(ret_val, rest, return_block, Branch.IS_ERROR)) builder.activate_block(return_block) - builder.nonlocal_control[-1].gen_return(builder, builder.read(ret_reg), -1) + builder.nonlocal_control[-1].gen_return(builder, ret_val, -1) # TODO: handle break/continue builder.activate_block(rest) diff --git a/mypyc/test-data/run-async.test b/mypyc/test-data/run-async.test index 11ce67077270..2dad720f99cd 100644 --- a/mypyc/test-data/run-async.test +++ b/mypyc/test-data/run-async.test @@ -643,3 +643,306 @@ def test_async_def_contains_two_nested_functions() -> None: [file asyncio/__init__.pyi] def run(x: object) -> object: ... + +[case testAsyncTryFinallyMixedReturn] +# This used to raise an AttributeError, when: +# - the try block contains multiple paths +# - at least one of those explicitly returns +# - at least one of those does not explicitly return +# - the non-returning path is taken at runtime + +import asyncio + + +async def test_mixed_return(b: bool) -> bool: + try: + if b: + return b + finally: + pass + return b + + +async def test_run() -> None: + # Test return path + result1 = await test_mixed_return(True) + assert result1 == True + + # Test non-return path + result2 = await test_mixed_return(False) + assert result2 == False + + +def test_async_try_finally_mixed_return() -> None: + asyncio.run(test_run()) + +[file driver.py] +from native import test_async_try_finally_mixed_return +test_async_try_finally_mixed_return() + +[file asyncio/__init__.pyi] +def run(x: object) -> object: ... + +[case testAsyncWithMixedReturn] +# This used to raise an AttributeError, related to +# testAsyncTryFinallyMixedReturn, this is essentially +# a far more extensive version of that test surfacing +# more edge cases + +import asyncio +from typing import Optional, Type, Literal + + +class AsyncContextManager: + async def __aenter__(self) -> "AsyncContextManager": + return self + + async def __aexit__( + self, + t: Optional[Type[BaseException]], + v: Optional[BaseException], + tb: object, + ) -> Literal[False]: + return False + + +# Simple async functions (generator class) +async def test_gen_1(b: bool) -> bool: + async with AsyncContextManager(): + if b: + return b + return b + + +async def test_gen_2(b: bool) -> bool: + async with AsyncContextManager(): + if b: + return b + else: + return b + + +async def test_gen_3(b: bool) -> bool: + async with AsyncContextManager(): + if b: + return b + else: + pass + return b + + +async def test_gen_4(b: bool) -> bool: + ret: bool + async with AsyncContextManager(): + if b: + ret = b + else: + ret = b + return ret + + +async def test_gen_5(i: int) -> int: + async with AsyncContextManager(): + if i == 1: + return i + elif i == 2: + pass + elif i == 3: + return i + return i + + +async def test_gen_6(i: int) -> int: + async with AsyncContextManager(): + if i == 1: + return i + elif i == 2: + return i + elif i == 3: + return i + return i + + +async def test_gen_7(i: int) -> int: + async with AsyncContextManager(): + if i == 1: + return i + elif i == 2: + return i + elif i == 3: + return i + else: + return i + + +# Async functions with nested functions (environment class) +async def test_env_1(b: bool) -> bool: + def helper() -> bool: + return True + + async with AsyncContextManager(): + if b: + return helper() + return b + + +async def test_env_2(b: bool) -> bool: + def helper() -> bool: + return True + + async with AsyncContextManager(): + if b: + return helper() + else: + return b + + +async def test_env_3(b: bool) -> bool: + def helper() -> bool: + return True + + async with AsyncContextManager(): + if b: + return helper() + else: + pass + return b + + +async def test_env_4(b: bool) -> bool: + def helper() -> bool: + return True + + ret: bool + async with AsyncContextManager(): + if b: + ret = helper() + else: + ret = b + return ret + + +async def test_env_5(i: int) -> int: + def helper() -> int: + return 1 + + async with AsyncContextManager(): + if i == 1: + return helper() + elif i == 2: + pass + elif i == 3: + return i + return i + + +async def test_env_6(i: int) -> int: + def helper() -> int: + return 1 + + async with AsyncContextManager(): + if i == 1: + return helper() + elif i == 2: + return i + elif i == 3: + return i + return i + + +async def test_env_7(i: int) -> int: + def helper() -> int: + return 1 + + async with AsyncContextManager(): + if i == 1: + return helper() + elif i == 2: + return i + elif i == 3: + return i + else: + return i + + +async def run_all_tests() -> None: + # Test simple async functions (generator class) + # test_env_1: mixed return/no-return + assert await test_gen_1(True) is True + assert await test_gen_1(False) is False + + # test_gen_2: all branches return + assert await test_gen_2(True) is True + assert await test_gen_2(False) is False + + # test_gen_3: mixed return/pass + assert await test_gen_3(True) is True + assert await test_gen_3(False) is False + + # test_gen_4: no returns in async with + assert await test_gen_4(True) is True + assert await test_gen_4(False) is False + + # test_gen_5: multiple branches, some return + assert await test_gen_5(0) == 0 + assert await test_gen_5(1) == 1 + assert await test_gen_5(2) == 2 + assert await test_gen_5(3) == 3 + + # test_gen_6: all explicit branches return, implicit fallthrough + assert await test_gen_6(0) == 0 + assert await test_gen_6(1) == 1 + assert await test_gen_6(2) == 2 + assert await test_gen_6(3) == 3 + + # test_gen_7: all branches return including else + assert await test_gen_7(0) == 0 + assert await test_gen_7(1) == 1 + assert await test_gen_7(2) == 2 + assert await test_gen_7(3) == 3 + + # Test async functions with nested functions (environment class) + # test_env_1: mixed return/no-return + assert await test_env_1(True) is True + assert await test_env_1(False) is False + + # test_env_2: all branches return + assert await test_env_2(True) is True + assert await test_env_2(False) is False + + # test_env_3: mixed return/pass + assert await test_env_3(True) is True + assert await test_env_3(False) is False + + # test_env_4: no returns in async with + assert await test_env_4(True) is True + assert await test_env_4(False) is False + + # test_env_5: multiple branches, some return + assert await test_env_5(0) == 0 + assert await test_env_5(1) == 1 + assert await test_env_5(2) == 2 + assert await test_env_5(3) == 3 + + # test_env_6: all explicit branches return, implicit fallthrough + assert await test_env_6(0) == 0 + assert await test_env_6(1) == 1 + assert await test_env_6(2) == 2 + assert await test_env_6(3) == 3 + + # test_env_7: all branches return including else + assert await test_env_7(0) == 0 + assert await test_env_7(1) == 1 + assert await test_env_7(2) == 2 + assert await test_env_7(3) == 3 + + +def test_async_with_mixed_return() -> None: + asyncio.run(run_all_tests()) + +[file driver.py] +from native import test_async_with_mixed_return +test_async_with_mixed_return() + +[file asyncio/__init__.pyi] +def run(x: object) -> object: ... diff --git a/mypyc/transform/ir_transform.py b/mypyc/transform/ir_transform.py index 326a5baca1e7..5c8406b2cf65 100644 --- a/mypyc/transform/ir_transform.py +++ b/mypyc/transform/ir_transform.py @@ -20,6 +20,7 @@ FloatNeg, FloatOp, GetAttr, + GetAttrNullable, GetElementPtr, Goto, IncRef, @@ -136,6 +137,9 @@ def visit_load_literal(self, op: LoadLiteral) -> Value | None: def visit_get_attr(self, op: GetAttr) -> Value | None: return self.add(op) + def visit_get_attr_nullable(self, op: GetAttrNullable) -> Value | None: + return self.add(op) + def visit_set_attr(self, op: SetAttr) -> Value | None: return self.add(op) @@ -268,6 +272,9 @@ def visit_load_literal(self, op: LoadLiteral) -> None: def visit_get_attr(self, op: GetAttr) -> None: op.obj = self.fix_op(op.obj) + def visit_get_attr_nullable(self, op: GetAttrNullable) -> None: + op.obj = self.fix_op(op.obj) + def visit_set_attr(self, op: SetAttr) -> None: op.obj = self.fix_op(op.obj) op.src = self.fix_op(op.src) From db7d4546d0d423dd9adecb8ca94e4395725c5699 Mon Sep 17 00:00:00 2001 From: Jorrit Jongma Date: Mon, 30 Jun 2025 12:18:57 +0200 Subject: [PATCH 2/3] Make GetAttrNullable a subclass of GetAttr --- mypyc/ir/ops.py | 19 ++++--------------- 1 file changed, 4 insertions(+), 15 deletions(-) diff --git a/mypyc/ir/ops.py b/mypyc/ir/ops.py index 651c5a96a2b8..50605dcefddf 100644 --- a/mypyc/ir/ops.py +++ b/mypyc/ir/ops.py @@ -799,7 +799,7 @@ def accept(self, visitor: OpVisitor[T]) -> T: return visitor.visit_get_attr(self) -class GetAttrNullable(RegisterOp): +class GetAttrNullable(GetAttr): """obj.attr (for a native object) - allows NULL without raising AttributeError This is used for spill targets where NULL indicates the non-return path was taken. @@ -809,20 +809,9 @@ class GetAttrNullable(RegisterOp): error_kind = ERR_NEVER def __init__(self, obj: Value, attr: str, line: int, *, borrow: bool = False) -> None: - super().__init__(line) - self.obj = obj - self.attr = attr - assert isinstance(obj.type, RInstance), "Attribute access not supported: %s" % obj.type - self.class_type = obj.type - attr_type = obj.type.attr_type(attr) - self.type = attr_type - self.is_borrowed = borrow and attr_type.is_refcounted - - def sources(self) -> list[Value]: - return [self.obj] - - def set_sources(self, new: list[Value]) -> None: - (self.obj,) = new + super().__init__(obj, attr, line, borrow=borrow) + # Override error_kind since GetAttr sets it based on attr_type.error_overlap + self.error_kind = ERR_NEVER def accept(self, visitor: OpVisitor[T]) -> T: return visitor.visit_get_attr_nullable(self) From 7601eebae8bd0fdde7a4d7a238fa4a0b793a3c81 Mon Sep 17 00:00:00 2001 From: Jorrit Jongma Date: Tue, 1 Jul 2025 21:11:10 +0200 Subject: [PATCH 3/3] Remove GetAttrNullable, use allow_null flag on GetAttr instead --- mypyc/analysis/dataflow.py | 4 ---- mypyc/analysis/ircheck.py | 5 ----- mypyc/analysis/selfleaks.py | 8 -------- mypyc/codegen/emitfunc.py | 8 +++++--- mypyc/ir/ops.py | 31 +++++++------------------------ mypyc/ir/pprint.py | 4 ---- mypyc/irbuild/builder.py | 3 +-- mypyc/transform/ir_transform.py | 7 ------- 8 files changed, 13 insertions(+), 57 deletions(-) diff --git a/mypyc/analysis/dataflow.py b/mypyc/analysis/dataflow.py index affa7d63e887..db62ef1700fa 100644 --- a/mypyc/analysis/dataflow.py +++ b/mypyc/analysis/dataflow.py @@ -24,7 +24,6 @@ FloatNeg, FloatOp, GetAttr, - GetAttrNullable, GetElementPtr, Goto, IncRef, @@ -210,9 +209,6 @@ def visit_load_literal(self, op: LoadLiteral) -> GenAndKill[T]: def visit_get_attr(self, op: GetAttr) -> GenAndKill[T]: return self.visit_register_op(op) - def visit_get_attr_nullable(self, op: GetAttrNullable) -> GenAndKill[T]: - return self.visit_register_op(op) - def visit_set_attr(self, op: SetAttr) -> GenAndKill[T]: return self.visit_register_op(op) diff --git a/mypyc/analysis/ircheck.py b/mypyc/analysis/ircheck.py index d5acf5c1e27e..88737ac208de 100644 --- a/mypyc/analysis/ircheck.py +++ b/mypyc/analysis/ircheck.py @@ -21,7 +21,6 @@ FloatNeg, FloatOp, GetAttr, - GetAttrNullable, GetElementPtr, Goto, IncRef, @@ -320,10 +319,6 @@ def visit_get_attr(self, op: GetAttr) -> None: # Nothing to do. pass - def visit_get_attr_nullable(self, op: GetAttrNullable) -> None: - # Nothing to do. - pass - def visit_set_attr(self, op: SetAttr) -> None: # Nothing to do. pass diff --git a/mypyc/analysis/selfleaks.py b/mypyc/analysis/selfleaks.py index 600a2f64c1f2..4d3a7c87c5d1 100644 --- a/mypyc/analysis/selfleaks.py +++ b/mypyc/analysis/selfleaks.py @@ -16,7 +16,6 @@ FloatNeg, FloatOp, GetAttr, - GetAttrNullable, GetElementPtr, Goto, InitStatic, @@ -115,13 +114,6 @@ def visit_get_attr(self, op: GetAttr) -> GenAndKill: return self.check_register_op(op) return CLEAN - def visit_get_attr_nullable(self, op: GetAttrNullable) -> GenAndKill: - cl = op.class_type.class_ir - if cl.get_method(op.attr): - # Property -- calls a function - return self.check_register_op(op) - return CLEAN - def visit_set_attr(self, op: SetAttr) -> GenAndKill: cl = op.class_type.class_ir if cl.get_method(op.attr): diff --git a/mypyc/codegen/emitfunc.py b/mypyc/codegen/emitfunc.py index cd1be7562f92..00c7fd56b899 100644 --- a/mypyc/codegen/emitfunc.py +++ b/mypyc/codegen/emitfunc.py @@ -40,7 +40,6 @@ FloatNeg, FloatOp, GetAttr, - GetAttrNullable, GetElementPtr, Goto, IncRef, @@ -359,6 +358,9 @@ def get_attr_expr(self, obj: str, op: GetAttr | SetAttr, decl_cl: ClassIR) -> st return f"({cast}{obj})->{self.emitter.attr(op.attr)}" def visit_get_attr(self, op: GetAttr) -> None: + if op.allow_null: + self.get_attr_with_allow_null(op) + return dest = self.reg(op) obj = self.reg(op.obj) rtype = op.class_type @@ -427,8 +429,8 @@ def visit_get_attr(self, op: GetAttr) -> None: elif not always_defined: self.emitter.emit_line("}") - def visit_get_attr_nullable(self, op: GetAttrNullable) -> None: - """Handle GetAttrNullable which allows NULL without raising AttributeError.""" + def get_attr_with_allow_null(self, op: GetAttr) -> None: + """Handle GetAttr with allow_null=True which allows NULL without raising AttributeError.""" dest = self.reg(op) obj = self.reg(op.obj) rtype = op.class_type diff --git a/mypyc/ir/ops.py b/mypyc/ir/ops.py index 50605dcefddf..9dde658231d8 100644 --- a/mypyc/ir/ops.py +++ b/mypyc/ir/ops.py @@ -777,15 +777,20 @@ class GetAttr(RegisterOp): error_kind = ERR_MAGIC - def __init__(self, obj: Value, attr: str, line: int, *, borrow: bool = False) -> None: + def __init__( + self, obj: Value, attr: str, line: int, *, borrow: bool = False, allow_null: bool = False + ) -> None: super().__init__(line) self.obj = obj self.attr = attr + self.allow_null = allow_null assert isinstance(obj.type, RInstance), "Attribute access not supported: %s" % obj.type self.class_type = obj.type attr_type = obj.type.attr_type(attr) self.type = attr_type - if attr_type.error_overlap: + if allow_null: + self.error_kind = ERR_NEVER + elif attr_type.error_overlap: self.error_kind = ERR_MAGIC_OVERLAPPING self.is_borrowed = borrow and attr_type.is_refcounted @@ -799,24 +804,6 @@ def accept(self, visitor: OpVisitor[T]) -> T: return visitor.visit_get_attr(self) -class GetAttrNullable(GetAttr): - """obj.attr (for a native object) - allows NULL without raising AttributeError - - This is used for spill targets where NULL indicates the non-return path was taken. - Unlike GetAttr, this won't raise AttributeError when the attribute is NULL. - """ - - error_kind = ERR_NEVER - - def __init__(self, obj: Value, attr: str, line: int, *, borrow: bool = False) -> None: - super().__init__(obj, attr, line, borrow=borrow) - # Override error_kind since GetAttr sets it based on attr_type.error_overlap - self.error_kind = ERR_NEVER - - def accept(self, visitor: OpVisitor[T]) -> T: - return visitor.visit_get_attr_nullable(self) - - class SetAttr(RegisterOp): """obj.attr = src (for a native object) @@ -1746,10 +1733,6 @@ def visit_load_literal(self, op: LoadLiteral) -> T: def visit_get_attr(self, op: GetAttr) -> T: raise NotImplementedError - @abstractmethod - def visit_get_attr_nullable(self, op: GetAttrNullable) -> T: - raise NotImplementedError - @abstractmethod def visit_set_attr(self, op: SetAttr) -> T: raise NotImplementedError diff --git a/mypyc/ir/pprint.py b/mypyc/ir/pprint.py index 3d18fe5288c5..6c96a21e473b 100644 --- a/mypyc/ir/pprint.py +++ b/mypyc/ir/pprint.py @@ -28,7 +28,6 @@ FloatNeg, FloatOp, GetAttr, - GetAttrNullable, GetElementPtr, Goto, IncRef, @@ -129,9 +128,6 @@ def visit_load_literal(self, op: LoadLiteral) -> str: def visit_get_attr(self, op: GetAttr) -> str: return self.format("%r = %s%r.%s", op, self.borrow_prefix(op), op.obj, op.attr) - def visit_get_attr_nullable(self, op: GetAttrNullable) -> str: - return self.format("%r = %s%r.%s?", op, self.borrow_prefix(op), op.obj, op.attr) - def borrow_prefix(self, op: Op) -> str: if op.is_borrowed: return "borrow " diff --git a/mypyc/irbuild/builder.py b/mypyc/irbuild/builder.py index 58d6d5756d89..878c5e76df3d 100644 --- a/mypyc/irbuild/builder.py +++ b/mypyc/irbuild/builder.py @@ -72,7 +72,6 @@ Branch, ComparisonOp, GetAttr, - GetAttrNullable, InitStatic, Integer, IntOp, @@ -716,7 +715,7 @@ def read_nullable_attr(self, obj: Value, attr: str, line: int = -1) -> Value: indicates the non-return path was taken. """ assert isinstance(obj.type, RInstance) and obj.type.class_ir.is_ext_class - return self.add(GetAttrNullable(obj, attr, line)) + return self.add(GetAttr(obj, attr, line, allow_null=True)) def assign(self, target: Register | AssignmentTarget, rvalue_reg: Value, line: int) -> None: if isinstance(target, Register): diff --git a/mypyc/transform/ir_transform.py b/mypyc/transform/ir_transform.py index 5c8406b2cf65..326a5baca1e7 100644 --- a/mypyc/transform/ir_transform.py +++ b/mypyc/transform/ir_transform.py @@ -20,7 +20,6 @@ FloatNeg, FloatOp, GetAttr, - GetAttrNullable, GetElementPtr, Goto, IncRef, @@ -137,9 +136,6 @@ def visit_load_literal(self, op: LoadLiteral) -> Value | None: def visit_get_attr(self, op: GetAttr) -> Value | None: return self.add(op) - def visit_get_attr_nullable(self, op: GetAttrNullable) -> Value | None: - return self.add(op) - def visit_set_attr(self, op: SetAttr) -> Value | None: return self.add(op) @@ -272,9 +268,6 @@ def visit_load_literal(self, op: LoadLiteral) -> None: def visit_get_attr(self, op: GetAttr) -> None: op.obj = self.fix_op(op.obj) - def visit_get_attr_nullable(self, op: GetAttrNullable) -> None: - op.obj = self.fix_op(op.obj) - def visit_set_attr(self, op: SetAttr) -> None: op.obj = self.fix_op(op.obj) op.src = self.fix_op(op.src)