diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3919289..a35de4c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -30,7 +30,7 @@ repos: rev: 0.7.17 hooks: - id: mdformat - additional_dependencies: [mdformat-gfm, mdformat-frontmatter] + additional_dependencies: [mdformat-gfm, mdformat-frontmatter, mdformat-pyproject] default_language_version: python: python3 diff --git a/README.md b/README.md index 4e26bc6..339d72d 100644 --- a/README.md +++ b/README.md @@ -77,6 +77,8 @@ Lastly, to get information about a token, including its contract address, you ca from ape_tokens import tokens bat = tokens["BAT"] +# Or, access via attribute +bat = tokens.BAT print(bat.address) ``` diff --git a/ape_tokens/managers.py b/ape_tokens/managers.py index f6de8c8..e3ac3ae 100644 --- a/ape_tokens/managers.py +++ b/ape_tokens/managers.py @@ -1,3 +1,5 @@ +from typing import Iterator, Mapping + from ape.contracts import ContractInstance from ape.exceptions import ContractNotFoundError from ape.types import ContractType @@ -106,7 +108,7 @@ ) -class TokenManager(ManagerAccessMixin, dict): +class TokenManager(ManagerAccessMixin, Mapping[str, ContractInstance]): @cached_property def _manager(self) -> TokenListManager: return TokenListManager() @@ -114,14 +116,25 @@ def _manager(self) -> TokenListManager: def __repr__(self) -> str: return f"" - def __getitem__(self, symbol: str) -> ContractInstance: - try: - token_info = self._manager.get_token_info( - symbol, chain_id=self.network_manager.network.chain_id - ) + def __len__(self) -> int: + return len(list(self._manager.get_tokens(chain_id=self.provider.chain_id))) - except ValueError as err: - raise KeyError(f"Symbol '{symbol}' is not a known token symbol") from err + def __iter__(self) -> Iterator[ContractInstance]: + for token in self._manager.get_tokens(chain_id=self.provider.chain_id): + yield self[token.symbol] + + def __getitem__(self, symbol: str) -> ContractInstance: + token_info = None + for tokenlist in self._manager.available_tokenlists(): + try: + token_info = self._manager.get_token_info( + symbol, chain_id=self.network_manager.network.chain_id, token_listname=tokenlist + ) + except ValueError: + continue + + if token_info is None: + raise KeyError(f"Symbol '{symbol}' is not a known token symbol") checksummed_address = to_checksum_address(token_info.address) try: @@ -130,3 +143,9 @@ def __getitem__(self, symbol: str) -> ContractInstance: return self.chain_manager.contracts.instance_at( checksummed_address, contract_type=ERC20 ) + + def __getattr__(self, symbol: str) -> ContractInstance: + try: + return self[symbol] + except KeyError as e: + raise AttributeError(str(e)) from None