Skip to content

feat[lang]: add ABIBuffer type #4561

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 27 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
d01d143
handle raw return
charles-cooper Apr 6, 2025
cd20250
propagate thru type system
charles-cooper Apr 6, 2025
970ce6b
fix convert
charles-cooper Apr 6, 2025
775f6d3
disallow literal ABIBuffers
charles-cooper Apr 6, 2025
2ec0b12
add a test
charles-cooper Apr 7, 2025
5848506
add invalid locations
charles-cooper Apr 7, 2025
a5fd707
ban in structs and tuples
charles-cooper Apr 7, 2025
8f7b50c
lint
charles-cooper Apr 8, 2025
066fde6
add abibuffer tests
cyberthirst Apr 9, 2025
30db056
add more abibuffer tests
cyberthirst Apr 9, 2025
82c18e9
Merge pull request #69 from cyberthirst/fork/charles-cooper/feat/lang…
charles-cooper Apr 9, 2025
f6dce65
handle extcalls for ABIBuffer
charles-cooper Apr 10, 2025
c0e2a8b
rename a test
charles-cooper Apr 11, 2025
3c140d8
rename ABIBuffer to ReturnBuffer
charles-cooper Apr 11, 2025
d2d7e46
allow returnbuffer in external call
charles-cooper Apr 11, 2025
ae6ad0e
lint
charles-cooper Apr 11, 2025
26a8dc2
ban in darrays
charles-cooper Apr 11, 2025
9e1d62a
strip from abi
charles-cooper Apr 11, 2025
9a61d9f
fix bad import
charles-cooper Apr 12, 2025
6982b8e
Merge branch 'master' into feat/lang/add-abi-buffer-type
charles-cooper Apr 12, 2025
2ad2088
update tests for new void output type
charles-cooper Apr 12, 2025
a5f203d
add returndatasize test
cyberthirst Apr 13, 2025
b63ba35
cleanup
cyberthirst Apr 13, 2025
88e0abf
add explanatory comment
cyberthirst Apr 13, 2025
6bbb1e1
Merge pull request #72 from cyberthirst/fork/charles-cooper/feat/lang…
charles-cooper Apr 13, 2025
4dcffc5
add convert rules from ReturnBufferT
charles-cooper Apr 14, 2025
a10e7e6
fix lint
charles-cooper Apr 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 203 additions & 0 deletions tests/functional/codegen/types/test_buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
from typing import Any

import pytest
from eth.codecs import abi

from tests.evm_backends.base_env import ExecutionReverted
from vyper.compiler import compile_code
from vyper.exceptions import InstantiationException, StructureException, TypeMismatch


# call, but don't abi decode the output
def _call_no_decode(contract_method: Any, *args, **kwargs) -> bytes:
contract = contract_method.contract
calldata = contract_method.prepare_calldata(*args, **kwargs)
output = contract.env.message_call(contract.address, data=calldata)

return output


def test_buffer(get_contract, tx_failed):
test_bytes = """
@external
def foo(x: Bytes[100]) -> ReturnBuffer[100]:
return convert(x, ReturnBuffer[100])
"""

c = get_contract(test_bytes)
return_data = b"cow"
moo_result = _call_no_decode(c.foo, return_data)
assert moo_result == return_data


def test_buffer_in_interface(get_contract, tx_failed):
caller_code = """
interface Foo:
def foo() -> ReturnBuffer[100]: view

@external
def foo(target: Foo) -> ReturnBuffer[100]:
return staticcall target.foo()
"""

return_data = abi.encode("(bytes)", (b"cow",))
target_code = f"""
@external
def foo() -> ReturnBuffer[100]:
return convert(x"{return_data.hex()}", ReturnBuffer[100])
"""
caller = get_contract(caller_code)
target = get_contract(target_code)

assert _call_no_decode(caller.foo, target.address) == return_data


