Skip to content

Commit

Permalink
fix: gas profile when non-VyperContracts exist in the environment (vy…
Browse files Browse the repository at this point in the history
…perlang#278)

* fix: gas profile when non-VyperContracts exist in the environment
* fix another spot
* another small fix
* recurse into black box contracts
* fix recursion for profiling
* remove unused member `_coverage_data`
* refactor profiler: make profile data global

rather than scoped to a particular env, which breaks if the user changes
env (e.g. via `swap_env`) in tests

* fix some small bugs
* simplify LineProfile.by_line
  • Loading branch information
charles-cooper authored Aug 28, 2024
1 parent 88f9bed commit 83ef3ac
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 54 deletions.
3 changes: 3 additions & 0 deletions boa/contracts/base_evm_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ class _BaseEVMContract:
This includes ABI and Vyper contract.
"""

# flag to signal whether this contract can be line profiled
_can_line_profile = False

def __init__(
self,
env: Optional[Env] = None,
Expand Down
14 changes: 3 additions & 11 deletions boa/contracts/vyper/vyper_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from boa.contracts.vyper.event import Event, RawEvent
from boa.contracts.vyper.ir_executor import executor_from_ir
from boa.environment import Env
from boa.profiling import LineProfile, cache_gas_used_for_computation
from boa.profiling import cache_gas_used_for_computation
from boa.util.abi import Address, abi_decode, abi_encode
from boa.util.lrudict import lrudict
from boa.vm.gas_meters import ProfilingGasMeter
Expand Down Expand Up @@ -501,6 +501,8 @@ def __repr__(self):


class VyperContract(_BaseVyperContract):
_can_line_profile = True

def __init__(
self,
compiler_data: CompilerData,
Expand Down Expand Up @@ -798,16 +800,6 @@ def stack_trace(self, computation=None):
return ret
return _handle_child_trace(computation, self.env, ret)

def line_profile(self, computation=None):
computation = computation or self._computation
ret = LineProfile.from_single(self, computation)
for child in computation.children:
child_obj = self.env.lookup_contract(child.msg.code_address)
# TODO: child obj is opaque contract that calls back into known contract
if child_obj is not None:
ret.merge(child_obj.line_profile(child))
return ret

def ensure_id(self, fn_t): # mimic vyper.codegen.module.IDGenerator api
if fn_t._function_id is None:
fn_t._function_id = self._function_id
Expand Down
5 changes: 0 additions & 5 deletions boa/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,6 @@ def __init__(self, fork_try_prefetch_state=False, fast_mode_enabled=False):
self.sha3_trace: dict = {}
self.sstore_trace: dict = {}

self._profiled_contracts = {}
self._cached_call_profiles = {}
self._cached_line_profiles = {}
self._coverage_data = {}

self._gas_tracker = 0

self.evm = PyEVM(self, fast_mode_enabled, fork_try_prefetch_state)
Expand Down
96 changes: 68 additions & 28 deletions boa/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from rich.table import Table

from boa.contracts.vyper.ast_utils import get_fn_name_from_lineno, get_line
from boa.environment import Env


def _safe_relpath(path):
Expand Down Expand Up @@ -145,17 +144,21 @@ def by_pc(self):
def by_line(self):
ret = {}
source_map = self.contract.source_map["pc_raw_ast_map"]
current_line = None
seen = set()
for pc in self.computation.code._trace:
if (node := source_map.get(pc)) is not None:
current_line = node.lineno
if pc in seen:
# TODO: this is a kludge, it prevents lines from being
# over-represented when they appear in loops. but we should
# probably either not have this guard, or actually count
# the number of times a line is hit per- computation.
continue
if (node := source_map.get(pc)) is None:
continue

# NOTE: do we still need the `current_line is not None` guard?
if current_line is not None and pc not in seen:
ret.setdefault(current_line, Datum())
ret[current_line].merge(self.by_pc[pc])
seen.add(pc)
current_line = node.lineno
ret.setdefault(current_line, Datum()).merge(self.by_pc[pc])

seen.add(pc)

return ret

Expand Down Expand Up @@ -225,6 +228,30 @@ def get_line_data(self):
return line_gas_data


# singleton profile object which collects gas+line profiles over test runs
class GlobalProfile:
_singleton = None

def __init__(self):
self.profiled_contracts = {}
self.call_profiles = {}
self.line_profiles = {}

@classmethod
def get_singleton(cls):
if cls._singleton is None:
cls._singleton = cls()
return cls._singleton

@classmethod
def clear_singleton(cls):
cls._singleton = None


def global_profile():
return GlobalProfile.get_singleton()


# stupid class whose __str__ method doesn't escape (good for repl)
class _String(str):
def __repr__(self):
Expand All @@ -233,21 +260,41 @@ def __repr__(self):

# cache gas_used for all computation (including children)
def cache_gas_used_for_computation(contract, computation):
profile = contract.line_profile(computation)
env = contract.env

def _recurse(computation):
# recursion for child computations
for _computation in computation.children:
child_contract = env.lookup_contract(_computation.msg.code_address)

if child_contract is None:
# for black box contracts, we don't profile the contract,
# but we recurse into the subcalls
_recurse(_computation)
else:
cache_gas_used_for_computation(child_contract, _computation)

if not getattr(contract, "_can_line_profile", False):
_recurse(computation)
return

profile = LineProfile.from_single(contract, computation)
contract_path = contract.compiler_data.contract_path

# -------------------- CACHE CALL PROFILE --------------------
# get gas used. We use Datum().net_gas here instead of Datum().net_tot_gas
# because a call's profile includes children call costs.
# There will be double counting, but that is by choice.
#
# TODO: make it user configurable / present it to the user,
# similar to cProfile tottime vs cumtime.

sum_net_gas = sum([i.net_gas for i in profile.profile.values()])
sum_net_tot_gas = sum([i.net_tot_gas for i in profile.profile.values()])

fn = contract._get_fn_from_computation(computation)
if fn is None:
fn_name = "unnamed"
fn_name = "<none>"
else:
fn_name = fn.name

Expand All @@ -257,29 +304,22 @@ def cache_gas_used_for_computation(contract, computation):
fn_name=fn_name,
)

env._cached_call_profiles.setdefault(fn, CallGasStats()).merge_gas_data(
global_profile().call_profiles.setdefault(fn, CallGasStats()).merge_gas_data(
sum_net_gas, sum_net_tot_gas
)

s = env._profiled_contracts.setdefault(fn.address, [])
if fn not in env._profiled_contracts[fn.address]:
s = global_profile().profiled_contracts.setdefault(fn.address, [])
if fn not in s:
s.append(fn)

# -------------------- CACHE LINE PROFILE --------------------
line_profile = profile.get_line_data()

for line, gas_used in line_profile.items():
env._cached_line_profiles.setdefault(line, []).append(gas_used)
global_profile().line_profiles.setdefault(line, []).append(gas_used)

# ------------------------- RECURSION -------------------------

# recursion for child computations
for _computation in computation.children:
child_contract = env.lookup_contract(_computation.msg.code_address)

# ignore black box contracts
if child_contract is not None:
cache_gas_used_for_computation(child_contract, _computation)
_recurse(computation)


def _create_table(for_line_profile: bool = False) -> Table:
Expand Down Expand Up @@ -307,11 +347,11 @@ def _create_table(for_line_profile: bool = False) -> Table:
return table


def get_call_profile_table(env: Env) -> Table:
def get_call_profile_table() -> Table:
table = _create_table()

cache = env._cached_call_profiles
cached_contracts = env._profiled_contracts
cache = global_profile().call_profiles
cached_contracts = global_profile().profiled_contracts
contract_vs_median_gas = []
for profile in cache:
cache[profile].compute_stats()
Expand Down Expand Up @@ -357,9 +397,9 @@ def get_call_profile_table(env: Env) -> Table:
return table


def get_line_profile_table(env: Env) -> Table:
def get_line_profile_table() -> Table:
contracts: dict = {}
for lp, gas_data in env._cached_line_profiles.items():
for lp, gas_data in global_profile().line_profiles.items():
contract_uid = (lp.contract_path, lp.address)

# add spaces so numbers take up equal space
Expand Down
8 changes: 4 additions & 4 deletions boa/test/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest

import boa
from boa.profiling import get_call_profile_table, get_line_profile_table
from boa.profiling import get_call_profile_table, get_line_profile_table, global_profile
from boa.vm.gas_meters import ProfilingGasMeter

# monkey patch HypothesisHandle. this fixes underlying isolation for
Expand Down Expand Up @@ -108,11 +108,11 @@ def _toggle_profiling(enabled: bool = False) -> Generator:


def pytest_sessionfinish(session, exitstatus):
if boa.env._cached_call_profiles:
if global_profile().call_profiles:
import sys

from rich.console import Console

console = Console(file=sys.stdout)
console.print(get_call_profile_table(boa.env))
console.print(get_line_profile_table(boa.env))
console.print(get_call_profile_table())
console.print(get_line_profile_table())
2 changes: 1 addition & 1 deletion docs/source/testing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Since ``titanoboa`` is framework-agnostic any other testing framework should wor
Gas Profiling
-----------------------

Titanoboa has native gas profiling tools that store and generate statistics upon calling a contract. When enabled, gas costs are stored per call in global ``boa.env._cached_call_profiles`` and ``boa.env._cached_line_profiles`` dictionaries.
Titanoboa has native gas profiling tools that store and generate statistics upon calling a contract. When enabled, gas costs are stored per call in ``global_profile().call_profiles`` and ``global_profile().line_profiles`` dictionaries.
To enable gas profiling,

1. decorate tests with ``@pytest.mark.gas_profile``, or
Expand Down
12 changes: 7 additions & 5 deletions tests/unitary/test_gas_profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from hypothesis import given, settings

import boa
from boa.profiling import global_profile
from boa.test import strategy


Expand Down Expand Up @@ -74,12 +75,13 @@ def _barfoo(a: uint256, b: uint256, c: uint256) -> uint256:
)
@pytest.mark.ignore_profiling
def test_ignore_profiling(variable_loop_contract, a, b, c):
cached_profiles = [boa.env._cached_call_profiles, boa.env._cached_line_profiles]
# TODO: not sure this is testing what it intends to
cached_profiles = [global_profile().call_profiles, global_profile().line_profiles]

variable_loop_contract.foo(a, b, c)

assert boa.env._cached_call_profiles == cached_profiles[0]
assert boa.env._cached_line_profiles == cached_profiles[1]
assert global_profile().call_profiles == cached_profiles[0]
assert global_profile().line_profiles == cached_profiles[1]


@pytest.mark.parametrize(
Expand All @@ -90,8 +92,8 @@ def test_call_variable_iter_method(variable_loop_contract, a, b, c):
variable_loop_contract.foo(a, b, c)
variable_loop_contract._barfoo(a, b, c)

assert boa.env._cached_call_profiles
assert boa.env._cached_line_profiles
assert global_profile().call_profiles
assert global_profile().line_profiles


@given(
Expand Down

0 comments on commit 83ef3ac

Please sign in to comment.