Skip to content

feat[lang]!: ban nonreentrant on internal functions #4573

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 0 additions & 42 deletions tests/functional/codegen/modules/test_nonreentrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,45 +34,3 @@ def __default__():
assert c.foo() == 5
with tx_failed():
c.re_enter()


def test_internal_nonreentrant(make_input_bundle, get_contract, tx_failed):
lib1 = """
interface Foo:
def foo() -> uint256: nonpayable

implements: Foo

@external
def foo() -> uint256:
return self._safe_fn()

@internal
@nonreentrant
def _safe_fn() -> uint256:
return 10
"""
main = """
import lib1

initializes: lib1

exports: lib1.foo

@external
@nonreentrant
def re_enter():
extcall lib1.Foo(self).foo() # should always throw

@external
def __default__():
# sanity: make sure we don't revert due to bad selector
pass
"""

input_bundle = make_input_bundle({"lib1.vy": lib1})

c = get_contract(main, input_bundle=input_bundle)
assert c.foo() == 10
with tx_failed():
c.re_enter()
53 changes: 4 additions & 49 deletions tests/functional/syntax/modules/test_initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1306,50 +1306,23 @@ def nonreentrant_library_bundle(make_input_bundle):
# test simple case
lib1 = """
# lib1.vy
@internal
@nonreentrant
def bar():
pass

# lib1.vy
@external
@nonreentrant
def ext_bar():
pass
"""
# test case with recursion
# test case with exports
lib2 = """
@internal
def bar():
self.baz()

@external
def ext_bar():
self.baz()

@nonreentrant
@internal
def baz():
return
"""
# test case with nested recursion
lib3 = """
import lib1
uses: lib1

@internal
def bar():
lib1.bar()

@external
def ext_bar():
lib1.bar()
exports: lib1.ext_bar
"""

return make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2, "lib3.vy": lib3})
return make_input_bundle({"lib1.vy": lib1, "lib2.vy": lib2})


@pytest.mark.parametrize("lib", ("lib1", "lib2", "lib3"))
@pytest.mark.parametrize("lib", ("lib1", "lib2"))
def test_nonreentrant_exports(nonreentrant_library_bundle, lib):
main = f"""
import {lib}
Expand All @@ -1368,24 +1341,6 @@ def foo():
assert e.value.annotations[0].lineno == 4


@pytest.mark.parametrize("lib", ("lib1", "lib2", "lib3"))
def test_internal_nonreentrant_import(nonreentrant_library_bundle, lib):
main = f"""
import {lib}

@external
def foo():
{lib}.bar() # line 6
"""
with pytest.raises(ImmutableViolation) as e:
compile_code(main, input_bundle=nonreentrant_library_bundle)
assert e.value._message == f"Cannot access `{lib}` state!" + NONREENTRANT_NOTE

hint = f"add `uses: {lib}` or `initializes: {lib}` as a top-level statement to your contract"
assert e.value._hint == hint
assert e.value.annotations[0].lineno == 6


def test_global_initialize_missed_import_hint(make_input_bundle, chdir_tmp_path):
lib1 = """
import lib2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,25 @@ def __init__():
""",
FunctionDeclarationException,
),
(
# nonreentrant on an internal function
"""
@nonreentrant
def foo():
pass
""",
FunctionDeclarationException,
),
(
# nonreentrant on an internal function
"""
@internal
@nonreentrant
def foo():
pass
""",
FunctionDeclarationException,
),
]


Expand Down
3 changes: 0 additions & 3 deletions tests/unit/cli/storage_layout/test_storage_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@ def public_foo1():
def public_foo2():
pass


@internal
@nonreentrant
def _bar():
pass

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def public_foo2():


@internal
@nonreentrant
def _bar():
pass

Expand Down
2 changes: 2 additions & 0 deletions vyper/codegen/function_definitions/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ def get_nonreentrant_lock(func_t):
if not func_t.nonreentrant:
return ["pass"], ["pass"]

assert func_t.is_external

nkey = func_t.reentrancy_key_position.position

LOAD, STORE = "sload", "sstore"
Expand Down
7 changes: 4 additions & 3 deletions vyper/semantics/types/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,9 +435,10 @@ def from_FunctionDef(cls, funcdef: vy_ast.FunctionDef) -> "ContractFunctionT":
raise FunctionDeclarationException(
"Constructor may not use default arguments", funcdef.args.defaults[0]
)
if decorators.nonreentrant:
msg = "`@nonreentrant` decorator disallowed on `__init__`"
raise FunctionDeclarationException(msg, decorators.nonreentrant_node)

if function_visibility != FunctionVisibility.EXTERNAL and decorators.nonreentrant:
msg = f"`@nonreentrant` decorator disallowed on {function_visibility} functions!"
raise FunctionDeclarationException(msg, decorators.nonreentrant_node)

return cls(
funcdef.name,
Expand Down