Skip to content

feat(specs): [EXAMPLE] Static EOF Test Loader #1346

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 12 commits into from
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ addopts =
-p pytest_plugins.filler.pre_alloc
-p pytest_plugins.solc.solc
-p pytest_plugins.filler.filler
-p pytest_plugins.refiller.refiller
-p pytest_plugins.shared.execute_fill
-p pytest_plugins.forks.forks
-p pytest_plugins.spec_version_checker.spec_version_checker
Expand Down
4 changes: 4 additions & 0 deletions src/ethereum_test_specs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import List, Type

from .base import BaseTest, TestSpec
from .base_static import BaseStaticTest
from .blockchain import (
BlockchainTest,
BlockchainTestFiller,
Expand All @@ -16,6 +17,7 @@
EOFTestFiller,
EOFTestSpec,
)
from .eof_static import EOFStaticTest
from .state import StateTest, StateTestFiller, StateTestSpec
from .transaction import TransactionTest, TransactionTestFiller, TransactionTestSpec

Expand All @@ -30,6 +32,7 @@

__all__ = (
"SPEC_TYPES",
"BaseStaticTest",
"BaseTest",
"BlockchainTest",
"BlockchainTestEngineFiller",
Expand All @@ -39,6 +42,7 @@
"EOFStateTest",
"EOFStateTestFiller",
"EOFStateTestSpec",
"EOFStaticTest",
"EOFTest",
"EOFTestFiller",
"EOFTestSpec",
Expand Down
229 changes: 229 additions & 0 deletions src/ethereum_test_specs/base_static.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
"""Base class to parse test cases written in static formats."""

import re
from abc import abstractmethod
from typing import Any, Callable, ClassVar, Dict, List, Tuple, Type, Union

from pydantic import (
BaseModel,
ConfigDict,
TypeAdapter,
ValidatorFunctionWrapHandler,
model_validator,
)

from ethereum_test_base_types import Bytes
from ethereum_test_forks import Fork, get_forks


class BaseStaticTest(BaseModel):
"""Represents a base class that reads cases from static files."""

formats: ClassVar[List[Type["BaseStaticTest"]]] = []
formats_type_adapter: ClassVar[TypeAdapter]

format_name: ClassVar[str] = ""

@classmethod
def __pydantic_init_subclass__(cls, **kwargs):
"""
Register all subclasses of BaseStaticTest with a static test format name set
as possible static test format.
"""
if cls.format_name:
# Register the new fixture format
BaseStaticTest.formats.append(cls)
if len(BaseStaticTest.formats) > 1:
BaseStaticTest.formats_type_adapter = TypeAdapter(
Union[tuple(BaseStaticTest.formats)],
)
else:
BaseStaticTest.formats_type_adapter = TypeAdapter(cls)

@model_validator(mode="wrap")
@classmethod
def _parse_into_subclass(
cls, v: Any, handler: ValidatorFunctionWrapHandler
) -> "BaseStaticTest":
"""Parse the static test into the correct subclass."""
if cls is BaseStaticTest:
return BaseStaticTest.formats_type_adapter.validate_python(v)
return handler(v)

@abstractmethod
def fill_function(self) -> Callable:
"""
Return the test function that can be used to fill the test.

This method should be implemented by the subclasses.

The function returned can be optionally decorated with the `@pytest.mark.parametrize`
decorator to parametrize the test with the number of sub test cases.

Example:
```
@pytest.mark.parametrize("n", [1])
@pytest.mark.parametrize("m", [1, 2])
@pytest.mark.valid_from("Homestead")
def test_state_filler(
state_test: StateTestFiller,
fork: Fork,
pre: Alloc,
n: int,
m: int,
):
\"\"\"Generate a test from a static state filler.\"\"\"
assert n == 1
assert m in [1, 2]
env = Environment(**self.env.model_dump())
sender = pre.fund_eoa()
tx = Transaction(
ty=0x0,
nonce=0,
to=Address(0x1000),
gas_limit=500000,
protected=False if fork in [Frontier, Homestead] else True,
data="",
sender=sender,
)
state_test(env=env, pre=pre, post={}, tx=tx)

return test_state_filler
```

To aid the generation of the test, the function can be defined and then the decorator be
applied after defining the function:

```
def test_state_filler(
state_test: StateTestFiller,
fork: Fork,
pre: Alloc,
n: int,
m: int,
):
...
test_state_filler = pytest.mark.parametrize("n", [1])(test_state_filler)
test_state_filler = pytest.mark.parametrize("m", [1, 2])(test_state_filler)
if self.valid_from:
test_state_filler = pytest.mark.valid_from(self.valid_from)(test_state_filler)
if self.valid_until:
test_state_filler = pytest.mark.valid_until(self.valid_until)(test_state_filler)
return test_state_filler
```

The function can contain the following parameters on top of the spec type parameter
(`state_test` in the example above):
- `fork`: The fork for which the test is currently being filled.
- `pre`: The pre-state of the test.

"""
raise NotImplementedError

@staticmethod
def remove_comments(data: Dict) -> Dict:
"""Remove comments from a dictionary."""
result = {}
for k, v in data.items():
if isinstance(k, str) and k.startswith("//"):
continue
if isinstance(v, dict):
v = BaseStaticTest.remove_comments(v)
elif isinstance(v, list):
v = [BaseStaticTest.remove_comments(i) if isinstance(i, dict) else i for i in v]
result[k] = v
return result

@model_validator(mode="before")
@classmethod
def remove_comments_from_model(cls, data: Any) -> Any:
"""Remove comments from the static file loaded, if any."""
if isinstance(data, dict):
return BaseStaticTest.remove_comments(data)
return data


ALL_FORKS = get_forks()


def fork_by_name(fork_name: str) -> Fork:
"""Get a fork by name."""
for fork in ALL_FORKS:
if fork.name() == fork_name:
return fork

raise Exception(f'Fork "{fork_name}" could not be identified.')


class ForkRangeDescriptor(BaseModel):
"""Fork descriptor parsed from string normally contained in ethereum/tests fillers."""

greater_equal: Fork | None = None
less_than: Fork | None = None
model_config = ConfigDict(frozen=True)

def fork_in_range(self, fork: Fork) -> bool:
"""Return whether the given fork is within range."""
if self.greater_equal is not None and fork < self.greater_equal:
return False
if self.less_than is not None and fork >= self.less_than:
return False
return True

@model_validator(mode="wrap")
@classmethod
def validate_fork_range_descriptor(cls, v: Any, handler: ValidatorFunctionWrapHandler):
"""
Validate the fork range descriptor from a string.

Examples:
- ">=Osaka" validates to {greater_equal=Osaka, less_than=None}
- ">=Prague<Osaka" validates to {greater_equal=Prague, less_than=Osaka}

"""
if isinstance(v, str):
# Decompose the string into its parts
descriptor_string = re.sub(r"\s+", "", v.strip())
v = {}
if m := re.search(r">=(\w+)", descriptor_string):
fork = fork_by_name(m.group(1))
v["greater_equal"] = fork
descriptor_string = re.sub(r">=(\w+)", "", descriptor_string)
if m := re.search(r"<(\w+)", descriptor_string):
fork = fork_by_name(m.group(1))
v["less_than"] = fork
descriptor_string = re.sub(r"<(\w+)", "", descriptor_string)
if descriptor_string:
raise Exception(
"Unable to completely parse fork range descriptor. "
+ f'Remaining string: "{descriptor_string}"'
)
return handler(v)


def remove_comments(v: str) -> str:
"""
Split by line and then remove the comments (starting with #) at the end of each line if
any.
"""
return "\n".join([line.split("#")[0].strip() for line in v.splitlines()])


label_matcher = re.compile(r"^:label\s+(\S+)\s*", re.MULTILINE)
raw_matcher = re.compile(r":raw\s+(.*)", re.MULTILINE)


def labeled_bytes_from_string(v: str) -> Tuple[str | None, Bytes]:
"""Parse `:label` and `:raw` from a string."""
v = remove_comments(v)

label: str | None = None
if m := label_matcher.search(v):
label = m.group(1)
v = label_matcher.sub("", v)

m = raw_matcher.match(v.replace("\n", " "))
if not m:
raise Exception(f"Unable to parse container from string: {v}")
strip_string = m.group(1).strip()
return label, Bytes(strip_string)
81 changes: 81 additions & 0 deletions src/ethereum_test_specs/eof_static.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""Ethereum EOF static test spec parser."""

from typing import Annotated, Callable, ClassVar, Dict, List

import pytest
from pydantic import BeforeValidator, Field

from ethereum_test_base_types import CamelModel
from ethereum_test_exceptions.exceptions import EOFExceptionInstanceOrList
from ethereum_test_forks import Fork
from ethereum_test_types.eof.v1 import Container

from .base_static import BaseStaticTest, ForkRangeDescriptor, labeled_bytes_from_string
from .eof import EOFTestFiller


class Info(CamelModel):
"""Information about the test contained in the static file."""

comment: str | None = None


def container_from_string(v: str) -> Container:
"""Parse a container string."""
label, raw_bytes = labeled_bytes_from_string(v)
return Container(
name=label,
raw_bytes=raw_bytes,
)


class Vector(CamelModel):
"""Single vector contained in an EOF filler static test."""

data: Annotated[Container, BeforeValidator(container_from_string)]
expect_exception: Dict[ForkRangeDescriptor, EOFExceptionInstanceOrList] | None = None


class EOFStaticTest(BaseStaticTest):
"""EOF static filler from ethereum/tests."""

info: Info = Field(..., alias="_info")

forks: List[ForkRangeDescriptor]
vectors: List[Vector]

format_name: ClassVar[str] = "eof_test"

def fill_function(self) -> Callable:
"""Return a EOF spec from a static file."""

@pytest.mark.parametrize(
"vector",
self.vectors,
ids=lambda c: c.data.name,
)
def test_eof_vectors(
eof_test: EOFTestFiller,
fork: Fork,
vector: Vector,
):
expect_exception: EOFExceptionInstanceOrList | None = None
if vector.expect_exception is not None:
for fork_range, exception in vector.expect_exception.items():
if fork_range.fork_in_range(fork):
expect_exception = exception
break
return eof_test(
container=vector.data,
expect_exception=expect_exception,
)

assert len(self.forks) <= 1, "Multiple fork elements is not supported"
forks = self.forks[0]

if forks.greater_equal is not None:
test_eof_vectors = pytest.mark.valid_from(str(forks.greater_equal))(test_eof_vectors)
if forks.less_than is not None:
test_eof_vectors = pytest.mark.valid_until(str(forks.greater_equal))(test_eof_vectors)

return test_eof_vectors
32 changes: 32 additions & 0 deletions src/ethereum_test_specs/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

from ethereum_test_base_types import Address, Bloom, Bytes, Hash, HeaderNonce
from ethereum_test_fixtures.blockchain import FixtureHeader
from ethereum_test_forks import Osaka, Prague

from ..base_static import ForkRangeDescriptor
from ..blockchain import Header

fixture_header_ones = FixtureHeader(
Expand Down Expand Up @@ -129,3 +131,33 @@ def test_fixture_header_join(
):
"""Test that the join method works as expected."""
assert modifier.apply(fixture_header) == fixture_header_expected


@pytest.mark.parametrize(
"fork_range_descriptor_string,expected_fork_range_descriptor",
[
(
">=Osaka",
ForkRangeDescriptor(
greater_equal=Osaka,
less_than=None,
),
),
(
">= Prague < Osaka",
ForkRangeDescriptor(
greater_equal=Prague,
less_than=Osaka,
),
),
],
)
def test_parsing_fork_range_descriptor_from_string(
fork_range_descriptor_string: str,
expected_fork_range_descriptor: ForkRangeDescriptor,
):
"""Test multiple strings used as fork range descriptors in ethereum/tests."""
assert (
ForkRangeDescriptor.model_validate(fork_range_descriptor_string)
== expected_fork_range_descriptor
)
Loading