Skip to content

fix[tool]: fix output bundle construction #4654

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open
11 changes: 11 additions & 0 deletions tests/unit/ast/test_ast_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,17 @@ def foo() -> uint256:
}


def test_import_builtin_ast():
code = """
from ethereum.ercs import IERC20
import math
"""
dict_out = compiler.compile_code(code, output_formats=["ast_dict"], source_id=0)["ast_dict"]
imports = dict_out["imports"]
import_paths = [import_dict["path"] for import_dict in imports]
assert import_paths == ["vyper/builtins/interfaces/IERC20.vyi", "vyper/builtins/stdlib/math.vy"]


def test_dict_to_ast():
code = """
@external
Expand Down
50 changes: 48 additions & 2 deletions tests/unit/cli/vyper_compile/test_compile_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from vyper.compiler.input_bundle import FilesystemInputBundle
from vyper.compiler.output_bundle import OutputBundle
from vyper.compiler.phases import CompilerData
from vyper.exceptions import TypeMismatch
from vyper.utils import sha256sum

TAMPERED_INTEGRITY_SUM = sha256sum("tampered integrity sum")
Expand Down Expand Up @@ -390,7 +391,7 @@ def test_archive_output(input_files):
assert archive_compiler_data.integrity_sum is not None

assert len(w) == 1, [s.message for s in w]
assert str(w[0].message).startswith(INTEGRITY_WARNING.format(integrity=integrity))
assert w[0].message.message.startswith(INTEGRITY_WARNING.format(integrity=integrity))


def test_archive_b64_output(input_files):
Expand Down Expand Up @@ -545,7 +546,52 @@ def test_solc_json_output(input_files):

w = warn_data[Path("contract.vy")]
assert len(w) == 1, [s.message for s in w]
assert str(w[0].message).startswith(INTEGRITY_WARNING.format(integrity=integrity))
assert w[0].message.message.startswith(INTEGRITY_WARNING.format(integrity=integrity))


# test that we can construct output bundles even when there is a semantic error
# TODO: maybe move this to tests/unit/compiler/
def test_output_bundle_semantic_error(make_file, chdir_tmp_path):
library_source = """
@internal
def foo() -> uint256:
return block.number + b"asldkjf" # semantic error
"""
contract_source = """
import lib

a: uint256
b: uint256