def test_buffer_str_convert(get_contract):
test_bytes = """
@external
def foo(x: Bytes[100]) -> ReturnBuffer[100]:
return convert(convert(x, String[100]), ReturnBuffer[100])
"""

c = get_contract(test_bytes)
moo_result = _call_no_decode(c.foo, b"cow")
assert moo_result == b"cow"


def test_buffer_returndatasize_check(get_contract):
test_bytes = """
interface Foo:
def payload() -> ReturnBuffer[127]: view

interface FooSanity:
def payload() -> ReturnBuffer[128]: view

payload: public(Bytes[33])

@external
def set_payload(b: Bytes[33]):
self.payload = b

@external
def bar() -> ReturnBuffer[127]:
return staticcall Foo(self).payload()

@external
def sanity_check() -> ReturnBuffer[128]:
b: ReturnBuffer[128] = staticcall FooSanity(self).payload()
return b
"""

c = get_contract(test_bytes)
payload = b"a" * 33
c.set_payload(payload)
assert c.payload() == payload

res = _call_no_decode(c.sanity_check)

assert len(res) == 128
assert abi.decode("(bytes)", res) == (payload,)

# revert due to returndatasize being too big
# 32B head, 32B length, 32 bytes payload, 32 right-padded bytes payload
with pytest.raises(ExecutionReverted):
_call_no_decode(c.bar)


def test_buffer_no_subscriptable(get_contract, tx_failed):
code = """
@external
def foo(x: Bytes[128]) -> bytes8:
return convert(x, ReturnBuffer[128])[0]
"""

with pytest.raises(StructureException, match="Not an indexable type"):
compile_code(code)


def test_proxy_raw_return(get_contract):
impl1 = """
@external
def foo() -> String[32]:
return "Hello"
"""

impl2 = """
@external
def foo() -> Bytes[32]:
return b"Goodbye"
"""

impl3 = """
@external
def foo() -> DynArray[uint256, 2]:
#return [1, 2]
a: DynArray[uint256, 2] = [1, 2]
return a
"""

proxy = """
target: address
@external
def set_implementation(target: address):
self.target = target

@external
def foo() -> ReturnBuffer[128]:
data: Bytes[128] = raw_call(self.target, msg.data, is_delegate_call=True, max_outsize=128)
return convert(data, ReturnBuffer[128])
"""

impl_c1 = get_contract(impl1)
impl_c2 = get_contract(impl2)
impl_c3 = get_contract(impl3)

proxy_c = get_contract(proxy)

proxy_c.set_implementation(impl_c1.address)
res = _call_no_decode(proxy_c.foo)
assert abi.decode("(bytes)", res) == (b"Hello",)

proxy_c.set_implementation(impl_c2.address)
res = _call_no_decode(proxy_c.foo)
assert abi.decode("(string)", res) == ("Goodbye",)

proxy_c.set_implementation(impl_c3.address)
res = _call_no_decode(proxy_c.foo)
assert abi.decode("(uint256[])", res) == ([1, 2],)


fail_list = [
("b: ReturnBuffer[128]", InstantiationException),
(
"""b: immutable(ReturnBuffer[128])

@deploy
def __init__():
helper: Bytes[128] = b''
b = convert(helper, ReturnBuffer[128])
""",
InstantiationException,
),
(
"b: constant(ReturnBuffer[128]) = b''",
TypeMismatch,
), # type mismatch for now until we allow buffer literals
("b: transient(ReturnBuffer[128])", InstantiationException),
("b: DynArray[ReturnBuffer[128], 2]", StructureException),
(
"""
@external
def foo(b: ReturnBuffer[128]):
pass
""",
InstantiationException,
),
]


# TODO: move these to syntax tests
@pytest.mark.parametrize("bad_code,exc", fail_list)
def test_abibuffer_fail(bad_code, exc):
with pytest.raises(exc):
compile_code(bad_code)
14 changes: 12 additions & 2 deletions vyper/builtins/_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
DecimalT,
FlagT,
IntegerT,
ReturnBufferT,
StringT,
)
from vyper.semantics.types.bytestrings import _BytestringT
Expand Down Expand Up @@ -423,9 +424,11 @@ def to_address(expr, arg, out_typ):


