From 00a0b6fe26a6cf97526d0dbbb531efe434b23ba6 Mon Sep 17 00:00:00 2001 From: derekpierre Date: Thu, 2 Nov 2023 16:02:13 -0400 Subject: [PATCH] Converted variable functionality into classes that can hold state instead of passing around state needed to allow for simpler resolution of variables. State is now reflected in a VariableContext object that is contract specific and is stored by variable objects. Add Encode variable type to allow function calls to be provided in yml - allows proxy constructor to call logic contract function to initialize. --- deployment/constants.py | 5 - deployment/params.py | 419 +++++++++++++++++++++++++--------------- 2 files changed, 265 insertions(+), 159 deletions(-) diff --git a/deployment/constants.py b/deployment/constants.py index 675c1770..18c8e2ad 100644 --- a/deployment/constants.py +++ b/deployment/constants.py @@ -6,11 +6,6 @@ DEPLOYMENT_DIR = Path(deployment.__file__).parent CONSTRUCTOR_PARAMS_DIR = DEPLOYMENT_DIR / "constructor_params" ARTIFACTS_DIR = DEPLOYMENT_DIR / "artifacts" -VARIABLE_PREFIX = "$" -SPECIAL_VARIABLE_DELIMITER = ":" -HEX_PREFIX = "0x" -BYTES_PREFIX = "bytes" -DEPLOYER_INDICATOR = "deployer" OZ_DEPENDENCY = project.dependencies["openzeppelin"]["5.0.0"] # diff --git a/deployment/params.py b/deployment/params.py index cc6afe27..893e32f5 100644 --- a/deployment/params.py +++ b/deployment/params.py @@ -1,4 +1,5 @@ import typing +from abc import ABC, abstractmethod from collections import OrderedDict from pathlib import Path from typing import Any, List @@ -9,13 +10,7 @@ from ape.contracts.base import ContractContainer, ContractInstance, ContractTransactionHandler from ape.utils import ZERO_ADDRESS from deployment.confirm import _confirm_resolution, _continue -from deployment.constants import ( - BYTES_PREFIX, - DEPLOYER_INDICATOR, - OZ_DEPENDENCY, - SPECIAL_VARIABLE_DELIMITER, - VARIABLE_PREFIX, -) +from deployment.constants import OZ_DEPENDENCY from deployment.registry import registry_from_ape_deployments from deployment.utils import ( _load_yaml, @@ -25,37 +20,162 @@ verify_contracts, ) from eth_typing import ChecksumAddress +from hexbytes import HexBytes from web3.auto import w3 CONTRACT_CONSTRUCTOR_PARAMETER_KEY = "constructor" CONTRACT_PROXY_PARAMETER_KEY = "proxy" -def _is_variable(param: Any) -> bool: - """Returns True if the param is a variable.""" - result = isinstance(param, str) and param.startswith(VARIABLE_PREFIX) - return result +class VariableContext: + def __init__( + self, + contract_names: List[str], + contract_name: str, + constants: typing.Dict[str, Any] = None, + check_for_proxy_instances: bool = True, + ): + self.contract_names = contract_names or list() + self.contract_name = contract_name + self.constants = constants or dict() + self.check_for_proxy_instances = check_for_proxy_instances + + +# Variables + + +class Variable(ABC): + VARIABLE_PREFIX = "$" + + @abstractmethod + def resolve(self) -> Any: + raise NotImplementedError + + @classmethod + def is_variable(cls, param: Any) -> bool: + """Returns True if the param is a variable.""" + result = isinstance(param, str) and param.startswith(cls.VARIABLE_PREFIX) + return result + + +class DeployerAccount(Variable): + DEPLOYER_INDICATOR = "deployer" + + @classmethod + def is_deployer(cls, value: str) -> bool: + """Returns True if the variable is a special deployer variable.""" + return value == cls.DEPLOYER_INDICATOR + + def resolve(self) -> Any: + deployer_account = Deployer.get_account() + if deployer_account is None: + return ZERO_ADDRESS + return deployer_account.address + +class Constant(Variable): # oxymoron + def __init__(self, constant_name: str, context: VariableContext): + try: + self.constant_value = context.constants[constant_name] + except KeyError: + raise ValueError(f"Constant '{constant_name}' not found in deployment file.") -def _is_special_variable(variable: str) -> bool: - """Returns True if the variable is a special variable.""" - rules = [_is_bytes, _is_deployer, _is_constant] - return any(rule(variable) for rule in rules) + @classmethod + def is_constant(cls, value: str) -> bool: + """Returns True if the variable is a deployment constant.""" + return value.isupper() + def resolve(self) -> Any: + return self.constant_value -def _is_bytes(variable: str) -> bool: - """Returns True if the variable is a special bytes value.""" - return variable.startswith(BYTES_PREFIX + SPECIAL_VARIABLE_DELIMITER) +class Encode(Variable): + ENCODE_PREFIX = "encode:" -def _is_deployer(variable: str) -> bool: - """Returns True if the variable is a special deployer variable.""" - return variable == DEPLOYER_INDICATOR + def __init__(self, variable: str, context: VariableContext): + variable = variable[len(self.ENCODE_PREFIX) :] + self.method_name, self.input_abi_types, self.method_args = self._get_call_data( + variable, context + ) + self.contract_name = context.contract_name + + @staticmethod + def _get_call_data(variable, context) -> typing.Tuple[str, List[str], List[Any]]: + variable_elements = variable.split(",") + method_name = variable_elements[0] + method_args = list() + if len(variable_elements) > 1: + args = variable_elements[1:] + for arg in args: + processed_value = _process_raw_value(arg, context) + method_args.append(processed_value) + + contract_name = context.contract_name + contract_container = get_contract_container(contract_name) + contract_method_abis = contract_container.contract_type.methods + method_abi = None + for abi in contract_method_abis: + if abi.name == method_name: + method_abi = abi + if not method_abi: + raise ValueError(f"ABI could not be found for method {contract_name}.{method_name}") + input_abi_types = [t.type for t in method_abi.inputs] + if len(input_abi_types) != len(method_args): + raise ValueError( + f"{contract_name}.{method_name} parameters length mismatch - " + f"ABI requires {len(input_abi_types)}, Got {len(method_args)}." + ) + return method_name, input_abi_types, method_args -def _is_constant(variable: str) -> bool: - """Returns True if the variable is a deployment constant.""" - return variable.isupper() + @classmethod + def is_encode(cls, value: str) -> bool: + """Returns True if the variable is a variable that needs encoding to bytes""" + return value.startswith(cls.ENCODE_PREFIX) + + def resolve(self) -> Any: + contract_container = get_contract_container(self.contract_name) + contract_instance = _get_contract_instance(contract_container) + if contract_instance == ZERO_ADDRESS: + # logic contract not yet deployed - in eager validation check + return HexBytes(b"\xde\xad\xbe\xef").hex() # 0xdeadbeef + + method_args = list() + for method_arg in self.method_args: + value = method_arg + if isinstance(method_arg, Variable): + value = method_arg.resolve() + method_args.append(value) + + method_handler = getattr(contract_instance, self.method_name) + encoded_bytes = method_handler.encode_input(*method_args) + return encoded_bytes.hex() # return as hex - just cleaner + + +class ContractName(Variable): + def __init__(self, contract_name: str, context: VariableContext): + if contract_name not in context.contract_names: + raise ValueError(f"Contract name {contract_name} not found") + + self.contract_name = contract_name + self.check_for_proxy_instances = context.check_for_proxy_instances + + def resolve(self) -> Any: + """Resolves a contract address.""" + contract_container = get_contract_container(self.contract_name) + contract_instance = _get_contract_instance(contract_container) + if contract_instance == ZERO_ADDRESS: + # eager validation + return ZERO_ADDRESS + + if self.check_for_proxy_instances: + # check if contract is proxied - if so return proxy contract instead + local_proxies = chain.contracts._local_proxies + for proxy_address, proxy_info in local_proxies.items(): + if proxy_info.target == contract_instance.address: + return proxy_address + + return contract_instance.address def _get_contract_instance( @@ -65,7 +185,7 @@ def _get_contract_instance( if not contract_instances: return ZERO_ADDRESS if len(contract_instances) != 1: - raise ConstructorParameters.Invalid( + raise ValueError( f"Variable {contract_container.contract_type.name} is ambiguous - " f"expected exactly one contract instance, got {len(contract_instances)}" ) @@ -73,115 +193,83 @@ def _get_contract_instance( return contract_instance -def _resolve_deployer() -> str: - deployer_account = Deployer.get_account() - if deployer_account is None: - return ZERO_ADDRESS - return deployer_account.address - +def _resolve_param(value: Any) -> Any: + """Resolves a single parameter value or a list of parameter values.""" + if isinstance(value, list): + return [_resolve_param(v) for v in value] -def _validate_transaction_args( - method: ContractTransactionHandler, args: typing.Tuple[Any, ...] -) -> typing.Dict[str, Any]: - """Validates the transaction arguments against the function ABI.""" - expected_length_abis = [abi for abi in method.abis if len(abi.inputs) == len(args)] - for abi in expected_length_abis: - named_args = {} - for arg, abi_input in zip(args, abi.inputs): - if not w3.is_encodable(abi_input.type, arg): - break - named_args[abi_input.name] = arg - else: - return named_args - raise ValueError(f"Could not find ABI for {method} with {len(args)} args and given types") + if isinstance(value, Variable): + return value.resolve() + return value # literally a value -def _resolve_contract_address(variable: str, check_for_proxy_instances=True) -> str: - """Resolves a contract address.""" - contract_container = get_contract_container(variable) - contract_instance = _get_contract_instance(contract_container) - if contract_instance == ZERO_ADDRESS: - # eager validation - return ZERO_ADDRESS - if check_for_proxy_instances: - # check if contract is proxied - if so return proxy contract instead - local_proxies = chain.contracts._local_proxies - for proxy_address, proxy_info in local_proxies.items(): - if proxy_info.target == contract_instance.address: - return proxy_address +def _resolve_params(parameters: OrderedDict) -> OrderedDict: + resolved_parameters = OrderedDict() + for name, value in parameters.items(): + resolved_parameters[name] = _resolve_param(value) - return contract_instance.address + return resolved_parameters -def _resolve_special_variable(variable: str, constants) -> Any: - if _is_deployer(variable): - result = _resolve_deployer() - elif _is_constant(variable): - result = _resolve_constant(variable, constants=constants) +def _variable_from_value(variable: Any, context: VariableContext) -> Variable: + variable = variable.strip(Variable.VARIABLE_PREFIX) + if DeployerAccount.is_deployer(variable): + return DeployerAccount() + elif Encode.is_encode(variable): + return Encode(variable, context) + elif Constant.is_constant(variable): + return Constant(variable, context) else: - raise ValueError(f"Invalid special variable {variable}") - return result + return ContractName(variable, context) -def _resolve_param(value: Any, constants, resolve_contracts_checking_proxies=True) -> Any: - """Resolves a single parameter value or a list of parameter values.""" +def _process_raw_value(value: Any, variable_context: VariableContext) -> Any: if isinstance(value, list): - return [_resolve_param(v, constants, resolve_contracts_checking_proxies) for v in value] - if not _is_variable(value): - return value # literally a value - variable = value.strip(VARIABLE_PREFIX) - if _is_special_variable(variable): - result = _resolve_special_variable(variable, constants=constants) - else: - result = _resolve_contract_address(variable, resolve_contracts_checking_proxies) - return result + return [_process_raw_value(v, variable_context) for v in value] + if Variable.is_variable(value): + value = _variable_from_value(value, variable_context) -def _resolve_constant(name: str, constants: typing.Dict[str, Any]) -> Any: - try: - value = constants[name] - return value - except KeyError: - raise ValueError(f"Constant '{name}' not found in deployment file.") + return value -def _validate_constructor_param(param: Any, contracts: List[str]) -> None: - """Validates a single constructor parameter or a list of parameters.""" - if isinstance(param, list): - for p in param: - _validate_constructor_param(p, contracts) - return +def _process_raw_values(values: OrderedDict, variable_context: VariableContext) -> OrderedDict: + processed_parameters = OrderedDict() + for name, value in values.items(): + processed_parameters[name] = _process_raw_value(value, variable_context) - if not _is_variable(param): - return # literally a value - variable = param.strip(VARIABLE_PREFIX) + return processed_parameters - if _is_special_variable(variable): - return # special variables are always marked as valid - if variable in contracts: - return +def _get_contract_names(config: typing.Dict) -> List[str]: + contract_names = list() + for contract_info in config["contracts"]: + if isinstance(contract_info, str): + contract_names.append(contract_info) + elif isinstance(contract_info, dict): + contract_names.extend(list(contract_info.keys())) + else: + raise ValueError("Malformed constructor parameters YAML.") - raise ConstructorParameters.Invalid(f"Variable {param} is not resolvable") + return contract_names def _validate_constructor_abi_inputs( contract_name: str, abi_inputs: List[Any], - parameters: OrderedDict, - constants: typing.Dict[str, Any], + resolved_parameters: OrderedDict, ) -> None: """Validates the constructor parameters against the constructor ABI.""" - if len(parameters) != len(abi_inputs): + if len(resolved_parameters) != len(abi_inputs): raise ConstructorParameters.Invalid( f"Constructor parameters length mismatch - " - f"{contract_name} ABI requires {len(abi_inputs)}, Got {len(parameters)}." + f"{contract_name} ABI requires {len(abi_inputs)}, Got {len(resolved_parameters)}." ) if not abi_inputs: return # no constructor parameters - codex = enumerate(zip(abi_inputs, parameters.items()), start=0) + codex = enumerate(zip(abi_inputs, resolved_parameters.items()), start=0) for position, (abi_input, resolved_input) in codex: name, value = resolved_input # validate name @@ -192,30 +280,26 @@ def _validate_constructor_abi_inputs( ) # validate value type - value_to_validate = _resolve_param(value, constants=constants) - if not w3.is_encodable(abi_input.type, value_to_validate): + if not w3.is_encodable(abi_input.type, value): raise ConstructorParameters.Invalid( f"Constructor param name '{name}' at position {position} has a value '{value}' " f"whose type does not match expected ABI type '{abi_input.type}'" ) -def validate_constructor_parameters(contracts, constants) -> None: +def validate_constructor_parameters(contracts_parameters) -> None: """Validates the constructor parameters for all contracts in a single config.""" - available_contracts = list(contracts.keys()) - for contract, parameters in contracts.items(): + for contract, parameters in contracts_parameters.items(): if not isinstance(parameters, dict): # this can happen if the yml file is malformed raise ValueError(f"Malformed constructor parameter config for {contract}.") - for value in parameters.values(): - _validate_constructor_param(value, available_contracts) + resolved_parameters = _resolve_params(parameters=parameters) contract_container = get_contract_container(contract) _validate_constructor_abi_inputs( contract_name=contract, abi_inputs=contract_container.constructor.abi.inputs, - parameters=parameters, - constants=constants, + resolved_parameters=resolved_parameters, ) @@ -225,16 +309,17 @@ class ConstructorParameters: class Invalid(Exception): """Raised when the constructor parameters are invalid""" - def __init__(self, parameters: OrderedDict, constants: dict = None): + def __init__(self, parameters: OrderedDict): self.parameters = parameters - self.constants = constants or {} - validate_constructor_parameters(parameters, constants) + validate_constructor_parameters(parameters) @classmethod def from_config(cls, config: typing.Dict) -> "ConstructorParameters": """Loads the constructor parameters from a JSON file.""" print("Processing contract constructor parameters...") contracts_config = OrderedDict() + contract_names = _get_contract_names(config) + constants = config.get("constants") for contract_info in config["contracts"]: if isinstance(contract_info, str): contract_constructor_params = {contract_info: OrderedDict()} @@ -244,40 +329,44 @@ def from_config(cls, config: typing.Dict) -> "ConstructorParameters": contract_name = list(contract_info.keys())[0] # only one entry contract_data = contract_info[contract_name] - parameter_values = OrderedDict() - if CONTRACT_CONSTRUCTOR_PARAMETER_KEY in contract_data: - parameter_values = OrderedDict( - contract_data[CONTRACT_CONSTRUCTOR_PARAMETER_KEY] - ) + parameter_values = cls._process_parameters( + constants, contract_data, contract_name, contract_names + ) contract_constructor_params = {contract_name: parameter_values} else: raise ValueError("Malformed constructor parameters YAML.") contracts_config.update(contract_constructor_params) - return cls(parameters=contracts_config, constants=config.get("constants")) + return cls(parameters=contracts_config) + + @classmethod + def _process_parameters(cls, constants, contract_data, contract_name, contract_names): + parameter_values = OrderedDict() + if CONTRACT_CONSTRUCTOR_PARAMETER_KEY in contract_data: + parameter_values = _process_raw_values( + contract_data[CONTRACT_CONSTRUCTOR_PARAMETER_KEY], + VariableContext( + contract_names=contract_names, constants=constants, contract_name=contract_name + ), + ) + return parameter_values def resolve(self, contract_name: str) -> OrderedDict: """Resolves the constructor parameters for a single contract.""" - resolved_params = OrderedDict() - for name, value in self.parameters[contract_name].items(): - resolved_params[name] = _resolve_param(value, constants=self.constants) + resolved_params = _resolve_params(self.parameters[contract_name]) return resolved_params -def validate_proxy_info(contracts_proxy_info, constants) -> None: +def validate_proxy_info(contracts_proxy_info) -> None: """Validates the proxy information for all contracts.""" - available_contracts = contracts_proxy_info.keys() contract_container = OZ_DEPENDENCY.TransparentUpgradeableProxy for contract, proxy_info in contracts_proxy_info.items(): - constructor_params = proxy_info.constructor_params - for value in constructor_params.values(): - _validate_constructor_param(value, available_contracts) + resolved_parameters = _resolve_params(proxy_info.constructor_params) _validate_constructor_abi_inputs( contract_name=contract_container.contract_type.name, abi_inputs=contract_container.constructor.abi.inputs, - parameters=constructor_params, - constants=constants, + resolved_parameters=resolved_parameters, ) @@ -294,15 +383,17 @@ class ProxyInfo(typing.NamedTuple): contract_type_container: ContractContainer constructor_params: OrderedDict - def __init__(self, contracts_proxy_info: OrderedDict, constants: dict = None): + def __init__(self, contracts_proxy_info: OrderedDict): self.contracts_proxy_info = contracts_proxy_info - self.constants = constants or {} - validate_proxy_info(contracts_proxy_info, self.constants) + validate_proxy_info(contracts_proxy_info) @classmethod def from_config(cls, config: typing.Dict) -> "ProxyParameters": """Loads the proxy parameters from a JSON config file.""" print("Processing proxy parameters...") + contract_names = _get_contract_names(config) + constants = config.get("constants") + contracts_proxy_info = OrderedDict() for contract_info in config["contracts"]: if isinstance(contract_info, str): @@ -316,10 +407,18 @@ def from_config(cls, config: typing.Dict) -> "ProxyParameters": if CONTRACT_PROXY_PARAMETER_KEY not in contract_data: continue - proxy_info = cls._generate_proxy_info(contract_name, contract_data) + proxy_info = cls._generate_proxy_info( + contract_data, + VariableContext( + contract_names=contract_names, + constants=constants, + contract_name=contract_name, + check_for_proxy_instances=False, + ), + ) contracts_proxy_info.update({contract_name: proxy_info}) - return cls(contracts_proxy_info=contracts_proxy_info, constants=config.get("constants")) + return cls(contracts_proxy_info=contracts_proxy_info) def contract_needs_proxy(self, contract_name) -> bool: proxy_info = self.contracts_proxy_info.get(contract_name) @@ -335,36 +434,32 @@ def resolve(self, contract_name: str) -> typing.Tuple[ContractContainer, Ordered contract_container = proxy_info.contract_type_container - resolved_params = OrderedDict() - for name, value in proxy_info.constructor_params.items(): - resolved_params[name] = _resolve_param( - value=value, constants=self.constants, resolve_contracts_checking_proxies=False - ) - + resolved_params = _resolve_params(parameters=proxy_info.constructor_params) return contract_container, resolved_params @classmethod - def _generate_proxy_info(cls, contract_name, contract_data) -> ProxyInfo: + def _generate_proxy_info(cls, contract_data, variable_context: VariableContext) -> ProxyInfo: proxy_data = contract_data[CONTRACT_PROXY_PARAMETER_KEY] or dict() - contract_type = contract_name + contract_type = variable_context.contract_name if cls.CONTRACT_TYPE in proxy_data: contract_type = proxy_data[cls.CONTRACT_TYPE] contract_type_container = get_contract_container(contract_type) - constructor_data = cls._default_proxy_parameters(contract_name) + constructor_data = cls._default_proxy_parameters(variable_context.contract_name) if CONTRACT_CONSTRUCTOR_PARAMETER_KEY in proxy_data: - for name, value in proxy_data[CONTRACT_CONSTRUCTOR_PARAMETER_KEY].items(): - if name == "_logic": - raise cls.Invalid( - "'_logic' parameter cannot be specified: it is implicitly " - "the contract being proxied" - ) + proxy_constructor_params = proxy_data[CONTRACT_CONSTRUCTOR_PARAMETER_KEY] + if "_logic" in proxy_constructor_params: + raise cls.Invalid( + "'_logic' parameter cannot be specified: it is implicitly " + "the contract being proxied" + ) - constructor_data.update({name: value}) + constructor_data.update(proxy_constructor_params) + processed_values = _process_raw_values(constructor_data, variable_context) proxy_info = cls.ProxyInfo( - contract_type_container=contract_type_container, constructor_params=constructor_data + contract_type_container=contract_type_container, constructor_params=processed_values ) return proxy_info @@ -376,6 +471,22 @@ def _default_proxy_parameters(cls, contract_name: str) -> OrderedDict: return default_parameters +def _validate_transaction_args( + method: ContractTransactionHandler, args: typing.Tuple[Any, ...] +) -> typing.Dict[str, Any]: + """Validates the transaction arguments against the function ABI.""" + expected_length_abis = [abi for abi in method.abis if len(abi.inputs) == len(args)] + for abi in expected_length_abis: + named_args = {} + for arg, abi_input in zip(args, abi.inputs): + if not w3.is_encodable(abi_input.type, arg): + break + named_args[abi_input.name] = arg + else: + return named_args + raise ValueError(f"Could not find ABI for {method} with {len(args)} args and given types") + + class Transactor: """ Represents an ape account plus validated/annotated transaction execution. @@ -496,8 +607,8 @@ def _deploy_proxy( proxy_container, resolved_params=resolved_proxy_params ) print( - f"\nWrapping {target_contract_name} into " - f"{proxy_contract.contract_type.name} (as type {contract_type_container.contract_type.name}) " + f"\nWrapping {target_contract_name} into {proxy_contract.contract_type.name} " + f"(as type {contract_type_container.contract_type.name}) " f"at {proxy_contract.address}." ) return contract_type_container.at(proxy_contract.address)