@external
def foo() -> uint256:
return lib.foo()
"""
_ = make_file("lib.vy", library_source)
contract_file = make_file("main.vy", contract_source)

with warnings.catch_warnings(record=True) as w:
s = compile_files([contract_file], ["archive"])

assert len(w) == 1
expected_warning = (
"Exceptions encountered during code generation (but producing archive anyway)"
)
assert expected_warning in w[0].message.message

archive_bytes = s[contract_file]["archive"]

archive_path = Path("foo.zip")
with archive_path.open("wb") as f:
f.write(archive_bytes)

assert zipfile.is_zipfile(archive_path)

# compare compiling the two input bundles
with pytest.raises(TypeMismatch, match="Cannot perform addition between dislike types") as e:
_ = compile_files([archive_path], ["integrity", "bytecode", "layout"])

assert e.value.message in w[0].message.message


# maybe this belongs in tests/unit/compiler?
Expand Down
4 changes: 4 additions & 0 deletions vyper/compiler/input_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@
def from_builtin(self):
return self.source_id == BUILTIN

# fast hash which doesn't require looking at the contents
def __hash__(self):
return hash((self.source_id, self.path, self.resolved_path))

Check warning on line 44 in vyper/compiler/input_bundle.py

View check run for this annotation

Codecov / codecov/patch

vyper/compiler/input_bundle.py#L44

Added line #L44 was not covered by tests


@dataclass(frozen=True)
class FileInput(CompilerInput):
Expand Down
43 changes: 18 additions & 25 deletions vyper/compiler/output.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import base64
from collections import deque
from pathlib import PurePath
from typing import Iterable

import vyper.ast as vy_ast
from vyper.ast.utils import ast_to_dict
Expand All @@ -11,49 +12,41 @@
from vyper.evm import opcodes
from vyper.exceptions import VyperException
from vyper.ir import compile_ir
from vyper.semantics.analysis.base import ModuleInfo
from vyper.semantics.types.function import ContractFunctionT, FunctionVisibility, StateMutability
from vyper.semantics.types.module import InterfaceT
from vyper.typing import StorageLayout
from vyper.utils import safe_relpath
from vyper.warnings import ContractSizeLimit, vyper_warn


def _get_reachable_imports(compiler_data: CompilerData) -> Iterable[vy_ast.Module]:
import_analysis = compiler_data.resolved_imports

# get all reachable imports including recursion
# (NOTE: does not include imported json interfaces.)
imported_modules = list(import_analysis.compiler_inputs.values())
imported_modules = [mod for mod in imported_modules if isinstance(mod, vy_ast.Module)]
if import_analysis.toplevel_module in imported_modules:
imported_modules.remove(import_analysis.toplevel_module)

Check warning on line 29 in vyper/compiler/output.py

View check run for this annotation

Codecov / codecov/patch

vyper/compiler/output.py#L29

Added line #L29 was not covered by tests

return imported_modules


def build_ast_dict(compiler_data: CompilerData) -> dict:
imported_modules = _get_reachable_imports(compiler_data)
ast_dict = {
"contract_name": str(compiler_data.contract_path),
"ast": ast_to_dict(compiler_data.vyper_module),
"imports": [ast_to_dict(ast) for ast in imported_modules],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this misses json interfaces, maybe let's leave a comment/raise an issue

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, i'll leave a comment.

}
return ast_dict


def build_annotated_ast_dict(compiler_data: CompilerData) -> dict:
module_t = compiler_data.annotated_vyper_module._metadata["type"]
# get all reachable imports including recursion
imported_module_infos = module_t.reachable_imports
unique_modules: dict[str, vy_ast.Module] = {}
for info in imported_module_infos:
if isinstance(info.typ, InterfaceT):
ast = info.typ.decl_node
if ast is None: # json abi
continue
else:
assert isinstance(info.typ, ModuleInfo)
ast = info.typ.module_t._module

assert isinstance(ast, vy_ast.Module) # help mypy
# use resolved_path for uniqueness, since Module objects can actually
# come from multiple InputBundles (particularly builtin interfaces),
# so source_id is not guaranteed to be unique.
if ast.resolved_path in unique_modules:
# sanity check -- objects must be identical
assert unique_modules[ast.resolved_path] is ast
unique_modules[ast.resolved_path] = ast

imported_modules = _get_reachable_imports(compiler_data)
annotated_ast_dict = {
"contract_name": str(compiler_data.contract_path),
"ast": ast_to_dict(compiler_data.annotated_vyper_module),
"imports": [ast_to_dict(ast) for ast in unique_modules.values()],
"imports": [ast_to_dict(ast) for ast in imported_modules],
}
return annotated_ast_dict

Expand Down
14 changes: 7 additions & 7 deletions vyper/compiler/output_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@ def __init__(self, compiler_data: CompilerData):
def compilation_target(self):
return self.compiler_data.compilation_target._metadata["type"]

@cached_property
def _imports(self):
return self.compilation_target.reachable_imports

@cached_property
def compiler_inputs(self) -> dict[str, CompilerInput]:
inputs: list[CompilerInput] = [
t.compiler_input for t in self._imports if not t.compiler_input.from_builtin
]
import_analysis = self.compiler_data.resolved_imports

inputs: list[CompilerInput] = import_analysis.compiler_inputs.copy()
inputs = [inp for inp in inputs if not inp.from_builtin]

# file input for the top level module; it's not in
# import_analysis._compiler_inputs
inputs.append(self.compiler_data.file_input)

sources = {}
Expand Down
3 changes: 0 additions & 3 deletions vyper/compiler/phases.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,6 @@ def __init__(
self.input_bundle = input_bundle or FilesystemInputBundle([Path(".")])
self.expected_integrity_sum = integrity_sum

# ast cache, hitchhike onto the input_bundle object
self.input_bundle._cache._ast_of: dict[int, vy_ast.Module] = {} # type: ignore

@cached_property
def source_code(self):
return self.file_input.source_code
Expand Down
37 changes: 24 additions & 13 deletions vyper/semantics/analysis/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
tag_exceptions,
)
from vyper.semantics.analysis.base import ImportInfo
from vyper.utils import safe_relpath, sha256sum
from vyper.utils import OrderedSet, safe_relpath, sha256sum

"""
collect import statements and validate the import graph.
Expand Down Expand Up @@ -74,21 +74,34 @@ def enter_path(self, module_ast: vy_ast.Module) -> Iterator[None]:


class ImportAnalyzer:
def __init__(self, input_bundle: InputBundle, graph: _ImportGraph):
seen: OrderedSet[vy_ast.Module]
_compiler_inputs: dict[CompilerInput, vy_ast.Module]
toplevel_module: vy_ast.Module