def _cast_bytestring(expr, arg, out_typ):
# ban converting Bytes[20] to Bytes[21]
# ban converting Bytes[20] to Bytes[21], since that can be done
# by simple assignment.
if isinstance(arg.typ, out_typ.__class__) and arg.typ.maxlen <= out_typ.maxlen:
_FAIL(arg.typ, out_typ, expr)

# can't downcast literals with known length (e.g. b"abc" to Bytes[2])
if isinstance(expr, vy_ast.Constant) and arg.typ.maxlen > out_typ.maxlen:
_FAIL(arg.typ, out_typ, expr)
Expand All @@ -444,11 +447,16 @@ def to_string(expr, arg, out_typ):
return _cast_bytestring(expr, arg, out_typ)


@_input_types(StringT, BytesT)
@_input_types(StringT, BytesT, ReturnBufferT)
def to_bytes(expr, arg, out_typ):
return _cast_bytestring(expr, arg, out_typ)


@_input_types(StringT, BytesT, ReturnBufferT)
def to_return_buffer(expr, arg, out_typ):
return _cast_bytestring(expr, arg, out_typ)


@_input_types(IntegerT)
def to_flag(expr, arg, out_typ):
if arg.typ != UINT256_T:
Expand Down Expand Up @@ -488,6 +496,8 @@ def convert(expr, context):
ret = to_bytes(arg_ast, arg, out_typ)
elif isinstance(out_typ, StringT):
ret = to_string(arg_ast, arg, out_typ)
elif isinstance(out_typ, ReturnBufferT):
ret = to_return_buffer(arg_ast, arg, out_typ)
else:
raise StructureException(f"Conversion to {out_typ} is invalid.", arg_ast)

Expand Down
1 change: 1 addition & 0 deletions vyper/builtins/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2395,6 +2395,7 @@ def build_IR(self, expr, args, kwargs, context):

class ABIDecode(BuiltinFunctionT):
_id = "abi_decode"
# TODO: allow ReturnBuffer here
_inputs = [("data", BytesT.any()), ("output_type", TYPE_T.any())]
_kwargs = {"unwrap_tuple": KwargSettings(BoolT(), True, require_literal=True)}

Expand Down
5 changes: 5 additions & 0 deletions vyper/codegen/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
HashMapT,
IntegerT,
InterfaceT,
ReturnBufferT,
StructT,
TupleT,
_BytestringT,
Expand Down Expand Up @@ -720,6 +721,10 @@ def needs_external_call_wrap(typ):
# including structs. MyStruct is returned as abi-encoded (MyStruct,).
# (Sorry this is so confusing. I didn't make these rules.)

# special case for ReturnBuffer, which lives outside the abi:
if isinstance(typ, ReturnBufferT):
return False

return not (isinstance(typ, TupleT) and typ.length > 1)


Expand Down
12 changes: 11 additions & 1 deletion vyper/codegen/external_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import vyper.utils as util
from vyper.codegen.abi_encoder import abi_encode
from vyper.codegen.core import (
STORE,
_freshname,
add_ofst,
bytes_data_ptr,
calculate_type_for_external_return,
check_assign,
check_external_call,
Expand All @@ -20,7 +22,7 @@
from vyper.codegen.ir_node import Encoding, IRnode
from vyper.evm.address_space import MEMORY
from vyper.exceptions import TypeCheckFailure
from vyper.semantics.types import InterfaceT, TupleT
from vyper.semantics.types import InterfaceT, ReturnBufferT, TupleT
from vyper.semantics.types.function import StateMutability


Expand Down Expand Up @@ -84,6 +86,14 @@ def _unpack_returndata(buf, fn_type, call_kwargs, contract_address, context, exp
if return_t is None:
return ["pass"], 0, 0

if isinstance(return_t, ReturnBufferT):
as_abibuf = copy.copy(buf)
as_abibuf.typ = return_t
check = ["assert", ["le", "returndatasize", return_t.length]]
copy_op = ["returndatacopy", bytes_data_ptr(as_abibuf), 0, "returndatasize"]
set_length = STORE(as_abibuf, "returndatasize")
return ["seq", check, copy_op, set_length, as_abibuf], None, return_t.length

wrapped_return_t = calculate_type_for_external_return(return_t)

abi_return_t = wrapped_return_t.abi_type
Expand Down
13 changes: 13 additions & 0 deletions vyper/codegen/return_.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from vyper.codegen.abi_encoder import abi_encode, abi_encoding_matches_vyper
from vyper.codegen.context import Context
from vyper.codegen.core import (
bytes_data_ptr,
calculate_type_for_external_return,
check_assign,
dummy_node_for_type,
get_bytearray_length,
get_type_for_exact_size,
make_setter,
needs_clamp,
Expand All @@ -14,6 +16,7 @@
from vyper.codegen.ir_node import IRnode
from vyper.evm.address_space import MEMORY
from vyper.exceptions import TypeCheckFailure
from vyper.semantics.types import ReturnBufferT

Stmt = Any # mypy kludge

Expand Down Expand Up @@ -57,6 +60,16 @@ def finalize(fill_return_buffer):
return finalize(fill_return_buffer)

else: # return from external function
# raw return
if isinstance(context.return_type, ReturnBufferT):
# copy to memory
buf = context.new_internal_variable(context.return_type)
return_len = get_bytearray_length(buf)
return_offset = bytes_data_ptr(buf)
jump_to_exit += [return_offset, return_len] # type: ignore
fill_return_buffer = make_setter(buf, ir_val)
return finalize(fill_return_buffer)

external_return_type = calculate_type_for_external_return(context.return_type)
maxlen = external_return_type.abi_type.size_bound()

Expand Down
4 changes: 2 additions & 2 deletions vyper/semantics/analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from vyper.semantics.analysis.levenshtein_utils import get_levenshtein_error_suggestions
from vyper.semantics.namespace import get_namespace
from vyper.semantics.types.base import TYPE_T, VyperType
from vyper.semantics.types.bytestrings import BytesT, StringT
from vyper.semantics.types.bytestrings import _BytestringT

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
vyper.semantics.types.bytestrings
begins an import cycle.
from vyper.semantics.types.primitives import AddressT, BoolT, BytesM_T, IntegerT
from vyper.semantics.types.subscriptable import DArrayT, SArrayT, TupleT
from vyper.utils import OrderedSet, checksum_encode, int_to_fourbytes
Expand Down Expand Up @@ -309,7 +309,7 @@
# special handling for bytestrings since their
# class objects are in the type map, not the type itself
# (worth rethinking this design at some point.)
if t in (BytesT, StringT):
if isinstance(t, type) and issubclass(t, _BytestringT):
t = t.from_literal(node)

# any more validation which needs to occur
Expand Down
4 changes: 2 additions & 2 deletions vyper/semantics/types/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from . import primitives, subscriptable, user
from .base import TYPE_T, VOID_TYPE, KwargSettings, VyperType, is_type_t, map_void
from .bytestrings import BytesT, StringT, _BytestringT
from .bytestrings import BytesT, ReturnBufferT, StringT, _BytestringT
from .function import ContractFunctionT, MemberFunctionT
from .module import InterfaceT, ModuleT
from .primitives import AddressT, BoolT, BytesM_T, DecimalT, IntegerT, SelfT
Expand All @@ -20,7 +20,7 @@ def _get_primitive_types():

# note: since bytestrings are parametrizable, the *class* objects
# are in the namespace instead of concrete type objects.
res.extend([BytesT, StringT])
res.extend([BytesT, StringT, ReturnBufferT])

ret = {t._id: t for t in res}
ret.update(_get_sequence_types())
Expand Down
Loading