def __init__(self, input_bundle: InputBundle, graph: _ImportGraph, module_ast: vy_ast.Module):
self.input_bundle = input_bundle
self.graph = graph
self.toplevel_module = module_ast
self._ast_of: dict[int, vy_ast.Module] = {}

self.seen: set[vy_ast.Module] = set()
self.seen = OrderedSet()

# keep around compiler inputs so when we construct the output
# bundle, we have access to the compiler input for each module
self._compiler_inputs = {}

self._integrity_sum = None

# should be all system paths + topmost module path
self.absolute_search_paths = input_bundle.search_paths.copy()

def resolve_imports(self, module_ast: vy_ast.Module):
self._resolve_imports_r(module_ast)
self._integrity_sum = self._calculate_integrity_sum_r(module_ast)
def resolve_imports(self):
self._resolve_imports_r(self.toplevel_module)
self._integrity_sum = self._calculate_integrity_sum_r(self.toplevel_module)

@property
def compiler_inputs(self) -> dict[CompilerInput, vy_ast.Module]:
return self._compiler_inputs

def _calculate_integrity_sum_r(self, module_ast: vy_ast.Module):
acc = [sha256sum(module_ast.full_source_code)]
Expand Down Expand Up @@ -152,6 +165,7 @@ def _add_import(
self, node: vy_ast.VyperNode, level: int, qualified_module_name: str, alias: str
) -> None:
compiler_input, ast = self._load_import(node, level, qualified_module_name, alias)
self._compiler_inputs[compiler_input] = ast
node._metadata["import_info"] = ImportInfo(
alias, qualified_module_name, compiler_input, ast
)
Expand Down Expand Up @@ -180,7 +194,7 @@ def _load_import(
assert isinstance(file, FileInput) # mypy hint

module_ast = self._ast_from_file(file)
self.resolve_imports(module_ast)
self._resolve_imports_r(module_ast)

return file, module_ast

Expand All @@ -193,10 +207,7 @@ def _load_import(
file = self._load_file(path.with_suffix(".vyi"), level)
assert isinstance(file, FileInput) # mypy hint
module_ast = self._ast_from_file(file)
self.resolve_imports(module_ast)

# language does not yet allow recursion for vyi files
# self.resolve_imports(module_ast)
self._resolve_imports_r(module_ast)

return file, module_ast

Expand Down Expand Up @@ -351,7 +362,7 @@ def _load_builtin_import(level: int, module_str: str) -> tuple[CompilerInput, vy

def resolve_imports(module_ast: vy_ast.Module, input_bundle: InputBundle):
graph = _ImportGraph()
analyzer = ImportAnalyzer(input_bundle, graph)
analyzer.resolve_imports(module_ast)
analyzer = ImportAnalyzer(input_bundle, graph, module_ast)
analyzer.resolve_imports()

return analyzer
20 changes: 1 addition & 19 deletions vyper/semantics/types/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from vyper.utils import OrderedSet

if TYPE_CHECKING:
from vyper.semantics.analysis.base import ImportInfo, ModuleInfo
from vyper.semantics.analysis.base import ModuleInfo

Check warning on line 28 in vyper/semantics/types/module.py

View check run for this annotation

Codecov / codecov/patch

vyper/semantics/types/module.py#L28

Added line #L28 was not covered by tests

Check failure

Code scanning / CodeQL

Module-level cyclic import Error

'ModuleInfo' may not be defined if module
vyper.semantics.analysis.base
is imported before module
vyper.semantics.types.module
, as the
definition
of ModuleInfo occurs after the cyclic
import
of vyper.semantics.types.module.


class InterfaceT(_UserType):
Expand Down Expand Up @@ -445,24 +445,6 @@
ret[info.alias] = module_info
return ret

@cached_property
def reachable_imports(self) -> list["ImportInfo"]:
"""
Return (recursively) reachable imports from this module as a list in
depth-first (descendants-first) order.
"""
ret = []
for s in self.import_stmts:
info = s._metadata["import_info"]

# NOTE: this needs to be redone if interfaces can import other interfaces
if not isinstance(info.typ, InterfaceT):
ret.extend(info.typ.typ.reachable_imports)

ret.append(info)

return ret

def find_module_info(self, needle: "ModuleT") -> Optional["ModuleInfo"]:
for s in self.imported_modules.values():
if s.module_t == needle:
Expand Down
Loading