diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c68479d59..a28b36343 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -34,10 +34,8 @@ jobs: with: starknet-foundry-version: ${{ env.FOUNDRY_VERSION }} - # Issue with cairo-coverage. Re-add to CI once issues are fixed. - # - # - name: Install cairo-coverage - # run: curl -L https://raw.githubusercontent.com/software-mansion/cairo-coverage/main/scripts/install.sh | sh + - name: Install cairo-coverage + run: curl -L https://raw.githubusercontent.com/software-mansion/cairo-coverage/main/scripts/install.sh | sh - name: Markdown lint uses: DavidAnson/markdownlint-cli2-action@eb5ca3ab411449c66620fe7f1b3c9e10547144b0 # v16 @@ -52,13 +50,11 @@ jobs: - name: Run tests run: snforge test --workspace - # Issue with cairo-coverage. Re-add to CI once issues are fixed. - # - # - name: Run tests and generate coverage report - # run: snforge test --workspace --coverage - # - # - name: Upload coverage to Codecov - # uses: codecov/codecov-action@v4 - # with: - # file: ./coverage.lcov - # token: ${{ secrets.CODECOV_TOKEN }} + - name: Run tests and generate coverage report + run: snforge test --workspace --coverage + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v4 + with: + file: ./coverage.lcov + token: ${{ secrets.CODECOV_TOKEN }} diff --git a/CHANGELOG.md b/CHANGELOG.md index b5d83bc02..da0b0c748 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- ERC4626Component (#1170) +- `Math::u256_mul_div` (#1170) - SRC9 (Outside Execution) integration to account presets (#1201) - `SNIP12HashSpanImpl` to `openzeppelin_utils::cryptography::snip12` (#1180) - GovernorComponent with the following extensions: (#1180) diff --git a/codecov.yml b/codecov.yml index dacdbb20d..75b444982 100644 --- a/codecov.yml +++ b/codecov.yml @@ -5,19 +5,22 @@ comment: coverage: # The value range where you want the value to be green # Hold ourselves to a high bar. - range: 90..100 + # TMP 80 floor until cairo-coverage becomes more stable + range: 80..100 status: project: coverage: # Use the coverage from the base commit (pull request base) coverage to compare against. # Once we have a baseline we can be more strict. + # TMP threshold until cairo-coverage becomes more stable target: auto - threshold: 2% + threshold: 4% patch: default: # Require new code to have 90%+ coverage. - target: 90% - threshold: 2% + # TMP target and threshold until cairo-coverage becomes more stable + target: 80% + threshold: 4% ignore: - "**/tests/**" @@ -27,4 +30,3 @@ ignore: github_checks: annotations: false - \ No newline at end of file diff --git a/packages/test_common/src/mocks.cairo b/packages/test_common/src/mocks.cairo index e19179acc..dfa9197af 100644 --- a/packages/test_common/src/mocks.cairo +++ b/packages/test_common/src/mocks.cairo @@ -4,6 +4,7 @@ pub mod checkpoint; pub mod erc1155; pub mod erc20; pub mod erc2981; +pub mod erc4626; pub mod erc721; pub mod governor; pub mod multisig; diff --git a/packages/test_common/src/mocks/erc20.cairo b/packages/test_common/src/mocks/erc20.cairo index de3c88e37..ef9503559 100644 --- a/packages/test_common/src/mocks/erc20.cairo +++ b/packages/test_common/src/mocks/erc20.cairo @@ -1,3 +1,5 @@ +use starknet::ContractAddress; + #[starknet::contract] pub mod DualCaseERC20Mock { use openzeppelin_token::erc20::{ERC20Component, ERC20HooksEmptyImpl}; @@ -224,3 +226,163 @@ pub mod DualCaseERC20PermitMock { self.erc20.mint(recipient, initial_supply); } } + +#[derive(Drop, Serde, PartialEq, Debug, starknet::Store)] +pub enum Type { + #[default] + No, + Before, + After, +} + +#[starknet::interface] +pub trait IERC20ReentrantHelpers { + fn schedule_reenter( + ref self: TState, + when: Type, + target: ContractAddress, + selector: felt252, + calldata: Span, + ); + fn function_call(ref self: TState); + fn unsafe_mint(ref self: TState, recipient: ContractAddress, amount: u256); + fn unsafe_burn(ref self: TState, account: ContractAddress, amount: u256); +} + +#[starknet::interface] +pub trait IERC20Reentrant { + fn schedule_reenter( + ref self: TState, + when: Type, + target: ContractAddress, + selector: felt252, + calldata: Span, + ); + fn function_call(ref self: TState); + fn unsafe_mint(ref self: TState, recipient: ContractAddress, amount: u256); + fn unsafe_burn(ref self: TState, account: ContractAddress, amount: u256); + + // IERC20 + fn total_supply(self: @TState) -> u256; + fn balance_of(self: @TState, account: ContractAddress) -> u256; + fn allowance(self: @TState, owner: ContractAddress, spender: ContractAddress) -> u256; + fn transfer(ref self: TState, recipient: ContractAddress, amount: u256) -> bool; + fn transfer_from( + ref self: TState, sender: ContractAddress, recipient: ContractAddress, amount: u256, + ) -> bool; + fn approve(ref self: TState, spender: ContractAddress, amount: u256) -> bool; +} + +#[starknet::contract] +pub mod ERC20ReentrantMock { + use openzeppelin_token::erc20::ERC20Component; + use starknet::ContractAddress; + use starknet::SyscallResultTrait; + use starknet::storage::{MutableVecTrait, Vec}; + use starknet::storage::{StoragePointerReadAccess, StoragePointerWriteAccess}; + use starknet::syscalls::call_contract_syscall; + use super::Type; + + component!(path: ERC20Component, storage: erc20, event: ERC20Event); + + #[abi(embed_v0)] + impl ERC20Impl = ERC20Component::ERC20Impl; + #[abi(embed_v0)] + impl ERC20MetadataImpl = ERC20Component::ERC20MetadataImpl; + #[abi(embed_v0)] + impl ERC20CamelOnlyImpl = ERC20Component::ERC20CamelOnlyImpl; + impl InternalImpl = ERC20Component::InternalImpl; + + #[storage] + pub struct Storage { + #[substorage(v0)] + pub erc20: ERC20Component::Storage, + reenter_type: Type, + reenter_target: ContractAddress, + reenter_selector: felt252, + reenter_calldata: Vec, + } + + #[event] + #[derive(Drop, starknet::Event)] + enum Event { + #[flat] + ERC20Event: ERC20Component::Event, + } + + // + // Hooks + // + + impl ERC20ReentrantImpl of ERC20Component::ERC20HooksTrait { + fn before_update( + ref self: ERC20Component::ComponentState, + from: ContractAddress, + recipient: ContractAddress, + amount: u256, + ) { + let mut contract_state = self.get_contract_mut(); + + if contract_state.reenter_type.read() == Type::Before { + contract_state.reenter_type.write(Type::No); + contract_state.function_call(); + } + } + + fn after_update( + ref self: ERC20Component::ComponentState, + from: ContractAddress, + recipient: ContractAddress, + amount: u256, + ) { + let mut contract_state = self.get_contract_mut(); + + if contract_state.reenter_type.read() == Type::After { + contract_state.reenter_type.write(Type::No); + contract_state.function_call(); + } + } + } + + #[abi(embed_v0)] + pub impl ERC20ReentrantHelpers of super::IERC20ReentrantHelpers { + fn schedule_reenter( + ref self: ContractState, + when: Type, + target: ContractAddress, + selector: felt252, + calldata: Span, + ) { + self.reenter_type.write(when); + self.reenter_target.write(target); + self.reenter_selector.write(selector); + for elem in calldata { + self.reenter_calldata.append().write(*elem); + } + } + + fn function_call(ref self: ContractState) { + let target = self.reenter_target.read(); + let selector = self.reenter_selector.read(); + let mut calldata = array![]; + for i in 0..self.reenter_calldata.len() { + calldata.append(self.reenter_calldata.at(i).read()); + }; + call_contract_syscall(target, selector, calldata.span()).unwrap_syscall(); + } + + fn unsafe_mint(ref self: ContractState, recipient: ContractAddress, amount: u256) { + self.erc20.mint(recipient, amount); + } + + fn unsafe_burn(ref self: ContractState, account: ContractAddress, amount: u256) { + self.erc20.burn(account, amount); + } + } + + #[constructor] + fn constructor(ref self: ContractState, name: ByteArray, symbol: ByteArray) { + self.erc20.initializer(name, symbol); + self.reenter_type.write(Type::No); + } +} diff --git a/packages/test_common/src/mocks/erc4626.cairo b/packages/test_common/src/mocks/erc4626.cairo new file mode 100644 index 000000000..148be9ba3 --- /dev/null +++ b/packages/test_common/src/mocks/erc4626.cairo @@ -0,0 +1,405 @@ +#[starknet::contract] +pub mod ERC4626Mock { + use openzeppelin_token::erc20::extensions::erc4626::DefaultConfig; + use openzeppelin_token::erc20::extensions::erc4626::ERC4626Component; + use openzeppelin_token::erc20::extensions::erc4626::ERC4626Component::InternalTrait as ERC4626InternalTrait; + use openzeppelin_token::erc20::extensions::erc4626::ERC4626DefaultLimits; + use openzeppelin_token::erc20::extensions::erc4626::ERC4626DefaultNoFees; + use openzeppelin_token::erc20::extensions::erc4626::ERC4626HooksEmptyImpl; + use openzeppelin_token::erc20::{ERC20Component, ERC20HooksEmptyImpl}; + use starknet::ContractAddress; + + component!(path: ERC4626Component, storage: erc4626, event: ERC4626Event); + component!(path: ERC20Component, storage: erc20, event: ERC20Event); + + // ERC4626 + #[abi(embed_v0)] + impl ERC4626ComponentImpl = ERC4626Component::ERC4626Impl; + // ERC4626MetadataImpl is a custom impl of IERC20Metadata + #[abi(embed_v0)] + impl ERC4626MetadataImpl = ERC4626Component::ERC4626MetadataImpl; + + // ERC20 + #[abi(embed_v0)] + impl ERC20Impl = ERC20Component::ERC20Impl; + #[abi(embed_v0)] + impl ERC20CamelOnlyImpl = ERC20Component::ERC20CamelOnlyImpl; + + impl ERC4626InternalImpl = ERC4626Component::InternalImpl; + impl ERC20InternalImpl = ERC20Component::InternalImpl; + + #[storage] + pub struct Storage { + #[substorage(v0)] + pub erc4626: ERC4626Component::Storage, + #[substorage(v0)] + pub erc20: ERC20Component::Storage, + } + + #[event] + #[derive(Drop, starknet::Event)] + enum Event { + #[flat] + ERC4626Event: ERC4626Component::Event, + #[flat] + ERC20Event: ERC20Component::Event, + } + + #[constructor] + fn constructor( + ref self: ContractState, + name: ByteArray, + symbol: ByteArray, + underlying_asset: ContractAddress, + initial_supply: u256, + recipient: ContractAddress, + ) { + self.erc20.initializer(name, symbol); + self.erc20.mint(recipient, initial_supply); + self.erc4626.initializer(underlying_asset); + } +} + +#[starknet::contract] +pub mod ERC4626OffsetMock { + use openzeppelin_token::erc20::extensions::erc4626::ERC4626Component; + use openzeppelin_token::erc20::extensions::erc4626::ERC4626Component::InternalTrait as ERC4626InternalTrait; + use openzeppelin_token::erc20::extensions::erc4626::ERC4626DefaultLimits; + use openzeppelin_token::erc20::extensions::erc4626::ERC4626DefaultNoFees; + use openzeppelin_token::erc20::extensions::erc4626::ERC4626HooksEmptyImpl; + use openzeppelin_token::erc20::{ERC20Component, ERC20HooksEmptyImpl}; + use starknet::ContractAddress; + + component!(path: ERC4626Component, storage: erc4626, event: ERC4626Event); + component!(path: ERC20Component, storage: erc20, event: ERC20Event); + + // ERC4626 + #[abi(embed_v0)] + impl ERC4626ComponentImpl = ERC4626Component::ERC4626Impl; + // ERC4626MetadataImpl is a custom impl of IERC20Metadata + #[abi(embed_v0)] + impl ERC4626MetadataImpl = ERC4626Component::ERC4626MetadataImpl; + + // ERC20 + #[abi(embed_v0)] + impl ERC20Impl = ERC20Component::ERC20Impl; + #[abi(embed_v0)] + impl ERC20CamelOnlyImpl = ERC20Component::ERC20CamelOnlyImpl; + + impl ERC4626InternalImpl = ERC4626Component::InternalImpl; + impl ERC20InternalImpl = ERC20Component::InternalImpl; + + #[storage] + pub struct Storage { + #[substorage(v0)] + pub erc4626: ERC4626Component::Storage, + #[substorage(v0)] + pub erc20: ERC20Component::Storage, + } + + #[event] + #[derive(Drop, starknet::Event)] + enum Event { + #[flat] + ERC4626Event: ERC4626Component::Event, + #[flat] + ERC20Event: ERC20Component::Event, + } + + pub impl OffsetConfig of ERC4626Component::ImmutableConfig { + const UNDERLYING_DECIMALS: u8 = ERC4626Component::DEFAULT_UNDERLYING_DECIMALS; + const DECIMALS_OFFSET: u8 = 1; + } + + #[constructor] + fn constructor( + ref self: ContractState, + name: ByteArray, + symbol: ByteArray, + underlying_asset: ContractAddress, + initial_supply: u256, + recipient: ContractAddress, + ) { + self.erc20.initializer(name, symbol); + self.erc20.mint(recipient, initial_supply); + self.erc4626.initializer(underlying_asset); + } +} + +#[starknet::contract] +pub mod ERC4626LimitsMock { + use openzeppelin_token::erc20::extensions::erc4626::ERC4626Component; + use openzeppelin_token::erc20::extensions::erc4626::ERC4626Component::InternalTrait as ERC4626InternalTrait; + use openzeppelin_token::erc20::extensions::erc4626::ERC4626DefaultNoFees; + use openzeppelin_token::erc20::extensions::erc4626::ERC4626HooksEmptyImpl; + use openzeppelin_token::erc20::{ERC20Component, ERC20HooksEmptyImpl}; + use starknet::ContractAddress; + + component!(path: ERC4626Component, storage: erc4626, event: ERC4626Event); + component!(path: ERC20Component, storage: erc20, event: ERC20Event); + + // ERC4626 + #[abi(embed_v0)] + impl ERC4626ComponentImpl = ERC4626Component::ERC4626Impl; + // ERC4626MetadataImpl is a custom impl of IERC20Metadata + #[abi(embed_v0)] + impl ERC4626MetadataImpl = ERC4626Component::ERC4626MetadataImpl; + + // ERC20 + #[abi(embed_v0)] + impl ERC20Impl = ERC20Component::ERC20Impl; + #[abi(embed_v0)] + impl ERC20CamelOnlyImpl = ERC20Component::ERC20CamelOnlyImpl; + + impl ERC4626InternalImpl = ERC4626Component::InternalImpl; + impl ERC20InternalImpl = ERC20Component::InternalImpl; + + #[storage] + pub struct Storage { + #[substorage(v0)] + pub erc4626: ERC4626Component::Storage, + #[substorage(v0)] + pub erc20: ERC20Component::Storage, + } + + #[event] + #[derive(Drop, starknet::Event)] + enum Event { + #[flat] + ERC4626Event: ERC4626Component::Event, + #[flat] + ERC20Event: ERC20Component::Event, + } + + pub impl OffsetConfig of ERC4626Component::ImmutableConfig { + const UNDERLYING_DECIMALS: u8 = ERC4626Component::DEFAULT_UNDERLYING_DECIMALS; + const DECIMALS_OFFSET: u8 = 1; + } + + const MAX_DEPOSIT: u256 = 100_000_000_000_000_000_000; + const MAX_MINT: u256 = 100_000_000_000_000_000_000; + + impl ERC4626LimitsImpl of ERC4626Component::LimitConfigTrait { + fn deposit_limit( + self: @ERC4626Component::ComponentState, receiver: ContractAddress, + ) -> Option:: { + Option::Some(MAX_DEPOSIT) + } + + fn mint_limit( + self: @ERC4626Component::ComponentState, receiver: ContractAddress, + ) -> Option:: { + Option::Some(MAX_MINT) + } + } + + #[constructor] + fn constructor( + ref self: ContractState, + name: ByteArray, + symbol: ByteArray, + underlying_asset: ContractAddress, + initial_supply: u256, + recipient: ContractAddress, + ) { + self.erc20.initializer(name, symbol); + self.erc20.mint(recipient, initial_supply); + self.erc4626.initializer(underlying_asset); + } +} + +/// The mock contract charges fees in terms of assets, not shares. +/// This means that the fees are calculated based on the amount of assets that are being deposited +/// or withdrawn, and not based on the amount of shares that are being minted or redeemed. +/// This is an opinionated design decision for the purpose of testing. +/// DO NOT USE IN PRODUCTION +#[starknet::contract] +pub mod ERC4626FeesMock { + use openzeppelin_token::erc20::extensions::erc4626::ERC4626Component; + use openzeppelin_token::erc20::extensions::erc4626::ERC4626Component::FeeConfigTrait; + use openzeppelin_token::erc20::extensions::erc4626::ERC4626Component::InternalTrait as ERC4626InternalTrait; + use openzeppelin_token::erc20::extensions::erc4626::ERC4626DefaultLimits; + use openzeppelin_token::erc20::interface::{IERC20Dispatcher, IERC20DispatcherTrait}; + use openzeppelin_token::erc20::{ERC20Component, ERC20HooksEmptyImpl}; + use openzeppelin_utils::math; + use openzeppelin_utils::math::Rounding; + use starknet::ContractAddress; + use starknet::storage::{StoragePointerReadAccess, StoragePointerWriteAccess}; + + component!(path: ERC4626Component, storage: erc4626, event: ERC4626Event); + component!(path: ERC20Component, storage: erc20, event: ERC20Event); + + // ERC4626 + #[abi(embed_v0)] + impl ERC4626ComponentImpl = ERC4626Component::ERC4626Impl; + // ERC4626MetadataImpl is a custom impl of IERC20Metadata + #[abi(embed_v0)] + impl ERC4626MetadataImpl = ERC4626Component::ERC4626MetadataImpl; + + // ERC20 + #[abi(embed_v0)] + impl ERC20Impl = ERC20Component::ERC20Impl; + #[abi(embed_v0)] + impl ERC20CamelOnlyImpl = ERC20Component::ERC20CamelOnlyImpl; + + impl ERC4626InternalImpl = ERC4626Component::InternalImpl; + impl ERC20InternalImpl = ERC20Component::InternalImpl; + + #[storage] + pub struct Storage { + #[substorage(v0)] + pub erc4626: ERC4626Component::Storage, + #[substorage(v0)] + pub erc20: ERC20Component::Storage, + pub entry_fee_basis_point_value: u256, + pub entry_fee_recipient: ContractAddress, + pub exit_fee_basis_point_value: u256, + pub exit_fee_recipient: ContractAddress, + } + + #[event] + #[derive(Drop, starknet::Event)] + enum Event { + #[flat] + ERC4626Event: ERC4626Component::Event, + #[flat] + ERC20Event: ERC20Component::Event, + } + + const _BASIS_POINT_SCALE: u256 = 10_000; + + /// Immutable config + impl OffsetConfig of ERC4626Component::ImmutableConfig { + const UNDERLYING_DECIMALS: u8 = ERC4626Component::DEFAULT_UNDERLYING_DECIMALS; + const DECIMALS_OFFSET: u8 = 0; + } + + /// Hooks + impl ERC4626HooksEmptyImpl of ERC4626Component::ERC4626HooksTrait { + fn after_deposit( + ref self: ERC4626Component::ComponentState, assets: u256, shares: u256, + ) { + let mut contract_state = self.get_contract_mut(); + let entry_basis_points = contract_state.entry_fee_basis_point_value.read(); + let fee = contract_state.fee_on_total(assets, entry_basis_points); + let recipient = contract_state.entry_fee_recipient.read(); + + if (fee > 0 && recipient != starknet::get_contract_address()) { + contract_state.transfer_fees(recipient, fee); + } + } + + fn before_withdraw( + ref self: ERC4626Component::ComponentState, assets: u256, shares: u256, + ) { + let mut contract_state = self.get_contract_mut(); + let exit_basis_points = contract_state.exit_fee_basis_point_value.read(); + let fee = contract_state.fee_on_raw(assets, exit_basis_points); + let recipient = contract_state.exit_fee_recipient.read(); + + if (fee > 0 && recipient != starknet::get_contract_address()) { + contract_state.transfer_fees(recipient, fee); + } + } + } + + /// Adjust fees + impl AdjustFeesImpl of FeeConfigTrait { + fn adjust_deposit( + self: @ERC4626Component::ComponentState, assets: u256, + ) -> u256 { + let contract_state = self.get_contract(); + contract_state.remove_fee_from_deposit(assets) + } + + fn adjust_mint( + self: @ERC4626Component::ComponentState, assets: u256, + ) -> u256 { + let contract_state = ERC4626Component::HasComponent::get_contract(self); + contract_state.add_fee_to_mint(assets) + } + + fn adjust_withdraw( + self: @ERC4626Component::ComponentState, assets: u256, + ) -> u256 { + let contract_state = ERC4626Component::HasComponent::get_contract(self); + contract_state.add_fee_to_withdraw(assets) + } + + fn adjust_redeem( + self: @ERC4626Component::ComponentState, assets: u256, + ) -> u256 { + let contract_state = ERC4626Component::HasComponent::get_contract(self); + contract_state.remove_fee_from_redeem(assets) + } + } + + #[constructor] + fn constructor( + ref self: ContractState, + name: ByteArray, + symbol: ByteArray, + underlying_asset: ContractAddress, + initial_supply: u256, + recipient: ContractAddress, + entry_fee: u256, + entry_treasury: ContractAddress, + exit_fee: u256, + exit_treasury: ContractAddress, + ) { + self.erc20.initializer(name, symbol); + self.erc20.mint(recipient, initial_supply); + self.erc4626.initializer(underlying_asset); + + self.entry_fee_basis_point_value.write(entry_fee); + self.entry_fee_recipient.write(entry_treasury); + self.exit_fee_basis_point_value.write(exit_fee); + self.exit_fee_recipient.write(exit_treasury); + } + + #[generate_trait] + pub impl InternalImpl of InternalTrait { + fn transfer_fees(ref self: ContractState, recipient: ContractAddress, fee: u256) { + let asset_address = self.asset(); + let asset_dispatcher = IERC20Dispatcher { contract_address: asset_address }; + assert(asset_dispatcher.transfer(recipient, fee), 'Fee transfer failed'); + } + + fn remove_fee_from_deposit(self: @ContractState, assets: u256) -> u256 { + let fee = self.fee_on_total(assets, self.entry_fee_basis_point_value.read()); + assets - fee + } + + fn add_fee_to_mint(self: @ContractState, assets: u256) -> u256 { + assets + self.fee_on_raw(assets, self.entry_fee_basis_point_value.read()) + } + + fn add_fee_to_withdraw(self: @ContractState, assets: u256) -> u256 { + let fee = self.fee_on_raw(assets, self.exit_fee_basis_point_value.read()); + assets + fee + } + + fn remove_fee_from_redeem(self: @ContractState, assets: u256) -> u256 { + assets - self.fee_on_total(assets, self.exit_fee_basis_point_value.read()) + } + + /// + /// Fee operations + /// + + /// Calculates the fees that should be added to an amount `assets` that does not already + /// include fees. + /// Used in IERC4626::mint and IERC4626::withdraw operations. + fn fee_on_raw(self: @ContractState, assets: u256, fee_basis_points: u256) -> u256 { + math::u256_mul_div(assets, fee_basis_points, _BASIS_POINT_SCALE, Rounding::Ceil) + } + + /// Calculates the fee part of an amount `assets` that already includes fees. + /// Used in IERC4626::deposit and IERC4626::redeem operations. + fn fee_on_total(self: @ContractState, assets: u256, fee_basis_points: u256) -> u256 { + math::u256_mul_div( + assets, fee_basis_points, fee_basis_points + _BASIS_POINT_SCALE, Rounding::Ceil, + ) + } + } +} diff --git a/packages/token/Scarb.toml b/packages/token/Scarb.toml index 6999f6322..4c89ad646 100644 --- a/packages/token/Scarb.toml +++ b/packages/token/Scarb.toml @@ -46,6 +46,11 @@ casm = false name = "openzeppelin_token_unittest" build-external-contracts = [ "openzeppelin_test_common::mocks::account::DualCaseAccountMock", + "openzeppelin_test_common::mocks::erc20::ERC20ReentrantMock", + "openzeppelin_test_common::mocks::erc4626::ERC4626Mock", + "openzeppelin_test_common::mocks::erc4626::ERC4626OffsetMock", + "openzeppelin_test_common::mocks::erc4626::ERC4626FeesMock", + "openzeppelin_test_common::mocks::erc4626::ERC4626LimitsMock", "openzeppelin_test_common::mocks::erc721::DualCaseERC721ReceiverMock", "openzeppelin_test_common::mocks::erc1155::DualCaseERC1155ReceiverMock", "openzeppelin_test_common::mocks::non_implementing::NonImplementingMock", diff --git a/packages/token/src/erc20.cairo b/packages/token/src/erc20.cairo index e3a8368f7..2f9f94b58 100644 --- a/packages/token/src/erc20.cairo +++ b/packages/token/src/erc20.cairo @@ -1,4 +1,5 @@ pub mod erc20; +pub mod extensions; pub mod interface; pub mod snip12_utils; diff --git a/packages/token/src/erc20/extensions.cairo b/packages/token/src/erc20/extensions.cairo new file mode 100644 index 000000000..52a7e03eb --- /dev/null +++ b/packages/token/src/erc20/extensions.cairo @@ -0,0 +1 @@ +pub mod erc4626; diff --git a/packages/token/src/erc20/extensions/erc4626.cairo b/packages/token/src/erc20/extensions/erc4626.cairo new file mode 100644 index 000000000..f9e34c1aa --- /dev/null +++ b/packages/token/src/erc20/extensions/erc4626.cairo @@ -0,0 +1,9 @@ +pub mod erc4626; +pub mod interface; + +pub use erc4626::DefaultConfig; +pub use erc4626::ERC4626Component; +pub use erc4626::ERC4626DefaultLimits; +pub use erc4626::ERC4626DefaultNoFees; +pub use erc4626::ERC4626HooksEmptyImpl; +pub use interface::IERC4626; diff --git a/packages/token/src/erc20/extensions/erc4626/erc4626.cairo b/packages/token/src/erc20/extensions/erc4626/erc4626.cairo new file mode 100644 index 000000000..97d6addf0 --- /dev/null +++ b/packages/token/src/erc20/extensions/erc4626/erc4626.cairo @@ -0,0 +1,681 @@ +// SPDX-License-Identifier: MIT +// OpenZeppelin Contracts for Cairo v0.20.0-rc.0 (token/erc20/extensions/erc4626/erc4626.cairo) + +/// # ERC4626 Component +/// +/// The ERC4626 component is an extension of ERC20 and provides an implementation of the IERC4626 +/// interface which allows the minting and burning of "shares" in exchange for an underlying +/// "asset". The component leverages traits to configure fees, limits, and decimals. +/// +/// CAUTION: In empty (or nearly empty) ERC-4626 vaults, deposits are at high risk of being stolen +/// through frontrunning with a "donation" to the vault that inflates the price of a share. This is +/// variously known as a donation or inflation attack and is essentially a problem of slippage. +/// Vault deployers can protect against this attack by making an initial deposit of a non-trivial +/// amount of the asset, such that price manipulation becomes infeasible. Withdrawals may similarly +/// be affected by slippage. Users can protect against this attack as well as unexpected slippage in +/// general by verifying the amount received is as expected, using a wrapper that performs these +/// checks. +/// +/// This implementation offers configurable virtual assets and shares to help developers mitigate +/// that risk. `ImmutableConfig::DECIMALS_OFFSET` corresponds to an offset in the decimal +/// representation between the underlying asset's decimals and vault decimals. This offset also +/// determines the rate of virtual shares to virtual assets in the vault, which itself determines +/// the initial exchange rate. While not fully preventing the attack, analysis shows that the +/// default offset (0) makes it non-profitable even if an attacker is able to capture value from +/// multiple user deposits, as a result of the value being captured by the virtual shares (out of +/// the attacker's donation) matching the attacker's expected gains. With a larger offset, the +/// attack becomes orders of magnitude more expensive than it is profitable. +/// +/// The drawback of this approach is that the virtual shares do capture (a very small) part of the +/// value being accrued to the vault. Also, if the vault experiences losses and users try to exit +/// the vault, the virtual shares and assets will cause the first exiting user to experience reduced +/// losses to the detriment to the last users who will experience bigger losses. +#[starknet::component] +pub mod ERC4626Component { + use core::num::traits::{Bounded, Pow, Zero}; + use crate::erc20::ERC20Component; + use crate::erc20::ERC20Component::InternalImpl as ERC20InternalImpl; + use crate::erc20::extensions::erc4626::interface::IERC4626; + use crate::erc20::interface::{IERC20, IERC20Metadata}; + use crate::erc20::interface::{IERC20Dispatcher, IERC20DispatcherTrait}; + use openzeppelin_utils::math; + use openzeppelin_utils::math::Rounding; + use starknet::ContractAddress; + use starknet::storage::{StoragePointerReadAccess, StoragePointerWriteAccess}; + + // The default values are only used when the DefaultConfig + // is in scope in the implementing contract. + pub const DEFAULT_UNDERLYING_DECIMALS: u8 = 18; + pub const DEFAULT_DECIMALS_OFFSET: u8 = 0; + + #[storage] + pub struct Storage { + ERC4626_asset: ContractAddress, + } + + #[event] + #[derive(Drop, PartialEq, starknet::Event)] + pub enum Event { + Deposit: Deposit, + Withdraw: Withdraw, + } + + /// Emitted when `sender` exchanges `assets` for `shares` and transfers those + /// `shares` to `owner`. + #[derive(Drop, PartialEq, starknet::Event)] + pub struct Deposit { + #[key] + pub sender: ContractAddress, + #[key] + pub owner: ContractAddress, + pub assets: u256, + pub shares: u256, + } + + /// Emitted when `sender` exchanges `shares`, owned by `owner`, for `assets` and transfers + /// those `assets` to `receiver`. + #[derive(Drop, PartialEq, starknet::Event)] + pub struct Withdraw { + #[key] + pub sender: ContractAddress, + #[key] + pub receiver: ContractAddress, + #[key] + pub owner: ContractAddress, + pub assets: u256, + pub shares: u256, + } + + pub mod Errors { + pub const EXCEEDED_MAX_DEPOSIT: felt252 = 'ERC4626: exceeds max deposit'; + pub const EXCEEDED_MAX_MINT: felt252 = 'ERC4626: exceeds max mint'; + pub const EXCEEDED_MAX_WITHDRAW: felt252 = 'ERC4626: exceeds max withdraw'; + pub const EXCEEDED_MAX_REDEEM: felt252 = 'ERC4626: exceeds max redeem'; + pub const TOKEN_TRANSFER_FAILED: felt252 = 'ERC4626: token transfer failed'; + pub const INVALID_ASSET_ADDRESS: felt252 = 'ERC4626: asset address set to 0'; + pub const DECIMALS_OVERFLOW: felt252 = 'ERC4626: decimals overflow'; + } + + /// Constants expected to be defined at the contract level which configure virtual + /// assets and shares. + /// + /// `UNDERLYING_DECIMALS` should match the underlying asset's decimals. The default + /// value is `18`. + /// + /// `DECIMALS_OFFSET` corresponds to the representational offset between `UNDERLYING_DECIMALS` + /// and the vault decimals. The greater the offset, the more expensive it is for attackers to + /// execute an inflation attack. + /// + /// Requirements: + /// + /// - `UNDERLYING_DECIMALS` + `DECIMALS_OFFSET` cannot exceed 255 (max u8). + pub trait ImmutableConfig { + const UNDERLYING_DECIMALS: u8; + const DECIMALS_OFFSET: u8; + + fn validate() { + assert( + Bounded::MAX - Self::UNDERLYING_DECIMALS >= Self::DECIMALS_OFFSET, + Errors::DECIMALS_OVERFLOW, + ) + } + } + + /// Adjustments for fees expected to be defined at the contract level. + /// Defaults to no entry or exit fees. + /// + /// NOTE: The FeeConfigTrait hooks directly into the preview methods of the ERC4626 component. + /// The preview methods must return as close to the exact amount of shares or assets as possible + /// if the actual (previewed) operation occurred in the same transaction (according to IERC4626 + /// spec). + /// All operations use their corresponding preview method as the value of assets or shares being + /// moved. + /// Therefore, adjusting an operation's assets in FeeConfigTrait consequently adjusts the assets + /// (or assets to be converted into shares) in both the preview operation and the actual + /// operation. + /// + /// NOTE: To transfer fees, this trait needs to be coordinated with + /// `ERC4626Component::ERC4626Hooks`. + /// See the ERC4626FeesMock example: + /// https://github.com/OpenZeppelin/cairo-contracts/tree/main/packages/test_common/src/mocks/erc4626.cairo + pub trait FeeConfigTrait> { + /// Adjusts deposits within `preview_deposit` to account for entry fees. + /// Entry fees should be transferred in the `after_deposit` hook. + fn adjust_deposit(self: @ComponentState, assets: u256) -> u256 { + assets + } + + /// Adjusts mints within `preview_mint` to account for entry fees. + /// Entry fees should be transferred in the `after_deposit` hook. + fn adjust_mint(self: @ComponentState, assets: u256) -> u256 { + assets + } + + /// Adjusts withdraws within `preview_withdraw` to account for exit fees. + /// Exit fees should be transferred in the `before_withdraw` hook. + fn adjust_withdraw(self: @ComponentState, assets: u256) -> u256 { + assets + } + + /// Adjusts redeems within `preview_redeem` to account for exit fees. + /// Exit fees should be transferred in the `before_withdraw` hook. + fn adjust_redeem(self: @ComponentState, assets: u256) -> u256 { + assets + } + } + + /// Sets limits to the target exchange type and is expected to be defined at the contract + /// level. + pub trait LimitConfigTrait> { + /// The max deposit allowed. + /// Defaults (`Option::None`) to 2 ** 256 - 1. + fn deposit_limit( + self: @ComponentState, receiver: ContractAddress, + ) -> Option { + Option::None + } + + /// The max deposit allowed. + /// Defaults (`Option::None`) to 2 ** 256 - 1. + fn mint_limit( + self: @ComponentState, receiver: ContractAddress, + ) -> Option { + Option::None + } + + /// The max withdraw allowed. + /// Defaults (`Option::None`) to the full asset balance of `owner` converted from shares. + fn withdraw_limit( + self: @ComponentState, owner: ContractAddress, + ) -> Option { + Option::None + } + + /// The max deposit allowed. + /// Defaults (`Option::None`) to the full asset balance of `owner`. + fn redeem_limit( + self: @ComponentState, owner: ContractAddress, + ) -> Option { + Option::None + } + } + + /// Allows contracts to hook logic into deposit and withdraw transactions. + /// This is where contracts can transfer fees. + /// + /// NOTE: ERC4626 preview methods must be inclusive of any entry or exit fees. + /// The `AdjustFeesTrait` will adjust these values accordingly; therefore, + /// fees must be set in the `AdjustFeesTrait` if the using contract enforces + /// entry or exit fees. + /// + /// See the example: + /// https://github.com/OpenZeppelin/cairo-contracts/tree/main/packages/test_common/src/mocks/erc4626.cairo + pub trait ERC4626HooksTrait> { + /// Hooks into `InternalImpl::_withdraw`. + /// Executes logic before burning shares and transferring assets. + fn before_withdraw(ref self: ComponentState, assets: u256, shares: u256) {} + /// Hooks into `InternalImpl::_deposit`. + /// Executes logic after transferring assets and minting shares. + fn after_deposit(ref self: ComponentState, assets: u256, shares: u256) {} + } + + // + // External + // + + #[embeddable_as(ERC4626Impl)] + impl ERC4626< + TContractState, + +HasComponent, + impl Fee: FeeConfigTrait, + impl Limit: LimitConfigTrait, + impl Hooks: ERC4626HooksTrait, + impl Immutable: ImmutableConfig, + impl ERC20: ERC20Component::HasComponent, + +ERC20Component::ERC20HooksTrait, + +Drop, + > of IERC4626> { + /// Returns the address of the underlying token used for the Vault for accounting, + /// depositing, and withdrawing. + fn asset(self: @ComponentState) -> ContractAddress { + self.ERC4626_asset.read() + } + + /// Returns the total amount of the underlying asset that is “managed” by Vault. + fn total_assets(self: @ComponentState) -> u256 { + let this = starknet::get_contract_address(); + let asset_dispatcher = IERC20Dispatcher { contract_address: self.ERC4626_asset.read() }; + asset_dispatcher.balance_of(this) + } + + /// Returns the amount of shares that the Vault would exchange for the amount of assets + /// provided irrespective of slippage or fees. + fn convert_to_shares(self: @ComponentState, assets: u256) -> u256 { + self._convert_to_shares(assets, Rounding::Floor) + } + + /// Returns the amount of assets that the Vault would exchange for the amount of shares + /// provided irrespective of slippage or fees. + fn convert_to_assets(self: @ComponentState, shares: u256) -> u256 { + self._convert_to_assets(shares, Rounding::Floor) + } + + /// Returns the maximum amount of the underlying asset that can be deposited into the Vault + /// for the receiver, through a deposit call. + /// + /// The default max deposit value is 2 ** 256 - 1. + /// This can be changed in the implementing contract by defining custom logic in + /// `LimitConfigTrait::deposit_limit`. + fn max_deposit(self: @ComponentState, receiver: ContractAddress) -> u256 { + match Limit::deposit_limit(self, receiver) { + Option::Some(limit) => limit, + Option::None => Bounded::MAX, + } + } + + /// Allows an on-chain or off-chain user to simulate the effects of their deposit at the + /// current block, given current on-chain conditions. + /// + /// The default deposit preview value is the full amount of shares. + /// This can be changed to account for fees, for example, in the implementing contract by + /// defining custom logic in `FeeConfigTrait::adjust_deposit`. + /// + /// NOTE: `preview_deposit` must be inclusive of entry fees to be compliant with the + /// IERC4626 spec. + fn preview_deposit(self: @ComponentState, assets: u256) -> u256 { + let adjusted_assets = Fee::adjust_deposit(self, assets); + self._convert_to_shares(adjusted_assets, Rounding::Floor) + } + + /// Mints Vault shares to `receiver` by depositing exactly `assets` of underlying tokens. + /// Returns the amount of newly-minted shares. + /// + /// Requirements: + /// + /// - `assets` is less than or equal to the max deposit amount for `receiver`. + /// + /// Emits a `Deposit` event. + fn deposit( + ref self: ComponentState, assets: u256, receiver: ContractAddress, + ) -> u256 { + let max_assets = self.max_deposit(receiver); + assert(assets <= max_assets, Errors::EXCEEDED_MAX_DEPOSIT); + + let shares = self.preview_deposit(assets); + let caller = starknet::get_caller_address(); + self._deposit(caller, receiver, assets, shares); + + shares + } + + /// Returns the maximum amount of the Vault shares that can be minted for `receiver` through + /// a `mint` call. + /// + /// The default max mint value is 2 ** 256 - 1. + /// This can be changed in the implementing contract by defining custom logic in + /// `LimitConfigTrait::mint_limit`. + fn max_mint(self: @ComponentState, receiver: ContractAddress) -> u256 { + match Limit::mint_limit(self, receiver) { + Option::Some(limit) => limit, + Option::None => Bounded::MAX, + } + } + + /// Allows an on-chain or off-chain user to simulate the effects of their mint at the + /// current block, given current on-chain conditions. + /// + /// The default mint preview value is the full amount of assets. + /// This can be changed to account for fees, for example, in the implementing contract by + /// defining custom logic in `FeeConfigTrait::adjust_mint`. + /// + /// NOTE: `preview_mint` must be inclusive of entry fees to be compliant with the IERC4626 + /// spec. + fn preview_mint(self: @ComponentState, shares: u256) -> u256 { + let full_assets = self._convert_to_assets(shares, Rounding::Ceil); + Fee::adjust_mint(self, full_assets) + } + + /// Mints exactly Vault `shares` to `receiver` by depositing amount of underlying tokens. + /// Returns the amount deposited assets. + /// + /// Requirements: + /// + /// - `shares` is less than or equal to the max shares amount for `receiver`. + /// + /// Emits a `Deposit` event. + fn mint( + ref self: ComponentState, shares: u256, receiver: ContractAddress, + ) -> u256 { + let max_shares = self.max_mint(receiver); + assert(shares <= max_shares, Errors::EXCEEDED_MAX_MINT); + + let assets = self.preview_mint(shares); + let caller = starknet::get_caller_address(); + self._deposit(caller, receiver, assets, shares); + + assets + } + + /// Returns the maximum amount of the underlying asset that can be withdrawn from the owner + /// balance in the Vault, through a `withdraw` call. + /// + /// The default max withdraw value is the full balance of assets for `owner` (converted from + /// shares). + /// This can be changed in the implementing contract by defining custom logic in + /// `LimitConfigTrait::withdraw_limit`. + fn max_withdraw(self: @ComponentState, owner: ContractAddress) -> u256 { + match Limit::withdraw_limit(self, owner) { + Option::Some(limit) => limit, + Option::None => { + let erc20_component = get_dep_component!(self, ERC20); + let owner_shares = erc20_component.balance_of(owner); + self._convert_to_assets(owner_shares, Rounding::Floor) + }, + } + } + + /// Allows an on-chain or off-chain user to simulate the effects of their withdrawal at the + /// current block, given current on-chain conditions. + /// + /// The default withdraw preview value is the full amount of shares. + /// This can be changed to account for fees, for example, in the implementing contract by + /// defining custom logic in `FeeConfigTrait::adjust_withdraw`. + /// + /// NOTE: `preview_withdraw` must be inclusive of exit fees to be compliant with the + /// IERC4626 spec. + fn preview_withdraw(self: @ComponentState, assets: u256) -> u256 { + let adjusted_assets = Fee::adjust_withdraw(self, assets); + self._convert_to_shares(adjusted_assets, Rounding::Ceil) + } + + /// Burns shares from `owner` and sends exactly `assets` of underlying tokens to `receiver`. + /// + /// Requirements: + /// + /// - `assets` is less than or equal to the max withdraw amount of `owner`. + /// + /// Emits a `Withdraw` event. + fn withdraw( + ref self: ComponentState, + assets: u256, + receiver: ContractAddress, + owner: ContractAddress, + ) -> u256 { + let max_assets = self.max_withdraw(owner); + assert(assets <= max_assets, Errors::EXCEEDED_MAX_WITHDRAW); + + let shares = self.preview_withdraw(assets); + let caller = starknet::get_caller_address(); + self._withdraw(caller, receiver, owner, assets, shares); + + shares + } + + /// Returns the maximum amount of Vault shares that can be redeemed from the owner balance + /// in the Vault, through a `redeem` call. + /// + /// The default max redeem value is the full balance of assets for `owner`. + /// This can be changed in the implementing contract by defining custom logic in + /// `LimitConfigTrait::redeem_limit`. + fn max_redeem(self: @ComponentState, owner: ContractAddress) -> u256 { + match Limit::redeem_limit(self, owner) { + Option::Some(limit) => limit, + Option::None => { + let erc20_component = get_dep_component!(self, ERC20); + erc20_component.balance_of(owner) + }, + } + } + + /// Allows an on-chain or off-chain user to simulate the effects of their redeemption at the + /// current block, given current on-chain conditions. + /// + /// The default redeem preview value is the full amount of assets. + /// This can be changed to account for fees, for example, in the implementing contract by + /// defining custom logic in `FeeConfigTrait::adjust_redeem`. + /// + /// NOTE: `preview_redeem` must be inclusive of exit fees to be compliant with the IERC4626 + /// spec. + fn preview_redeem(self: @ComponentState, shares: u256) -> u256 { + let full_assets = self._convert_to_assets(shares, Rounding::Floor); + Fee::adjust_redeem(self, full_assets) + } + + /// Burns exactly `shares` from `owner` and sends assets of underlying tokens to `receiver`. + /// + /// Requirements: + /// + /// - `shares` is less than or equal to the max redeem amount of `owner`. + /// + /// Emits a `Withdraw` event. + fn redeem( + ref self: ComponentState, + shares: u256, + receiver: ContractAddress, + owner: ContractAddress, + ) -> u256 { + let max_shares = self.max_redeem(owner); + assert(shares <= max_shares, Errors::EXCEEDED_MAX_REDEEM); + + let assets = self.preview_redeem(shares); + let caller = starknet::get_caller_address(); + self._withdraw(caller, receiver, owner, assets, shares); + + assets + } + } + + #[embeddable_as(ERC4626MetadataImpl)] + impl ERC4626Metadata< + TContractState, + +HasComponent, + impl Immutable: ImmutableConfig, + impl ERC20: ERC20Component::HasComponent, + > of IERC20Metadata> { + /// Returns the name of the token. + fn name(self: @ComponentState) -> ByteArray { + let erc20_component = get_dep_component!(self, ERC20); + erc20_component.ERC20_name.read() + } + + /// Returns the ticker symbol of the token, usually a shorter version of the name. + fn symbol(self: @ComponentState) -> ByteArray { + let erc20_component = get_dep_component!(self, ERC20); + erc20_component.ERC20_symbol.read() + } + + /// Returns the cumulative number of decimals which includes both the underlying and offset + /// decimals. + /// Both of which must be defined in the `ImmutableConfig` inside the implementing contract. + fn decimals(self: @ComponentState) -> u8 { + Immutable::UNDERLYING_DECIMALS + Immutable::DECIMALS_OFFSET + } + } + + // + // Internal + // + + #[generate_trait] + pub impl InternalImpl< + TContractState, + +HasComponent, + impl Hooks: ERC4626HooksTrait, + impl Immutable: ImmutableConfig, + impl ERC20: ERC20Component::HasComponent, + +FeeConfigTrait, + +LimitConfigTrait, + +ERC20Component::ERC20HooksTrait, + +Drop, + > of InternalTrait { + /// Validates the `ImmutableConfig` constants and sets the `asset_address` to the vault. + /// This should be set in the contract's constructor. + /// + /// Requirements: + /// + /// - `asset_address` cannot be the zero address. + fn initializer(ref self: ComponentState, asset_address: ContractAddress) { + Immutable::validate(); + assert(asset_address.is_non_zero(), Errors::INVALID_ASSET_ADDRESS); + self.ERC4626_asset.write(asset_address); + } + + /// Internal logic for `deposit` and `mint`. + /// Transfers `assets` from `caller` to the Vault contract then mints `shares` to + /// `receiver`. + /// Fees can be transferred in the `ERC4626Hooks::after_deposit` hook which is executed + /// after assets are transferred and shares are minted. + /// + /// Requirements: + /// + /// - `ERC20::transfer_from` must return true. + /// + /// Emits two `ERC20::Transfer` events (`ERC20::mint` and `ERC20::transfer_from`). + /// Emits a `Deposit` event. + fn _deposit( + ref self: ComponentState, + caller: ContractAddress, + receiver: ContractAddress, + assets: u256, + shares: u256, + ) { + // Transfer assets first + let this = starknet::get_contract_address(); + let asset_dispatcher = IERC20Dispatcher { contract_address: self.ERC4626_asset.read() }; + assert( + asset_dispatcher.transfer_from(caller, this, assets), Errors::TOKEN_TRANSFER_FAILED, + ); + + // Mint shares after transferring assets + let mut erc20_component = get_dep_component_mut!(ref self, ERC20); + erc20_component.mint(receiver, shares); + self.emit(Deposit { sender: caller, owner: receiver, assets, shares }); + + // After deposit hook + Hooks::after_deposit(ref self, assets, shares); + } + + /// Internal logic for `withdraw` and `redeem`. + /// Burns `shares` from `owner` and then transfers `assets` to `receiver`. + /// Fees can be transferred in the `ERC4626Hooks::before_withdraw` hook which is executed + /// before shares are burned and assets are transferred. + /// + /// Requirements: + /// + /// - `ERC20::transfer` must return true. + /// + /// Emits two `ERC20::Transfer` events (`ERC20::burn` and `ERC20::transfer`). + /// + /// Emits a `Withdraw` event. + fn _withdraw( + ref self: ComponentState, + caller: ContractAddress, + receiver: ContractAddress, + owner: ContractAddress, + assets: u256, + shares: u256, + ) { + // Before withdraw hook + Hooks::before_withdraw(ref self, assets, shares); + + // Burn shares first + let mut erc20_component = get_dep_component_mut!(ref self, ERC20); + if caller != owner { + erc20_component._spend_allowance(owner, caller, shares); + } + erc20_component.burn(owner, shares); + + // Transfer assets after burn + let asset_dispatcher = IERC20Dispatcher { contract_address: self.ERC4626_asset.read() }; + assert(asset_dispatcher.transfer(receiver, assets), Errors::TOKEN_TRANSFER_FAILED); + + self.emit(Withdraw { sender: caller, receiver, owner, assets, shares }); + } + + /// Internal conversion function (from assets to shares) with support for `rounding` + /// direction. + fn _convert_to_shares( + self: @ComponentState, assets: u256, rounding: Rounding, + ) -> u256 { + let erc20_component = get_dep_component!(self, ERC20); + let total_supply = erc20_component.total_supply(); + + math::u256_mul_div( + assets, + total_supply + 10_u256.pow(Immutable::DECIMALS_OFFSET.into()), + self.total_assets() + 1, + rounding, + ) + } + + /// Internal conversion function (from shares to assets) with support for `rounding` + /// direction. + fn _convert_to_assets( + self: @ComponentState, shares: u256, rounding: Rounding, + ) -> u256 { + let erc20_component = get_dep_component!(self, ERC20); + let total_supply = erc20_component.total_supply(); + + math::u256_mul_div( + shares, + self.total_assets() + 1, + total_supply + 10_u256.pow(Immutable::DECIMALS_OFFSET.into()), + rounding, + ) + } + } +} + +// +// Default (empty) traits +// + +pub impl ERC4626HooksEmptyImpl< + TContractState, +ERC4626Component::HasComponent, +> of ERC4626Component::ERC4626HooksTrait {} +pub impl ERC4626DefaultNoFees< + TContractState, +ERC4626Component::HasComponent, +> of ERC4626Component::FeeConfigTrait {} +pub impl ERC4626DefaultLimits< + TContractState, +ERC4626Component::HasComponent, +> of ERC4626Component::LimitConfigTrait {} + +/// Implementation of the default `ERC4626Component::ImmutableConfig`. +/// +/// See +/// https://github.com/starknet-io/SNIPs/blob/963848f0752bde75c7087c2446d83b7da8118b25/SNIPS/snip-107.md#defaultconfig-implementation +/// +/// The default `UNDERLYING_DECIMALS` is set to `18`. +/// The default `DECIMALS_OFFSET` is set to `0`. +pub impl DefaultConfig of ERC4626Component::ImmutableConfig { + const UNDERLYING_DECIMALS: u8 = ERC4626Component::DEFAULT_UNDERLYING_DECIMALS; + const DECIMALS_OFFSET: u8 = ERC4626Component::DEFAULT_DECIMALS_OFFSET; +} + +#[cfg(test)] +mod Test { + use openzeppelin_test_common::mocks::erc4626::ERC4626Mock; + use super::ERC4626Component::InternalImpl; + use super::{ERC4626Component, ERC4626DefaultLimits, ERC4626DefaultNoFees}; + + type ComponentState = ERC4626Component::ComponentState; + + fn COMPONENT_STATE() -> ComponentState { + ERC4626Component::component_state_for_testing() + } + + // Invalid decimals + impl InvalidImmutableConfig of ERC4626Component::ImmutableConfig { + const UNDERLYING_DECIMALS: u8 = 255; + const DECIMALS_OFFSET: u8 = 1; + } + + #[test] + #[should_panic(expected: 'ERC4626: decimals overflow')] + fn test_initializer_invalid_config_panics() { + let mut state = COMPONENT_STATE(); + let asset = starknet::contract_address_const::<'ASSET'>(); + + state.initializer(asset); + } +} diff --git a/packages/token/src/erc20/extensions/erc4626/interface.cairo b/packages/token/src/erc20/extensions/erc4626/interface.cairo new file mode 100644 index 000000000..286977769 --- /dev/null +++ b/packages/token/src/erc20/extensions/erc4626/interface.cairo @@ -0,0 +1,254 @@ +// SPDX-License-Identifier: MIT +// OpenZeppelin Contracts for Cairo v0.20.0-rc.0 (token/erc20/extensions/erc4626/interface.cairo) + +use starknet::ContractAddress; + +#[starknet::interface] +pub trait IERC4626 { + /// Returns the address of the underlying token used for the Vault for accounting, depositing, + /// and withdrawing. + /// + /// MUST be an ERC20 token contract. + /// MUST NOT panic. + fn asset(self: @TState) -> ContractAddress; + + /// Returns the total amount of the underlying asset that is “managed” by Vault. + /// + /// SHOULD include any compounding that occurs from yield. + /// MUST be inclusive of any fees that are charged against assets in the Vault. + /// MUST NOT panic. + fn total_assets(self: @TState) -> u256; + + /// Returns the amount of shares that the Vault would exchange for the amount of assets + /// provided irrespective of slippage or fees. + /// + /// MUST NOT be inclusive of any fees that are charged against assets in the Vault. + /// MUST NOT show any variations depending on the caller. + /// MUST NOT reflect slippage or other on-chain conditions, when performing the actual exchange. + /// MUST NOT panic. + /// + /// NOTE: This calculation MAY NOT reflect the "per-user" price-per-share, and instead should + /// reflect the "average-user's" price-per-share, meaning what the average user should expect to + /// see when exchanging to and from. + fn convert_to_shares(self: @TState, assets: u256) -> u256; + + /// Returns the amount of assets that the Vault would exchange for the amount of shares + /// provided irrespective of slippage or fees. + /// + /// MUST NOT be inclusive of any fees that are charged against assets in the Vault. + /// MUST NOT show any variations depending on the caller. + /// MUST NOT reflect slippage or other on-chain conditions, when performing the actual exchange. + /// MUST NOT panic. + /// + /// NOTE: This calculation MAY NOT reflect the “per-user” price-per-share, and instead + /// should reflect the “average-user’s” price-per-share, meaning what the average user + /// should expect to see when exchanging to and from. + fn convert_to_assets(self: @TState, shares: u256) -> u256; + + /// Returns the maximum amount of the underlying asset that can be deposited into the Vault for + /// `receiver`, through a deposit call. + /// + /// MUST return a limited value if receiver is subject to some deposit limit. + /// MUST return 2 ** 256 - 1 if there is no limit on the maximum amount of assets that may be + /// deposited. + /// MUST NOT panic. + fn max_deposit(self: @TState, receiver: ContractAddress) -> u256; + + /// Allows an on-chain or off-chain user to simulate the effects of their deposit at the current + /// block, given current on-chain conditions. + /// + /// MUST return as close to and no more than the exact amount of Vault shares that would be + /// minted in a deposit call in the same transaction i.e. deposit should return the same or more + /// shares as `preview_deposit` if called in the same transaction. + /// MUST NOT account for deposit limits like those returned from `max_deposit` and should always + /// act as though the deposit would be accepted, regardless if the user has enough tokens + /// approved, etc. + /// MUST be inclusive of deposit fees. Integrators should be aware of the existence of deposit + /// fees. + /// MUST NOT panic. + /// + /// NOTE: Any unfavorable discrepancy between `convert_to_shares` and `preview_deposit` + /// SHOULD be considered slippage in share price or some other type of condition, meaning the + /// depositor will lose assets by depositing. + fn preview_deposit(self: @TState, assets: u256) -> u256; + + /// Mints Vault shares to `receiver` by depositing exactly amount of `assets`. + /// + /// MUST emit the Deposit event. + /// MAY support an additional flow in which the underlying tokens are owned by the Vault + /// contract before the deposit execution, and are accounted for during deposit. + /// MUST panic if all of assets cannot be deposited (due to deposit limit being reached, + /// slippage, the user not approving enough underlying tokens to the Vault contract, etc). + /// + /// NOTE: Most implementations will require pre-approval of the Vault with the Vault’s + /// underlying asset token. + fn deposit(ref self: TState, assets: u256, receiver: ContractAddress) -> u256; + + /// Returns the maximum amount of the Vault shares that can be minted for the receiver, through + /// a mint call. + /// + /// MUST return a limited value if receiver is subject to some mint limit. + /// MUST return 2 ** 256 - 1 if there is no limit on the maximum amount of shares that may be + /// minted. + /// MUST NOT panic. + fn max_mint(self: @TState, receiver: ContractAddress) -> u256; + + /// Allows an on-chain or off-chain user to simulate the effects of their mint at the current + /// block, given current on-chain conditions. + /// + /// MUST return as close to and no fewer than the exact amount of assets that would be deposited + /// in a `mint` call in the same transaction. I.e. `mint` should return the same or fewer assets + /// as `preview_mint` if called in the same transaction. + /// MUST NOT account for mint limits like those returned from `max_mint` and should always act + /// as though the mint would be accepted, regardless if the user has enough tokens approved, + /// etc. + /// MUST be inclusive of deposit fees. Integrators should be aware of the existence of deposit + /// fees. + /// MUST NOT panic. + /// + /// NOTE: Any unfavorable discrepancy between convertToAssets and previewMint SHOULD be + /// considered slippage in share price or some other type of condition, meaning the depositor + /// will lose assets by minting. + fn preview_mint(self: @TState, shares: u256) -> u256; + + /// Mints exactly shares Vault shares to receiver by depositing amount of underlying tokens. + /// + /// MUST emit the `Deposit` event. + /// MAY support an additional flow in which the underlying tokens are owned by the Vault + /// contract before the mint execution, and are accounted for during mint. + /// MUST panic if all of shares cannot be minted (due to deposit limit being reached, slippage, + /// the user not approving enough underlying tokens to the Vault contract, etc). + /// + /// NOTE: Most implementations will require pre-approval of the Vault with the Vault’s + /// underlying asset token. + fn mint(ref self: TState, shares: u256, receiver: ContractAddress) -> u256; + + /// Returns the maximum amount of the underlying asset that can be withdrawn from the owner + /// balance in the Vault, through a withdraw call. + /// + /// MUST return a limited value if owner is subject to some withdrawal limit or timelock. + /// MUST NOT panic. + fn max_withdraw(self: @TState, owner: ContractAddress) -> u256; + + /// Allows an on-chain or off-chain user to simulate the effects of their withdrawal at the + /// current block, given current on-chain conditions. + /// + /// MUST return as close to and no fewer than the exact amount of Vault shares that would be + /// burned in a withdraw call in the same transaction i.e. withdraw should return the same or + /// fewer shares as preview_withdraw if called in the same transaction. + /// MUST NOT account for withdrawal limits like those returned from max_withdraw and should + /// always act as though the withdrawal would be accepted, regardless if the user has enough + /// shares, etc. + /// MUST be inclusive of withdrawal fees. Integrators should be aware of the existence of + /// withdrawal fees. + /// MUST not panic. + /// + /// NOTE: Any unfavorable discrepancy between `convert_to_shares` and `preview_withdraw` + /// SHOULD be considered slippage in share price or some other type of condition, meaning the + /// depositor will lose assets by depositing. + fn preview_withdraw(self: @TState, assets: u256) -> u256; + + /// Burns shares from owner and sends exactly assets of underlying tokens to receiver. + /// + /// MUST emit the `Withdraw` event. + /// MAY support an additional flow in which the underlying tokens are owned by the Vault + /// contract before the withdraw execution, and are accounted for during withdraw. + /// MUST revert if all of assets cannot be withdrawn (due to withdrawal limit being reached, + /// slippage, the owner not having enough shares, etc). + /// + /// NOTE: Some implementations will require pre-requesting to the Vault before a withdrawal + /// may be performed. + /// Those methods should be performed separately. + fn withdraw( + ref self: TState, assets: u256, receiver: ContractAddress, owner: ContractAddress, + ) -> u256; + + /// Returns the maximum amount of Vault shares that can be redeemed from the owner balance in + /// the Vault, through a redeem call. + /// + /// MUST return a limited value if owner is subject to some withdrawal limit or timelock. + /// MUST return `ERC20::balance_of(owner)` if `owner` is not subject to any withdrawal limit or + /// timelock. + /// MUST NOT panic. + fn max_redeem(self: @TState, owner: ContractAddress) -> u256; + + /// Allows an on-chain or off-chain user to simulate the effects of their redeemption at the + /// current block, given current on-chain conditions. + /// + /// MUST return as close to and no more than the exact amount of assets that would be withdrawn + /// in a redeem call in the same transaction i.e. redeem should return the same or more assets + /// as preview_redeem if called in the same transaction. + /// MUST NOT account for redemption limits like those returned from max_redeem and should always + /// act as though the redemption would be accepted, regardless if the user has enough shares, + /// etc. + /// MUST be inclusive of withdrawal fees. Integrators should be aware of the existence of + /// withdrawal fees. + /// MUST NOT panic. + /// + /// NOTE: Any unfavorable discrepancy between `convert_to_assets` and `preview_redeem` SHOULD be + /// considered slippage in share price or some other type of condition, meaning the depositor + /// will lose assets by redeeming. + fn preview_redeem(self: @TState, shares: u256) -> u256; + + /// Burns exactly shares from owner and sends assets of underlying tokens to receiver. + /// + /// MUST emit the `Withdraw` event. + /// MAY support an additional flow in which the underlying tokens are owned by the Vault + /// contract before the redeem execution, and are accounted for during redeem. + /// MUST revert if all of shares cannot be redeemed (due to withdrawal limit being reached, + /// slippage, the owner not having enough shares, etc). + /// + /// NOTE: Some implementations will require pre-requesting to the Vault before a withdrawal may + /// be performed. + /// Those methods should be performed separately. + fn redeem( + ref self: TState, shares: u256, receiver: ContractAddress, owner: ContractAddress, + ) -> u256; +} + +#[starknet::interface] +pub trait ERC4626ABI { + // IERC4626 + fn asset(self: @TState) -> ContractAddress; + fn total_assets(self: @TState) -> u256; + fn convert_to_shares(self: @TState, assets: u256) -> u256; + fn convert_to_assets(self: @TState, shares: u256) -> u256; + fn max_deposit(self: @TState, receiver: ContractAddress) -> u256; + fn preview_deposit(self: @TState, assets: u256) -> u256; + fn deposit(ref self: TState, assets: u256, receiver: ContractAddress) -> u256; + fn max_mint(self: @TState, receiver: ContractAddress) -> u256; + fn preview_mint(self: @TState, shares: u256) -> u256; + fn mint(ref self: TState, shares: u256, receiver: ContractAddress) -> u256; + fn max_withdraw(self: @TState, owner: ContractAddress) -> u256; + fn preview_withdraw(self: @TState, assets: u256) -> u256; + fn withdraw( + ref self: TState, assets: u256, receiver: ContractAddress, owner: ContractAddress, + ) -> u256; + fn max_redeem(self: @TState, owner: ContractAddress) -> u256; + fn preview_redeem(self: @TState, shares: u256) -> u256; + fn redeem( + ref self: TState, shares: u256, receiver: ContractAddress, owner: ContractAddress, + ) -> u256; + + // IERC20 + fn total_supply(self: @TState) -> u256; + fn balance_of(self: @TState, account: ContractAddress) -> u256; + fn allowance(self: @TState, owner: ContractAddress, spender: ContractAddress) -> u256; + fn transfer(ref self: TState, recipient: ContractAddress, amount: u256) -> bool; + fn transfer_from( + ref self: TState, sender: ContractAddress, recipient: ContractAddress, amount: u256, + ) -> bool; + fn approve(ref self: TState, spender: ContractAddress, amount: u256) -> bool; + + // IERC20Metadata + fn name(self: @TState) -> ByteArray; + fn symbol(self: @TState) -> ByteArray; + fn decimals(self: @TState) -> u8; + + // IERC20CamelOnly + fn totalSupply(self: @TState) -> u256; + fn balanceOf(self: @TState, account: ContractAddress) -> u256; + fn transferFrom( + ref self: TState, sender: ContractAddress, recipient: ContractAddress, amount: u256, + ) -> bool; +} diff --git a/packages/token/src/tests.cairo b/packages/token/src/tests.cairo index fdb72f5ef..a9654141f 100644 --- a/packages/token/src/tests.cairo +++ b/packages/token/src/tests.cairo @@ -1,4 +1,5 @@ pub mod erc1155; pub mod erc20; pub mod erc2981; +pub mod erc4626; pub mod erc721; diff --git a/packages/token/src/tests/erc4626.cairo b/packages/token/src/tests/erc4626.cairo new file mode 100644 index 000000000..d7bf4e38b --- /dev/null +++ b/packages/token/src/tests/erc4626.cairo @@ -0,0 +1 @@ +mod test_erc4626; diff --git a/packages/token/src/tests/erc4626/test_erc4626.cairo b/packages/token/src/tests/erc4626/test_erc4626.cairo new file mode 100644 index 000000000..ab39fdbaa --- /dev/null +++ b/packages/token/src/tests/erc4626/test_erc4626.cairo @@ -0,0 +1,1620 @@ +use core::num::traits::{Bounded, Pow}; +use crate::erc20::ERC20Component::InternalImpl as ERC20InternalImpl; +use crate::erc20::extensions::erc4626::DefaultConfig; +use crate::erc20::extensions::erc4626::ERC4626Component; +use crate::erc20::extensions::erc4626::ERC4626Component::{Deposit, Withdraw}; +use crate::erc20::extensions::erc4626::ERC4626Component::{ + ERC4626Impl, ERC4626MetadataImpl, InternalImpl, +}; +use crate::erc20::extensions::erc4626::interface::{ERC4626ABIDispatcher, ERC4626ABIDispatcherTrait}; +use openzeppelin_test_common::erc20::ERC20SpyHelpers; +use openzeppelin_test_common::mocks::erc20::Type; +use openzeppelin_test_common::mocks::erc20::{ + IERC20ReentrantDispatcher, IERC20ReentrantDispatcherTrait, +}; +use openzeppelin_test_common::mocks::erc4626::ERC4626Mock; +use openzeppelin_testing as utils; +use openzeppelin_testing::constants::{NAME, OTHER, RECIPIENT, SPENDER, SYMBOL, ZERO}; +use openzeppelin_testing::events::EventSpyExt; +use openzeppelin_utils::serde::SerializedAppend; +use snforge_std::{CheatSpan, EventSpy, cheat_caller_address, spy_events}; +use starknet::{ContractAddress, contract_address_const}; + +fn ASSET() -> ContractAddress { + contract_address_const::<'ASSET'>() +} + +fn HOLDER() -> ContractAddress { + contract_address_const::<'HOLDER'>() +} + +fn TREASURY() -> ContractAddress { + contract_address_const::<'TREASURY'>() +} + +fn VAULT_NAME() -> ByteArray { + "VAULT" +} + +fn VAULT_SYMBOL() -> ByteArray { + "V" +} + +const DEFAULT_DECIMALS: u8 = 18; +const NO_OFFSET_DECIMALS: u8 = 0; +const OFFSET_DECIMALS: u8 = 1; + +fn parse_token(token: u256) -> u256 { + token * 10_u256.pow(DEFAULT_DECIMALS.into()) +} + +fn parse_share_offset(shares: u256) -> u256 { + shares * 10_u256.pow(DEFAULT_DECIMALS.into() + OFFSET_DECIMALS.into()) +} + +// +// Setup +// + +type ComponentState = ERC4626Component::ComponentState; + +fn COMPONENT_STATE() -> ComponentState { + ERC4626Component::component_state_for_testing() +} + +// +// Dispatchers +// + +fn deploy_asset() -> IERC20ReentrantDispatcher { + let mut asset_calldata: Array = array![]; + asset_calldata.append_serde(NAME()); + asset_calldata.append_serde(SYMBOL()); + + let contract_address = utils::declare_and_deploy("ERC20ReentrantMock", asset_calldata); + IERC20ReentrantDispatcher { contract_address } +} + +fn deploy_vault(asset_address: ContractAddress) -> ERC4626ABIDispatcher { + let no_shares = 0_u256; + + let mut vault_calldata: Array = array![]; + vault_calldata.append_serde(VAULT_NAME()); + vault_calldata.append_serde(VAULT_SYMBOL()); + vault_calldata.append_serde(asset_address); + vault_calldata.append_serde(no_shares); + vault_calldata.append_serde(HOLDER()); + + let contract_address = utils::declare_and_deploy("ERC4626Mock", vault_calldata); + ERC4626ABIDispatcher { contract_address } +} + +fn deploy_vault_offset_minted_shares( + asset_address: ContractAddress, shares: u256, recipient: ContractAddress, +) -> ERC4626ABIDispatcher { + let mut vault_calldata: Array = array![]; + vault_calldata.append_serde(VAULT_NAME()); + vault_calldata.append_serde(VAULT_SYMBOL()); + vault_calldata.append_serde(asset_address); + vault_calldata.append_serde(shares); + vault_calldata.append_serde(recipient); + + let contract_address = utils::declare_and_deploy("ERC4626OffsetMock", vault_calldata); + ERC4626ABIDispatcher { contract_address } +} + +fn deploy_vault_offset(asset_address: ContractAddress) -> ERC4626ABIDispatcher { + deploy_vault_offset_minted_shares(asset_address, 0, HOLDER()) +} + +fn deploy_vault_fees(asset_address: ContractAddress) -> ERC4626ABIDispatcher { + let no_shares = 0_u256; + deploy_vault_fees_with_shares(asset_address, no_shares, HOLDER()) +} + +fn deploy_vault_fees_with_shares( + asset_address: ContractAddress, shares: u256, recipient: ContractAddress, +) -> ERC4626ABIDispatcher { + let fee_basis_points = 500_u256; // 5% + + let mut vault_calldata: Array = array![]; + vault_calldata.append_serde(VAULT_NAME()); + vault_calldata.append_serde(VAULT_SYMBOL()); + vault_calldata.append_serde(asset_address); + vault_calldata.append_serde(shares); + vault_calldata.append_serde(recipient); + + // Enter fees + vault_calldata.append_serde(fee_basis_points); + vault_calldata.append_serde(TREASURY()); + // No exit fees + vault_calldata.append_serde(0_u256); + vault_calldata.append_serde(ZERO()); + + let contract_address = utils::declare_and_deploy("ERC4626FeesMock", vault_calldata); + ERC4626ABIDispatcher { contract_address } +} + +fn deploy_vault_exit_fees_with_shares( + asset_address: ContractAddress, shares: u256, recipient: ContractAddress, +) -> ERC4626ABIDispatcher { + let fee_basis_points = 500_u256; // 5% + + let mut vault_calldata: Array = array![]; + vault_calldata.append_serde(VAULT_NAME()); + vault_calldata.append_serde(VAULT_SYMBOL()); + vault_calldata.append_serde(asset_address); + vault_calldata.append_serde(shares); + vault_calldata.append_serde(recipient); + + // No enter fees + vault_calldata.append_serde(0_u256); + vault_calldata.append_serde(ZERO()); + // Exit fees + vault_calldata.append_serde(fee_basis_points); + vault_calldata.append_serde(TREASURY()); + + let contract_address = utils::declare_and_deploy("ERC4626FeesMock", vault_calldata); + ERC4626ABIDispatcher { contract_address } +} + +fn deploy_vault_limits(asset_address: ContractAddress) -> ERC4626ABIDispatcher { + let no_shares = 0_u256; + + let mut vault_calldata: Array = array![]; + vault_calldata.append_serde(VAULT_NAME()); + vault_calldata.append_serde(VAULT_SYMBOL()); + vault_calldata.append_serde(asset_address); + vault_calldata.append_serde(no_shares); + vault_calldata.append_serde(HOLDER()); + + let contract_address = utils::declare_and_deploy("ERC4626LimitsMock", vault_calldata); + ERC4626ABIDispatcher { contract_address } +} + +// +// initializer +// + +#[test] +#[should_panic(expected: 'ERC4626: asset address set to 0')] +fn test_initializer_zero_address_asset() { + let mut state = COMPONENT_STATE(); + + state.initializer(ZERO()); +} + +// +// asset +// + +#[test] +fn test_asset() { + let mut state = COMPONENT_STATE(); + + let asset_address = state.asset(); + assert_eq!(asset_address, ZERO()); + + state.initializer(ASSET()); + + let asset_address = state.asset(); + assert_eq!(asset_address, ASSET()); +} + +// +// Metadata +// + +#[test] +fn test_metadata() { + let asset = deploy_asset(); + let vault = deploy_vault(asset.contract_address); + + let name = vault.name(); + assert_eq!(name, VAULT_NAME()); + + let symbol = vault.symbol(); + assert_eq!(symbol, VAULT_SYMBOL()); + + let decimals = vault.decimals(); + assert_eq!(decimals, DEFAULT_DECIMALS + NO_OFFSET_DECIMALS); + + let asset_address = vault.asset(); + assert_eq!(asset_address, asset.contract_address); +} + +#[test] +fn test_decimals_offset() { + let asset = deploy_asset(); + let vault = deploy_vault_offset(asset.contract_address); + + let decimals = vault.decimals(); + assert_eq!(decimals, DEFAULT_DECIMALS + OFFSET_DECIMALS); +} + +// +// Empty vault: no assets, no shares +// + +fn setup_empty() -> (IERC20ReentrantDispatcher, ERC4626ABIDispatcher) { + let mut asset = deploy_asset(); + let mut vault = deploy_vault_offset(asset.contract_address); + + // Mint assets to HOLDER and approve vault + asset.unsafe_mint(HOLDER(), Bounded::MAX / 2); // 50% of max + cheat_caller_address(asset.contract_address, HOLDER(), CheatSpan::TargetCalls(1)); + asset.approve(vault.contract_address, Bounded::MAX); + + (asset, vault) +} + +#[test] +fn test_init_vault_status() { + let (_, vault) = setup_empty(); + let total_assets = vault.total_assets(); + + assert_eq!(total_assets, 0); +} + +#[test] +fn test_deposit() { + let (asset, vault) = setup_empty(); + let amount = parse_token(1); + + // Check max deposit + let max_deposit = vault.max_deposit(HOLDER()); + assert_eq!(max_deposit, Bounded::MAX); + + // Check preview == expected shares + let preview_deposit = vault.preview_deposit(amount); + let exp_shares = parse_share_offset(1); + assert_eq!(preview_deposit, exp_shares); + + let holder_balance_before = asset.balance_of(HOLDER()); + let mut spy = spy_events(); + + // Deposit + cheat_caller_address(vault.contract_address, HOLDER(), CheatSpan::TargetCalls(1)); + let shares = vault.deposit(amount, RECIPIENT()); + + // Check balances + let holder_balance_after = asset.balance_of(HOLDER()); + assert_eq!(holder_balance_after, holder_balance_before - amount); + + let recipient_shares = vault.balance_of(RECIPIENT()); + assert_eq!(recipient_shares, exp_shares); + + // Check events + spy.assert_event_transfer(asset.contract_address, HOLDER(), vault.contract_address, amount); + spy.assert_event_transfer(vault.contract_address, ZERO(), RECIPIENT(), shares); + spy.assert_only_event_deposit(vault.contract_address, HOLDER(), RECIPIENT(), amount, shares); +} + +#[test] +fn test_mint() { + let (asset, vault) = setup_empty(); + + // Check max mint + let max_mint = vault.max_mint(HOLDER()); + assert_eq!(max_mint, Bounded::MAX); + + // Check preview mint + let preview_mint = vault.preview_mint(parse_share_offset(1)); + let exp_assets = parse_token(1); + assert_eq!(preview_mint, exp_assets); + + let mut spy = spy_events(); + let holder_balance_before = asset.balance_of(HOLDER()); + + // Mint + cheat_caller_address(vault.contract_address, HOLDER(), CheatSpan::TargetCalls(1)); + vault.mint(parse_share_offset(1), RECIPIENT()); + + // Check balances + let holder_balance_after = asset.balance_of(HOLDER()); + assert_eq!(holder_balance_after, holder_balance_before - parse_token(1)); + + let recipient_shares = vault.balance_of(RECIPIENT()); + assert_eq!(recipient_shares, parse_share_offset(1)); + + // Check events + spy + .assert_event_transfer( + asset.contract_address, HOLDER(), vault.contract_address, parse_token(1), + ); + spy.assert_event_transfer(vault.contract_address, ZERO(), RECIPIENT(), parse_share_offset(1)); + spy + .assert_only_event_deposit( + vault.contract_address, HOLDER(), RECIPIENT(), parse_token(1), parse_share_offset(1), + ); +} + +#[test] +fn test_withdraw() { + let (asset, vault) = setup_empty(); + + // Check max mint + let max_withdraw = vault.max_withdraw(HOLDER()); + assert_eq!(max_withdraw, 0); + + // Check preview mint + let preview_withdraw = vault.preview_withdraw(0); + assert_eq!(preview_withdraw, 0); + + let mut spy = spy_events(); + + // Withdraw + cheat_caller_address(vault.contract_address, HOLDER(), CheatSpan::TargetCalls(1)); + vault.withdraw(0, RECIPIENT(), HOLDER()); + + // Check events + spy.assert_event_transfer(vault.contract_address, HOLDER(), ZERO(), 0); + spy.assert_event_transfer(asset.contract_address, vault.contract_address, RECIPIENT(), 0); + spy.assert_only_event_withdraw(vault.contract_address, HOLDER(), RECIPIENT(), HOLDER(), 0, 0); +} + +#[test] +fn test_redeem() { + let (asset, vault) = setup_empty(); + + // Check max redeem + let max_redeem = vault.max_redeem(HOLDER()); + assert_eq!(max_redeem, 0); + + // Check preview redeem + let preview_redeem = vault.preview_redeem(0); + assert_eq!(preview_redeem, 0); + + let mut spy = spy_events(); + + // Redeem + cheat_caller_address(vault.contract_address, HOLDER(), CheatSpan::TargetCalls(1)); + vault.redeem(0, RECIPIENT(), HOLDER()); + + // Check events + spy.assert_event_transfer(vault.contract_address, HOLDER(), ZERO(), 0); + spy.assert_event_transfer(asset.contract_address, vault.contract_address, RECIPIENT(), 0); + spy.assert_only_event_withdraw(vault.contract_address, HOLDER(), RECIPIENT(), HOLDER(), 0, 0); +} + +// +// Inflation attack: Offset price by direct deposit of assets +// + +fn setup_inflation_attack() -> (IERC20ReentrantDispatcher, ERC4626ABIDispatcher) { + let mut asset = deploy_asset(); + let mut vault = deploy_vault_offset(asset.contract_address); + + // Mint assets to HOLDER and approve vault + asset.unsafe_mint(HOLDER(), Bounded::MAX / 2); // 50% of max + cheat_caller_address(asset.contract_address, HOLDER(), CheatSpan::TargetCalls(1)); + asset.approve(vault.contract_address, Bounded::MAX); + + // Donate 1 token to the vault to offset the price + asset.unsafe_mint(vault.contract_address, parse_token(1)); + + (asset, vault) +} + +#[test] +fn test_inflation_attack_status() { + let (_, vault) = setup_inflation_attack(); + + let total_supply = vault.total_supply(); + assert_eq!(total_supply, 0); + + let total_assets = vault.total_assets(); + assert_eq!(total_assets, parse_token(1)); +} + +#[test] +fn test_inflation_attack_deposit() { + let (asset, vault) = setup_inflation_attack(); + let virtual_assets = 1; + let offset = 1; + let virtual_shares = 10_u256.pow(offset); + + let effective_assets = vault.total_assets() + virtual_assets; + let effective_shares = vault.total_supply() + virtual_shares; + + let deposit_assets = parse_token(1); + let expected_shares = (deposit_assets * effective_shares) / effective_assets; + + // Check max deposit + let max_deposit = vault.max_deposit(HOLDER()); + assert_eq!(max_deposit, Bounded::MAX); + + // Check preview deposit + let preview_deposit = vault.preview_deposit(deposit_assets); + assert_eq!(preview_deposit, expected_shares); + + // Before deposit + let holder_balance_before = asset.balance_of(HOLDER()); + let mut spy = spy_events(); + + // Deposit + cheat_caller_address(vault.contract_address, HOLDER(), CheatSpan::TargetCalls(1)); + let shares = vault.deposit(deposit_assets, RECIPIENT()); + + // After deposit + let holder_balance_after = asset.balance_of(HOLDER()); + assert_eq!(holder_balance_after, holder_balance_before - deposit_assets); + + // Check recipient shares + let recipient_balance = vault.balance_of(RECIPIENT()); + assert_eq!(recipient_balance, expected_shares); + + // Check events + spy + .assert_event_transfer( + asset.contract_address, HOLDER(), vault.contract_address, deposit_assets, + ); + spy.assert_event_transfer(vault.contract_address, ZERO(), RECIPIENT(), shares); + spy + .assert_only_event_deposit( + vault.contract_address, HOLDER(), RECIPIENT(), deposit_assets, expected_shares, + ); +} + +#[test] +fn test_inflation_attack_mint() { + let (asset, vault) = setup_inflation_attack(); + let virtual_assets = 1; + let offset = 1; + let virtual_shares = 10_u256.pow(offset); + + let effective_assets = vault.total_assets() + virtual_assets; + let effective_shares = vault.total_supply() + virtual_shares; + + let mint_shares = parse_share_offset(1); + let expected_assets = (mint_shares * effective_assets) / effective_shares; + + // Check max mint + let max_mint = vault.max_mint(HOLDER()); + assert_eq!(max_mint, Bounded::MAX); + + // Check preview mint + let preview_mint = vault.preview_mint(mint_shares); + assert_eq!(preview_mint, expected_assets); + + // Capture initial balances + let holder_balance_before = asset.balance_of(HOLDER()); + let vault_balance_before = asset.balance_of(vault.contract_address); + + // Mint + let mut spy = spy_events(); + cheat_caller_address(vault.contract_address, HOLDER(), CheatSpan::TargetCalls(1)); + vault.mint(mint_shares, RECIPIENT()); + + // Check balances + assert_expected_assets(asset, HOLDER(), holder_balance_before - expected_assets); + assert_expected_assets(asset, vault.contract_address, vault_balance_before + expected_assets); + assert_expected_shares(vault, RECIPIENT(), parse_share_offset(1)); + + // Check events + spy + .assert_event_transfer( + asset.contract_address, HOLDER(), vault.contract_address, expected_assets, + ); + spy.assert_event_transfer(vault.contract_address, ZERO(), RECIPIENT(), mint_shares); + spy + .assert_only_event_deposit( + vault.contract_address, HOLDER(), RECIPIENT(), expected_assets, mint_shares, + ); +} + +#[test] +fn test_inflation_attack_withdraw() { + let (asset, vault) = setup_inflation_attack(); + + // Check max withdraw + let max_withdraw = vault.max_withdraw(HOLDER()); + assert_eq!(max_withdraw, 0); + + // Check preview withdraw + let preview_withdraw = vault.preview_withdraw(0); + assert_eq!(preview_withdraw, 0); + + // Capture initial balances + let holder_balance_before = asset.balance_of(HOLDER()); + let vault_balance_before = asset.balance_of(vault.contract_address); + + // Withdraw + let mut spy = spy_events(); + cheat_caller_address(vault.contract_address, HOLDER(), CheatSpan::TargetCalls(1)); + vault.withdraw(0, RECIPIENT(), HOLDER()); + + // Check balances and events + assert_expected_assets(asset, HOLDER(), holder_balance_before); + assert_expected_assets(asset, vault.contract_address, vault_balance_before); + + spy.assert_event_transfer(vault.contract_address, HOLDER(), ZERO(), 0); + spy.assert_event_transfer(asset.contract_address, vault.contract_address, RECIPIENT(), 0); + spy.assert_only_event_withdraw(vault.contract_address, HOLDER(), RECIPIENT(), HOLDER(), 0, 0); +} + +#[test] +fn test_inflation_attack_redeem() { + let (asset, vault) = setup_inflation_attack(); + + // Check max redeem + let max_redeem = vault.max_redeem(HOLDER()); + assert_eq!(max_redeem, 0); + + // Check preview redeem + let preview_redeem = vault.preview_redeem(0); + assert_eq!(preview_redeem, 0); + + // Redeem + let mut spy = spy_events(); + cheat_caller_address(vault.contract_address, HOLDER(), CheatSpan::TargetCalls(1)); + vault.redeem(0, RECIPIENT(), HOLDER()); + + // Check events + spy.assert_event_transfer(vault.contract_address, HOLDER(), ZERO(), 0); + spy.assert_event_transfer(asset.contract_address, vault.contract_address, RECIPIENT(), 0); + spy.assert_only_event_withdraw(vault.contract_address, HOLDER(), RECIPIENT(), HOLDER(), 0, 0); +} + +// +// Full vault: Assets and shares +// + +fn setup_full_vault() -> (IERC20ReentrantDispatcher, ERC4626ABIDispatcher) { + let mut asset = deploy_asset(); + + let shares = parse_share_offset(100); + let recipient = HOLDER(); + + // Add 1 token of underlying asset and 100 shares to the vault + let mut vault = deploy_vault_offset_minted_shares(asset.contract_address, shares, recipient); + asset.unsafe_mint(vault.contract_address, parse_token(1)); + + // Approve SPENDER + cheat_caller_address(vault.contract_address, HOLDER(), CheatSpan::TargetCalls(1)); + vault.approve(SPENDER(), Bounded::MAX); + + // Mint assets to HOLDER, approve vault + asset.unsafe_mint(HOLDER(), Bounded::MAX / 2); // 50% of max + cheat_caller_address(asset.contract_address, HOLDER(), CheatSpan::TargetCalls(1)); + asset.approve(vault.contract_address, Bounded::MAX); + + (asset, vault) +} + +#[test] +fn test_full_vault_status() { + let (_, vault) = setup_full_vault(); + + let total_supply = vault.total_supply(); + assert_eq!(total_supply, parse_share_offset(100)); + + let total_assets = vault.total_assets(); + assert_eq!(total_assets, parse_token(1)); +} + +#[test] +fn test_full_vault_deposit() { + let (asset, vault) = setup_full_vault(); + + let virtual_assets = 1; + let offset = 1; + let virtual_shares = 10_u256.pow(offset); + + let effective_assets = vault.total_assets() + virtual_assets; + let effective_shares = vault.total_supply() + virtual_shares; + + let deposit_assets = parse_token(1); + let expected_shares = (deposit_assets * effective_shares) / effective_assets; + + // Check max deposit + let max_deposit = vault.max_deposit(HOLDER()); + assert_eq!(max_deposit, Bounded::MAX); + + // Check preview deposit + let preview_deposit = vault.preview_deposit(deposit_assets); + assert_eq!(preview_deposit, expected_shares); + + // Before deposit + let holder_balance_before = asset.balance_of(HOLDER()); + + // Deposit + let mut spy = spy_events(); + cheat_caller_address(vault.contract_address, HOLDER(), CheatSpan::TargetCalls(1)); + let shares = vault.deposit(deposit_assets, RECIPIENT()); + + // After deposit + let holder_balance_after = asset.balance_of(HOLDER()); + assert_eq!(holder_balance_after, holder_balance_before - deposit_assets); + + // Check recipient shares + let recipient_balance = vault.balance_of(RECIPIENT()); + assert_eq!(recipient_balance, expected_shares); + + // Check events + spy + .assert_event_transfer( + asset.contract_address, HOLDER(), vault.contract_address, deposit_assets, + ); + spy.assert_event_transfer(vault.contract_address, ZERO(), RECIPIENT(), shares); + spy + .assert_only_event_deposit( + vault.contract_address, HOLDER(), RECIPIENT(), deposit_assets, expected_shares, + ); +} + +#[test] +fn test_full_vault_mint() { + let (asset, vault) = setup_full_vault(); + + let virtual_assets = 1; + let offset = 1; + let virtual_shares = 10_u256.pow(offset); + + let effective_assets = vault.total_assets() + virtual_assets; + let effective_shares = vault.total_supply() + virtual_shares; + + let mint_shares = parse_share_offset(1); + let expected_assets = (mint_shares * effective_assets) / effective_shares + + 1; // add `1` for the rounding + + // Check max mint + let max_mint = vault.max_mint(HOLDER()); + assert_eq!(max_mint, Bounded::MAX); + + // Check preview mint + let preview_mint = vault.preview_mint(mint_shares); + assert_eq!(preview_mint, expected_assets); + + // Capture initial balances + let holder_balance_before = asset.balance_of(HOLDER()); + let vault_balance_before = asset.balance_of(vault.contract_address); + + // Mint + let mut spy = spy_events(); + cheat_caller_address(vault.contract_address, HOLDER(), CheatSpan::TargetCalls(1)); + vault.mint(mint_shares, RECIPIENT()); + + // Check balances + assert_expected_assets(asset, HOLDER(), holder_balance_before - expected_assets); + assert_expected_assets(asset, vault.contract_address, vault_balance_before + expected_assets); + assert_expected_shares(vault, RECIPIENT(), parse_share_offset(1)); + + // Check events + spy + .assert_event_transfer( + asset.contract_address, HOLDER(), vault.contract_address, expected_assets, + ); + spy.assert_event_transfer(vault.contract_address, ZERO(), RECIPIENT(), mint_shares); + spy + .assert_only_event_deposit( + vault.contract_address, HOLDER(), RECIPIENT(), expected_assets, mint_shares, + ); +} + +#[test] +fn test_full_vault_withdraw() { + let (asset, vault) = setup_full_vault(); + + let virtual_assets = 1; + let offset = 1; + let virtual_shares = 10_u256.pow(offset); + + let effective_assets = vault.total_assets() + virtual_assets; + let effective_shares = vault.total_supply() + virtual_shares; + + let withdraw_assets = parse_token(1); + let expected_shares = (withdraw_assets * effective_shares) / effective_assets + + 1; // add `1` for the rounding + + // Check max withdraw + let max_withdraw = vault.max_withdraw(HOLDER()); + assert_eq!(max_withdraw, withdraw_assets); + + // Check preview withdraw + let preview_withdraw = vault.preview_withdraw(withdraw_assets); + assert_eq!(preview_withdraw, expected_shares); + + // Capture initial balances + let holder_balance_before = asset.balance_of(HOLDER()); + let vault_balance_before = asset.balance_of(vault.contract_address); + + // Withdraw + let mut spy = spy_events(); + cheat_caller_address(vault.contract_address, HOLDER(), CheatSpan::TargetCalls(1)); + vault.withdraw(withdraw_assets, RECIPIENT(), HOLDER()); + + // Check balances and events + assert_expected_assets(asset, HOLDER(), holder_balance_before); + assert_expected_assets(asset, RECIPIENT(), withdraw_assets); + assert_expected_assets(asset, vault.contract_address, vault_balance_before - withdraw_assets); + + spy.assert_event_transfer(vault.contract_address, HOLDER(), ZERO(), expected_shares); + spy + .assert_event_transfer( + asset.contract_address, vault.contract_address, RECIPIENT(), withdraw_assets, + ); + spy + .assert_only_event_withdraw( + vault.contract_address, + HOLDER(), + RECIPIENT(), + HOLDER(), + withdraw_assets, + expected_shares, + ); +} + +#[test] +fn test_full_vault_withdraw_with_approval() { + let (asset, vault) = setup_full_vault(); + + let virtual_assets = 1; + let offset = 1; + let virtual_shares = 10_u256.pow(offset); + + let effective_assets = vault.total_assets() + virtual_assets; + let effective_shares = vault.total_supply() + virtual_shares; + + let withdraw_assets = parse_token(1); + let expected_shares = (withdraw_assets * effective_shares) / effective_assets + + 1; // add `1` for the rounding + + // Withdraw + let mut spy = spy_events(); + cheat_caller_address(vault.contract_address, SPENDER(), CheatSpan::TargetCalls(1)); + vault.withdraw(withdraw_assets, RECIPIENT(), HOLDER()); + + // Check events + spy.assert_event_transfer(vault.contract_address, HOLDER(), ZERO(), expected_shares); + spy + .assert_event_transfer( + asset.contract_address, vault.contract_address, RECIPIENT(), withdraw_assets, + ); + spy + .assert_only_event_withdraw( + vault.contract_address, + SPENDER(), + RECIPIENT(), + HOLDER(), + withdraw_assets, + expected_shares, + ); +} + +#[test] +#[should_panic(expected: 'ERC20: insufficient allowance')] +fn test_full_vault_withdraw_unauthorized() { + let (_, vault) = setup_full_vault(); + let withdraw_assets = parse_token(1); + + cheat_caller_address(vault.contract_address, OTHER(), CheatSpan::TargetCalls(1)); + vault.withdraw(withdraw_assets, RECIPIENT(), HOLDER()); +} + +#[test] +fn test_full_vault_redeem() { + let (asset, vault) = setup_full_vault(); + + let virtual_assets = 1; + let offset = 1; + let virtual_shares = 10_u256.pow(offset); + + let effective_assets = vault.total_assets() + virtual_assets; + let effective_shares = vault.total_supply() + virtual_shares; + + let redeem_shares = parse_share_offset(100); + let expected_assets = (redeem_shares * effective_assets) / effective_shares; + + // Check max redeem + let max_redeem = vault.max_redeem(HOLDER()); + assert_eq!(max_redeem, redeem_shares); + + // Check preview redeem + let preview_redeem = vault.preview_redeem(redeem_shares); + assert_eq!(preview_redeem, expected_assets); + + // Capture initial balances + let vault_balance_before = asset.balance_of(vault.contract_address); + let holder_shares_before = vault.balance_of(HOLDER()); + + // Redeem + let mut spy = spy_events(); + cheat_caller_address(vault.contract_address, HOLDER(), CheatSpan::TargetCalls(1)); + vault.redeem(redeem_shares, RECIPIENT(), HOLDER()); + + // Check balances and events + assert_expected_assets(asset, RECIPIENT(), expected_assets); + assert_expected_assets(asset, vault.contract_address, vault_balance_before - expected_assets); + assert_expected_shares(vault, HOLDER(), holder_shares_before - redeem_shares); + + spy.assert_event_transfer(vault.contract_address, HOLDER(), ZERO(), redeem_shares); + spy + .assert_event_transfer( + asset.contract_address, vault.contract_address, RECIPIENT(), expected_assets, + ); + spy + .assert_only_event_withdraw( + vault.contract_address, HOLDER(), RECIPIENT(), HOLDER(), expected_assets, redeem_shares, + ); +} + +#[test] +fn test_full_vault_redeem_with_approval() { + let (asset, vault) = setup_full_vault(); + + let virtual_assets = 1; + let offset = 1; + let virtual_shares = 10_u256.pow(offset); + + let effective_assets = vault.total_assets() + virtual_assets; + let effective_shares = vault.total_supply() + virtual_shares; + + let redeem_shares = parse_share_offset(100); + let expected_assets = (redeem_shares * effective_assets) / effective_shares; + + // Check max redeem + let max_redeem = vault.max_redeem(HOLDER()); + assert_eq!(max_redeem, redeem_shares); + + // Check preview redeem + let preview_redeem = vault.preview_redeem(redeem_shares); + assert_eq!(preview_redeem, expected_assets); + + // Capture initial balances + let vault_balance_before = asset.balance_of(vault.contract_address); + let holder_shares_before = vault.balance_of(HOLDER()); + + // Redeem from SPENDER + let mut spy = spy_events(); + cheat_caller_address(vault.contract_address, SPENDER(), CheatSpan::TargetCalls(1)); + vault.redeem(redeem_shares, RECIPIENT(), HOLDER()); + + // Check balances and events + assert_expected_assets(asset, RECIPIENT(), expected_assets); + assert_expected_assets(asset, vault.contract_address, vault_balance_before - expected_assets); + assert_expected_shares(vault, HOLDER(), holder_shares_before - redeem_shares); + + spy.assert_event_transfer(vault.contract_address, HOLDER(), ZERO(), redeem_shares); + spy + .assert_event_transfer( + asset.contract_address, vault.contract_address, RECIPIENT(), expected_assets, + ); + spy + .assert_only_event_withdraw( + vault.contract_address, + SPENDER(), + RECIPIENT(), + HOLDER(), + expected_assets, + redeem_shares, + ); +} + +#[test] +#[should_panic(expected: 'ERC20: insufficient allowance')] +fn test_full_vault_redeem_unauthorized() { + let (_, vault) = setup_full_vault(); + let redeem_shares = parse_share_offset(100); + + // Unauthorized redeem + cheat_caller_address(vault.contract_address, OTHER(), CheatSpan::TargetCalls(1)); + vault.redeem(redeem_shares, RECIPIENT(), HOLDER()); +} + +// +// Reentrancy +// + +fn setup_reentrancy() -> (IERC20ReentrantDispatcher, ERC4626ABIDispatcher) { + let mut asset = deploy_asset(); + let mut vault = deploy_vault_offset(asset.contract_address); + + let value: u256 = 1_000_000_000_000_000_000; + asset.unsafe_mint(HOLDER(), value); + asset.unsafe_mint(OTHER(), value); + + // Set infinite approvals from HOLDER, OTHER, and asset to vault + let approvers: Span = array![HOLDER(), OTHER(), asset.contract_address].span(); + for addr in approvers { + cheat_caller_address(asset.contract_address, *addr, CheatSpan::TargetCalls(1)); + asset.approve(vault.contract_address, Bounded::MAX); + }; + + (asset, vault) +} + +#[test] +fn test_share_price_with_reentrancy_before_deposit() { + let (asset, vault) = setup_reentrancy(); + + let value = 1_000_000_000_000_000_000; + let reenter_value = 1_000_000_000; + + asset.unsafe_mint(asset.contract_address, reenter_value); + + // Schedule reentrancy + let mut calldata: Array = array![]; + calldata.append_serde(reenter_value); + calldata.append_serde(HOLDER()); + asset + .schedule_reenter( + Type::Before, vault.contract_address, selector!("deposit"), calldata.span(), + ); + + let shares_for_deposit = vault.preview_deposit(value); + let shares_for_reenter = vault.preview_deposit(reenter_value); + + // Deposit + let mut spy = spy_events(); + cheat_caller_address(vault.contract_address, HOLDER(), CheatSpan::TargetCalls(1)); + vault.deposit(value, HOLDER()); + + // Check price is kept + let after_deposit = vault.preview_deposit(value); + assert_eq!(shares_for_deposit, after_deposit); + + // Check events + // Reentered events come first because they're called in mock ERC20 `before_update` hook + spy + .assert_event_transfer( + asset.contract_address, asset.contract_address, vault.contract_address, reenter_value, + ); + spy.assert_event_transfer(vault.contract_address, ZERO(), HOLDER(), shares_for_reenter); + spy + .assert_event_deposit( + vault.contract_address, + asset.contract_address, + HOLDER(), + reenter_value, + shares_for_reenter, + ); + + spy.assert_event_transfer(asset.contract_address, HOLDER(), vault.contract_address, value); + spy.assert_event_transfer(vault.contract_address, ZERO(), HOLDER(), shares_for_deposit); + spy + .assert_only_event_deposit( + vault.contract_address, HOLDER(), HOLDER(), value, shares_for_deposit, + ); +} + +#[test] +fn test_share_price_with_reentrancy_after_withdraw() { + let (asset, vault) = setup_reentrancy(); + + let value = 1_000_000_000_000_000_000; + let reenter_value = 1_000_000_000; + + // Deposit from HOLDER and OTHER + cheat_caller_address(vault.contract_address, HOLDER(), CheatSpan::TargetCalls(1)); + vault.deposit(value, HOLDER()); + + cheat_caller_address(vault.contract_address, OTHER(), CheatSpan::TargetCalls(1)); + vault.deposit(reenter_value, asset.contract_address); + + // Schedule reentrancy + let mut calldata: Array = array![]; + calldata.append_serde(reenter_value); + calldata.append_serde(HOLDER()); + calldata.append_serde(asset.contract_address); + asset + .schedule_reenter( + Type::After, vault.contract_address, selector!("withdraw"), calldata.span(), + ); + + let shares_for_withdraw = vault.preview_withdraw(value); + let shares_for_reenter = vault.preview_withdraw(reenter_value); + + // Withdraw + let mut spy = spy_events(); + cheat_caller_address(vault.contract_address, HOLDER(), CheatSpan::TargetCalls(1)); + vault.withdraw(value, HOLDER(), HOLDER()); + + // Check price is kept + let after_withdraw = vault.preview_withdraw(value); + assert_eq!(shares_for_withdraw, after_withdraw); + + // Main withdraw event + spy + .assert_event_withdraw( + vault.contract_address, HOLDER(), HOLDER(), HOLDER(), value, shares_for_withdraw, + ); + // Reentrant withdraw event → uses same price + spy + .assert_event_withdraw( + vault.contract_address, + asset.contract_address, + HOLDER(), + asset.contract_address, + reenter_value, + shares_for_reenter, + ); +} + +#[test] +fn test_price_change_during_reentrancy_doesnt_affect_deposit() { + let (asset, vault) = setup_reentrancy(); + + let value: u256 = 1_000_000_000_000_000_000; + let reenter_value: u256 = 1_000_000_000; + + // Schedules a reentrancy from the token contract that messes up the share price + let mut calldata: Array = array![]; + calldata.append_serde(vault.contract_address); + calldata.append_serde(reenter_value); + asset + .schedule_reenter( + Type::Before, asset.contract_address, selector!("unsafe_mint"), calldata.span(), + ); + + let shares_before = vault.preview_deposit(value); + + // Deposit + let mut spy = spy_events(); + cheat_caller_address(vault.contract_address, HOLDER(), CheatSpan::TargetCalls(1)); + vault.deposit(value, HOLDER()); + + // Check main event to ensure price is as previewed + spy.assert_event_deposit(vault.contract_address, HOLDER(), HOLDER(), value, shares_before); + + // Check that price is modified after reentrant tx + let shares_after = vault.preview_deposit(value); + assert(shares_after < shares_before, 'Mint should change share price'); +} + +#[test] +fn test_price_change_during_reentrancy_doesnt_affect_withdraw() { + let (asset, vault) = setup_reentrancy(); + + let value: u256 = 1_000_000_000_000_000_000; + let reenter_value: u256 = 1_000_000_000; + + cheat_caller_address(vault.contract_address, HOLDER(), CheatSpan::TargetCalls(1)); + vault.deposit(value, HOLDER()); + cheat_caller_address(vault.contract_address, OTHER(), CheatSpan::TargetCalls(1)); + vault.deposit(value, OTHER()); + + // Schedules a reentrancy from the token contract that messes up the share price + let mut calldata: Array = array![]; + calldata.append_serde(vault.contract_address); + calldata.append_serde(reenter_value); + asset + .schedule_reenter( + Type::After, asset.contract_address, selector!("unsafe_burn"), calldata.span(), + ); + + let shares_before = vault.preview_withdraw(value); + + // Withdraw, triggering ERC20 `after_update` hook + let mut spy = spy_events(); + cheat_caller_address(vault.contract_address, HOLDER(), CheatSpan::TargetCalls(1)); + vault.withdraw(value, HOLDER(), HOLDER()); + + // Check main event to ensure price is as previewed + spy + .assert_event_withdraw( + vault.contract_address, HOLDER(), HOLDER(), HOLDER(), value, shares_before, + ); + + // Check that price is modified after reentrant tx + let shares_after = vault.preview_withdraw(value); + assert(shares_after > shares_before, 'Burn should change share price'); +} + +// +// Limits +// + +fn setup_limits() -> (IERC20ReentrantDispatcher, ERC4626ABIDispatcher) { + let mut asset = deploy_asset(); + let mut vault = deploy_vault_limits(asset.contract_address); + + (asset, vault) +} + +#[test] +#[should_panic(expected: 'ERC4626: exceeds max deposit')] +fn test_max_limit_deposit() { + let (_, vault) = setup_limits(); + + let max_deposit = vault.max_deposit(HOLDER()); + cheat_caller_address(vault.contract_address, HOLDER(), CheatSpan::TargetCalls(1)); + vault.deposit(max_deposit + 1, HOLDER()); +} + +#[test] +#[should_panic(expected: 'ERC4626: exceeds max mint')] +fn test_max_limit_mint() { + let (_, vault) = setup_limits(); + + let max_mint = vault.max_mint(HOLDER()); + cheat_caller_address(vault.contract_address, HOLDER(), CheatSpan::TargetCalls(1)); + vault.mint(max_mint + 1, HOLDER()); +} + +#[test] +#[should_panic(expected: 'ERC4626: exceeds max withdraw')] +fn test_max_limit_withdraw() { + let (_, vault) = setup_limits(); + + let max_withdraw = vault.max_redeem(HOLDER()); + cheat_caller_address(vault.contract_address, HOLDER(), CheatSpan::TargetCalls(1)); + vault.withdraw(max_withdraw + 1, HOLDER(), HOLDER()); +} + +#[test] +#[should_panic(expected: 'ERC4626: exceeds max redeem')] +fn test_max_limit_redeem() { + let (_, vault) = setup_limits(); + + let max_redeem = vault.max_redeem(HOLDER()); + cheat_caller_address(vault.contract_address, HOLDER(), CheatSpan::TargetCalls(1)); + vault.redeem(max_redeem + 1, HOLDER(), HOLDER()); +} + +// +// Fees +// + +fn setup_input_fees() -> (IERC20ReentrantDispatcher, ERC4626ABIDispatcher) { + let mut asset = deploy_asset(); + let mut vault = deploy_vault_fees(asset.contract_address); + + let half_max: u256 = Bounded::MAX / 2; + asset.unsafe_mint(HOLDER(), half_max); + + cheat_caller_address(asset.contract_address, HOLDER(), CheatSpan::TargetCalls(1)); + asset.approve(vault.contract_address, half_max); + + (asset, vault) +} + +fn setup_output_fees() -> (IERC20ReentrantDispatcher, ERC4626ABIDispatcher) { + let mut asset = deploy_asset(); + let half_max: u256 = Bounded::MAX / 2; + + // Mint shares to HOLDER + let mut vault = deploy_vault_exit_fees_with_shares(asset.contract_address, half_max, HOLDER()); + + // Mint assets to vault + asset.unsafe_mint(vault.contract_address, half_max); + + (asset, vault) +} + +#[test] +fn test_input_fees_deposit() { + let (asset, vault) = setup_input_fees(); + + let FEE_BASIS_POINTS: u256 = 500; // 5% + let VALUE_WITHOUT_FEES: u256 = 10_000; + let FEES = (VALUE_WITHOUT_FEES * FEE_BASIS_POINTS) / 10_000; + let VALUE_WITH_FEES = VALUE_WITHOUT_FEES + FEES; + + let actual_value = vault.preview_deposit(VALUE_WITH_FEES); + assert_eq!(actual_value, VALUE_WITHOUT_FEES); + + let holder_asset_bal = asset.balance_of(HOLDER()); + let vault_asset_bal = asset.balance_of(vault.contract_address); + + let mut spy = spy_events(); + cheat_caller_address(vault.contract_address, HOLDER(), CheatSpan::TargetCalls(1)); + vault.deposit(VALUE_WITH_FEES, RECIPIENT()); + + // Check asset balances + assert_expected_assets(asset, HOLDER(), holder_asset_bal - VALUE_WITH_FEES); + assert_expected_assets(asset, vault.contract_address, vault_asset_bal + VALUE_WITHOUT_FEES); + assert_expected_assets(asset, TREASURY(), FEES); + + // Check shares + assert_expected_shares(vault, RECIPIENT(), VALUE_WITHOUT_FEES); + + // Check events + spy + .assert_event_transfer( + asset.contract_address, HOLDER(), vault.contract_address, VALUE_WITH_FEES, + ); + spy.assert_event_transfer(vault.contract_address, ZERO(), RECIPIENT(), VALUE_WITHOUT_FEES); + spy + .assert_event_deposit( + vault.contract_address, HOLDER(), RECIPIENT(), VALUE_WITH_FEES, VALUE_WITHOUT_FEES, + ); + spy.assert_event_transfer(asset.contract_address, vault.contract_address, TREASURY(), FEES); +} + +#[test] +fn test_input_fees_mint() { + let (asset, vault) = setup_input_fees(); + + let FEE_BASIS_POINTS: u256 = 500; // 5% + let VALUE_WITHOUT_FEES: u256 = 10_000; + let FEES = (VALUE_WITHOUT_FEES * FEE_BASIS_POINTS) / 10_000; + let VALUE_WITH_FEES = VALUE_WITHOUT_FEES + FEES; + + let actual_value = vault.preview_mint(VALUE_WITHOUT_FEES); + assert_eq!(actual_value, VALUE_WITH_FEES); + + let holder_asset_bal = asset.balance_of(HOLDER()); + let vault_asset_bal = asset.balance_of(vault.contract_address); + + let mut spy = spy_events(); + cheat_caller_address(vault.contract_address, HOLDER(), CheatSpan::TargetCalls(1)); + vault.mint(VALUE_WITHOUT_FEES, RECIPIENT()); + + // Check asset balances + assert_expected_assets(asset, HOLDER(), holder_asset_bal - VALUE_WITH_FEES); + assert_expected_assets(asset, vault.contract_address, vault_asset_bal + VALUE_WITHOUT_FEES); + assert_expected_assets(asset, TREASURY(), FEES); + + // Check shares + assert_expected_shares(vault, RECIPIENT(), VALUE_WITHOUT_FEES); + + // Check events + spy + .assert_event_transfer( + asset.contract_address, HOLDER(), vault.contract_address, VALUE_WITH_FEES, + ); + spy.assert_event_transfer(vault.contract_address, ZERO(), RECIPIENT(), VALUE_WITHOUT_FEES); + spy + .assert_event_deposit( + vault.contract_address, HOLDER(), RECIPIENT(), VALUE_WITH_FEES, VALUE_WITHOUT_FEES, + ); + spy.assert_event_transfer(asset.contract_address, vault.contract_address, TREASURY(), FEES); +} + +#[test] +fn test_output_fees_redeem() { + let (asset, vault) = setup_output_fees(); + + let FEE_BASIS_POINTS: u256 = 500; // 5% + let VALUE_WITHOUT_FEES: u256 = 10_000; + let FEES = (VALUE_WITHOUT_FEES * FEE_BASIS_POINTS) / 10_000; + let VALUE_WITH_FEES = VALUE_WITHOUT_FEES + FEES; + + let preview_redeem = vault.preview_redeem(VALUE_WITH_FEES); + assert_eq!(preview_redeem, VALUE_WITHOUT_FEES); + + let vault_asset_bal = asset.balance_of(vault.contract_address); + let recipient_asset_bal = asset.balance_of(RECIPIENT()); + let treasury_asset_bal = asset.balance_of(TREASURY()); + let holder_shares = vault.balance_of(HOLDER()); + + let mut spy = spy_events(); + cheat_caller_address(vault.contract_address, HOLDER(), CheatSpan::TargetCalls(1)); + vault.redeem(VALUE_WITH_FEES, RECIPIENT(), HOLDER()); + + // Check asset balances + assert_expected_assets(asset, vault.contract_address, vault_asset_bal - VALUE_WITH_FEES); + assert_expected_assets(asset, RECIPIENT(), recipient_asset_bal + VALUE_WITHOUT_FEES); + assert_expected_assets(asset, TREASURY(), treasury_asset_bal + FEES); + + // Check shares + assert_expected_shares(vault, HOLDER(), holder_shares - VALUE_WITH_FEES); + + // Check events + spy.assert_event_transfer(asset.contract_address, vault.contract_address, TREASURY(), FEES); + spy.assert_event_transfer(vault.contract_address, HOLDER(), ZERO(), VALUE_WITH_FEES); + spy + .assert_event_transfer( + asset.contract_address, vault.contract_address, RECIPIENT(), VALUE_WITHOUT_FEES, + ); + spy + .assert_only_event_withdraw( + vault.contract_address, + HOLDER(), + RECIPIENT(), + HOLDER(), + VALUE_WITHOUT_FEES, + VALUE_WITH_FEES, + ); +} + +#[test] +fn test_output_fees_withdraw() { + let (asset, vault) = setup_output_fees(); + + let FEE_BASIS_POINTS: u256 = 500; // 5% + let VALUE_WITHOUT_FEES: u256 = 10_000; + let FEES = (VALUE_WITHOUT_FEES * FEE_BASIS_POINTS) / 10_000; + let VALUE_WITH_FEES = VALUE_WITHOUT_FEES + FEES; + + let preview_withdraw = vault.preview_withdraw(VALUE_WITHOUT_FEES); + assert_eq!(preview_withdraw, VALUE_WITH_FEES); + + let vault_asset_bal = asset.balance_of(vault.contract_address); + let recipient_asset_bal = asset.balance_of(RECIPIENT()); + let treasury_asset_bal = asset.balance_of(TREASURY()); + let holder_shares = vault.balance_of(HOLDER()); + + let mut spy = spy_events(); + cheat_caller_address(vault.contract_address, HOLDER(), CheatSpan::TargetCalls(1)); + vault.withdraw(VALUE_WITHOUT_FEES, RECIPIENT(), HOLDER()); + + // Check asset balances + assert_expected_assets(asset, vault.contract_address, vault_asset_bal - VALUE_WITH_FEES); + assert_expected_assets(asset, RECIPIENT(), recipient_asset_bal + VALUE_WITHOUT_FEES); + assert_expected_assets(asset, TREASURY(), treasury_asset_bal + FEES); + + // Check shares + assert_expected_shares(vault, HOLDER(), holder_shares - VALUE_WITH_FEES); + + // Check events + spy.assert_event_transfer(asset.contract_address, vault.contract_address, TREASURY(), FEES); + spy.assert_event_transfer(vault.contract_address, HOLDER(), ZERO(), VALUE_WITH_FEES); + spy + .assert_event_transfer( + asset.contract_address, vault.contract_address, RECIPIENT(), VALUE_WITHOUT_FEES, + ); + spy + .assert_only_event_withdraw( + vault.contract_address, + HOLDER(), + RECIPIENT(), + HOLDER(), + VALUE_WITHOUT_FEES, + VALUE_WITH_FEES, + ); +} + +// +// Scenario inspired by solmate ERC4626 tests +// + +#[test] +fn test_multiple_txs_part_1() { + let mut asset = deploy_asset(); + let mut vault = deploy_vault(asset.contract_address); + + let alice = contract_address_const::<'alice'>(); + let bob = contract_address_const::<'bob'>(); + + asset.unsafe_mint(alice, 4_000); + asset.unsafe_mint(bob, 7_001); + + cheat_caller_address(asset.contract_address, alice, CheatSpan::TargetCalls(1)); + asset.approve(vault.contract_address, 4_000); + cheat_caller_address(asset.contract_address, bob, CheatSpan::TargetCalls(1)); + asset.approve(vault.contract_address, 7_001); + + // 1. Alice mints 2_000 shares (costs 2_000 tokens) + cheat_caller_address(vault.contract_address, alice, CheatSpan::TargetCalls(1)); + vault.mint(2_000, alice); + + assert_eq!(vault.preview_deposit(2_000), 2_000); + assert_eq!(vault.balance_of(alice), 2_000); + assert_eq!(vault.balance_of(bob), 0); + assert_eq!(vault.convert_to_assets(vault.balance_of(alice)), 2_000); + assert_eq!(vault.convert_to_assets(vault.balance_of(bob)), 0); + assert_eq!(vault.convert_to_shares(asset.balance_of(vault.contract_address)), 2_000); + assert_eq!(vault.total_supply(), 2_000); + assert_eq!(vault.total_assets(), 2_000); + + // 2. Bob deposits 4_000 tokens (mints 4_000 shares) + cheat_caller_address(vault.contract_address, bob, CheatSpan::TargetCalls(1)); + vault.mint(4_000, bob); + + assert_eq!(vault.preview_deposit(4_000), 4_000); + assert_eq!(vault.balance_of(alice), 2_000); + assert_eq!(vault.balance_of(bob), 4_000); + assert_eq!(vault.convert_to_assets(vault.balance_of(alice)), 2_000); + assert_eq!(vault.convert_to_assets(vault.balance_of(bob)), 4_000); + assert_eq!(vault.convert_to_shares(asset.balance_of(vault.contract_address)), 6_000); + assert_eq!(vault.total_supply(), 6_000); + assert_eq!(vault.total_assets(), 6_000); + + // 3. Vault mutates by +3_000 tokens (simulated yield returned from strategy) + asset.unsafe_mint(vault.contract_address, 3_000); + + assert_eq!(vault.balance_of(alice), 2_000); + assert_eq!(vault.balance_of(bob), 4_000); + // Was 3_000, but virtual assets/shares captures part of the yield + assert_eq!(vault.convert_to_assets(vault.balance_of(alice)), 2_999); + // Was 6_000, but virtual assets/shares captures part of the yield + assert_eq!(vault.convert_to_assets(vault.balance_of(bob)), 5_999); + assert_eq!(vault.convert_to_shares(asset.balance_of(vault.contract_address)), 6_000); + assert_eq!(vault.total_supply(), 6_000); + assert_eq!(vault.total_assets(), 9_000); + + // 4. Alice deposits 2_000 tokens (mints 1_333 shares) + cheat_caller_address(vault.contract_address, alice, CheatSpan::TargetCalls(1)); + vault.deposit(2_000, alice); + + assert_eq!(vault.balance_of(alice), 3_333); + assert_eq!(vault.balance_of(bob), 4_000); + assert_eq!(vault.convert_to_assets(vault.balance_of(alice)), 4_999); + assert_eq!(vault.convert_to_assets(vault.balance_of(bob)), 6_000); + assert_eq!(vault.convert_to_shares(asset.balance_of(vault.contract_address)), 7_333); + assert_eq!(vault.total_supply(), 7_333); + assert_eq!(vault.total_assets(), 11_000); + + // 5. Bob mints 2_000 shares (costs 3_001 assets) + // NOTE: Bob's assets spent rounds toward infinity + // NOTE: Alice's vault assets rounds toward infinity + cheat_caller_address(vault.contract_address, bob, CheatSpan::TargetCalls(1)); + vault.mint(2_000, bob); + + assert_eq!(vault.balance_of(alice), 3_333); + assert_eq!(vault.balance_of(bob), 6_000); + assert_eq!(vault.convert_to_assets(vault.balance_of(alice)), 4_999); + assert_eq!(vault.convert_to_assets(vault.balance_of(bob)), 9_000); + assert_eq!(vault.convert_to_shares(asset.balance_of(vault.contract_address)), 9_333); + assert_eq!(vault.total_supply(), 9_333); + assert_eq!(vault.total_assets(), 14_000); + + // 6. Vault mutates by +3_000 tokens + // NOTE: Vault holds 17_001 tokens, but `assets_of` returns 17000. + asset.unsafe_mint(vault.contract_address, 3_000); + + assert_eq!(vault.balance_of(alice), 3_333); + assert_eq!(vault.balance_of(bob), 6_000); + // Was 6_071 + assert_eq!(vault.convert_to_assets(vault.balance_of(alice)), 6_070); + // Was 10_929 + assert_eq!(vault.convert_to_assets(vault.balance_of(bob)), 10_928); + assert_eq!(vault.convert_to_shares(asset.balance_of(vault.contract_address)), 9_333); + assert_eq!(vault.total_supply(), 9_333); + // Was 17_001 + assert_eq!(vault.total_assets(), 17_000); + + // 7. Alice redeems 1_333 shares (2_428 assets) + cheat_caller_address(vault.contract_address, alice, CheatSpan::TargetCalls(1)); + vault.redeem(1_333, alice, alice); + + assert_eq!(vault.balance_of(alice), 2_000); + assert_eq!(vault.balance_of(bob), 6_000); + assert_eq!(vault.convert_to_assets(vault.balance_of(alice)), 3_643); + assert_eq!(vault.convert_to_assets(vault.balance_of(bob)), 10_929); + assert_eq!(vault.convert_to_shares(asset.balance_of(vault.contract_address)), 8_000); + assert_eq!(vault.total_supply(), 8_000); + assert_eq!(vault.total_assets(), 14_573); +} + +#[test] +fn test_multiple_txs_part_2() { + // SNForge hangs, so the test is split in two. + let mut asset = deploy_asset(); + let mut vault = deploy_vault(asset.contract_address); + + let alice = contract_address_const::<'alice'>(); + let bob = contract_address_const::<'bob'>(); + + asset.unsafe_mint(alice, 4_000); + asset.unsafe_mint(bob, 7_001); + + cheat_caller_address(asset.contract_address, alice, CheatSpan::TargetCalls(1)); + asset.approve(vault.contract_address, 4_000); + cheat_caller_address(asset.contract_address, bob, CheatSpan::TargetCalls(1)); + asset.approve(vault.contract_address, 7_001); + + // Recreate state to where it left off from `test_multiple_txs_part_1`. + + // 1. Alice mints 2_000 shares (costs 2_000 tokens) + cheat_caller_address(vault.contract_address, alice, CheatSpan::TargetCalls(1)); + vault.mint(2_000, alice); + // 2. Bob deposits 4_000 tokens (mints 4_000 shares) + cheat_caller_address(vault.contract_address, bob, CheatSpan::TargetCalls(1)); + vault.mint(4_000, bob); + // 3. Vault mutates by +3_000 tokens (simulated yield returned from strategy) + asset.unsafe_mint(vault.contract_address, 3_000); + // 4. Alice deposits 2_000 tokens (mints 1_333 shares) + cheat_caller_address(vault.contract_address, alice, CheatSpan::TargetCalls(1)); + vault.deposit(2_000, alice); + // 5. Bob mints 2_000 shares (costs 3_001 assets) + cheat_caller_address(vault.contract_address, bob, CheatSpan::TargetCalls(1)); + vault.mint(2_000, bob); + // 6. Vault mutates by +3_000 tokens + asset.unsafe_mint(vault.contract_address, 3_000); + // 7. Alice redeems 1_333 shares (2_428 assets) + cheat_caller_address(vault.contract_address, alice, CheatSpan::TargetCalls(1)); + vault.redeem(1_333, alice, alice); + + // 8. Bob withdraws 2_929 assets (1_608 shares) + cheat_caller_address(vault.contract_address, bob, CheatSpan::TargetCalls(1)); + vault.withdraw(2_929, bob, bob); + + assert_eq!(vault.balance_of(alice), 2_000); + assert_eq!(vault.balance_of(bob), 4_392); + assert_eq!(vault.convert_to_assets(vault.balance_of(alice)), 3_643); + assert_eq!(vault.convert_to_assets(vault.balance_of(bob)), 8_000); + assert_eq!(vault.convert_to_shares(asset.balance_of(vault.contract_address)), 6_392); + assert_eq!(vault.total_supply(), 6_392); + assert_eq!(vault.total_assets(), 11_644); + + // 9. Alice withdraws 3_643 assets (2_000 shares) + // NOTE: Bob's assets have been rounded back towards infinity + cheat_caller_address(vault.contract_address, alice, CheatSpan::TargetCalls(1)); + vault.withdraw(3_643, alice, alice); + + assert_eq!(vault.balance_of(alice), 0); + assert_eq!(vault.balance_of(bob), 4_392); + assert_eq!(vault.convert_to_assets(vault.balance_of(alice)), 0); + assert_eq!(vault.convert_to_assets(vault.balance_of(bob)), 8_000); + assert_eq!(vault.convert_to_shares(asset.balance_of(vault.contract_address)), 4_392); + assert_eq!(vault.total_supply(), 4_392); + assert_eq!(vault.total_assets(), 8_001); + + // 10. Bob redeems 4_392 shares (8_001) + cheat_caller_address(vault.contract_address, bob, CheatSpan::TargetCalls(1)); + vault.redeem(4_392, bob, bob); + + assert_eq!(vault.balance_of(alice), 0); + assert_eq!(vault.balance_of(bob), 0); + assert_eq!(vault.convert_to_assets(vault.balance_of(alice)), 0); + assert_eq!(vault.convert_to_assets(vault.balance_of(bob)), 0); + assert_eq!(vault.convert_to_shares(asset.balance_of(vault.contract_address)), 0); + assert_eq!(vault.total_supply(), 0); + assert_eq!(vault.total_assets(), 1); +} + +// +// Assertions/Helpers +// + +fn assert_expected_shares( + vault: ERC4626ABIDispatcher, account: ContractAddress, expected_shares: u256, +) { + let actual_shares = vault.balance_of(account); + assert_eq!(actual_shares, expected_shares); +} + +fn assert_expected_assets( + asset: IERC20ReentrantDispatcher, account: ContractAddress, expected_assets: u256, +) { + let actual_assets = asset.balance_of(account); + assert_eq!(actual_assets, expected_assets); +} + +#[generate_trait] +pub impl ERC4626SpyHelpersImpl of ERC4626SpyHelpers { + fn assert_event_deposit( + ref self: EventSpy, + contract: ContractAddress, + sender: ContractAddress, + owner: ContractAddress, + assets: u256, + shares: u256, + ) { + let expected = ERC4626Component::Event::Deposit(Deposit { sender, owner, assets, shares }); + self.assert_emitted_single(contract, expected); + } + + fn assert_only_event_deposit( + ref self: EventSpy, + contract: ContractAddress, + sender: ContractAddress, + owner: ContractAddress, + assets: u256, + shares: u256, + ) { + self.assert_event_deposit(contract, sender, owner, assets, shares); + self.assert_no_events_left_from(contract); + } + + fn assert_event_withdraw( + ref self: EventSpy, + contract: ContractAddress, + sender: ContractAddress, + receiver: ContractAddress, + owner: ContractAddress, + assets: u256, + shares: u256, + ) { + let expected = ERC4626Component::Event::Withdraw( + Withdraw { sender, receiver, owner, assets, shares }, + ); + self.assert_emitted_single(contract, expected); + } + + fn assert_only_event_withdraw( + ref self: EventSpy, + contract: ContractAddress, + sender: ContractAddress, + receiver: ContractAddress, + owner: ContractAddress, + assets: u256, + shares: u256, + ) { + self.assert_event_withdraw(contract, sender, receiver, owner, assets, shares); + self.assert_no_events_left_from(contract); + } +} diff --git a/packages/utils/src/math.cairo b/packages/utils/src/math.cairo index 1b8a19884..2d50310e2 100644 --- a/packages/utils/src/math.cairo +++ b/packages/utils/src/math.cairo @@ -1,6 +1,8 @@ // SPDX-License-Identifier: MIT // OpenZeppelin Contracts for Cairo v0.20.0 (utils/math.cairo) +use core::integer::u512_safe_div_rem_by_u256; +use core::num::traits::WideMul; use core::traits::{BitAnd, BitXor, Into}; /// Returns the average of two numbers. The result is rounded down. @@ -19,3 +21,45 @@ pub fn average< // (a + b) / 2 can overflow. (a & b) + (a ^ b) / 2_u8.into() } + +#[derive(Drop, Copy, Debug)] +pub enum Rounding { + Floor, // Toward negative infinity + Ceil, // Toward positive infinity + Trunc, // Toward zero + Expand // Away from zero +} + +/// Returns the quotient of x * y / denominator and rounds up or down depending on `rounding`. +/// Uses `u512_safe_div_rem_by_u256` for precision. +/// +/// Requirements: +/// +/// - `denominator` cannot be zero. +/// - The quotient cannot be greater than u256. +pub fn u256_mul_div(x: u256, y: u256, denominator: u256, rounding: Rounding) -> u256 { + let (q, r) = _raw_u256_mul_div(x, y, denominator); + + let is_rounded_up = match rounding { + Rounding::Ceil => 1, + Rounding::Expand => 1, + Rounding::Trunc => 0, + Rounding::Floor => 0, + }; + + let has_remainder = if r > 0 { + 1 + } else { + 0 + }; + + q + (is_rounded_up & has_remainder) +} + +fn _raw_u256_mul_div(x: u256, y: u256, denominator: u256) -> (u256, u256) { + let denominator = denominator.try_into().expect('mul_div division by zero'); + let p = x.wide_mul(y); + let (q, r) = u512_safe_div_rem_by_u256(p, denominator); + let q = q.try_into().expect('mul_div quotient > u256'); + (q, r) +} diff --git a/packages/utils/src/tests/test_math.cairo b/packages/utils/src/tests/test_math.cairo index b098c9447..9cc6b5bd7 100644 --- a/packages/utils/src/tests/test_math.cairo +++ b/packages/utils/src/tests/test_math.cairo @@ -1,10 +1,16 @@ use core::integer::{u512, u512_safe_div_rem_by_u256}; +use core::num::traits::Bounded; use core::num::traits::OverflowingAdd; -use crate::math::average; +use crate::math; +use crate::math::Rounding; + +// +// average +// #[test] fn test_average_u8(a: u8, b: u8) { - let actual = average(a, b); + let actual = math::average(a, b); let a: u256 = a.into(); let b: u256 = b.into(); @@ -15,7 +21,7 @@ fn test_average_u8(a: u8, b: u8) { #[test] fn test_average_u16(a: u16, b: u16) { - let actual = average(a, b); + let actual = math::average(a, b); let a: u256 = a.into(); let b: u256 = b.into(); @@ -26,7 +32,7 @@ fn test_average_u16(a: u16, b: u16) { #[test] fn test_average_u32(a: u32, b: u32) { - let actual = average(a, b); + let actual = math::average(a, b); let a: u256 = a.into(); let b: u256 = b.into(); @@ -37,7 +43,7 @@ fn test_average_u32(a: u32, b: u32) { #[test] fn test_average_u64(a: u64, b: u64) { - let actual = average(a, b); + let actual = math::average(a, b); let a: u256 = a.into(); let b: u256 = b.into(); @@ -48,7 +54,7 @@ fn test_average_u64(a: u64, b: u64) { #[test] fn test_average_u128(a: u128, b: u128) { - let actual = average(a, b); + let actual = math::average(a, b); let a: u256 = a.into(); let b: u256 = b.into(); @@ -59,7 +65,7 @@ fn test_average_u128(a: u128, b: u128) { #[test] fn test_average_u256(a: u256, b: u256) { - let actual = average(a, b); + let actual = math::average(a, b); let mut expected = 0; let (sum, overflow) = a.overflowing_add(b); @@ -73,3 +79,101 @@ fn test_average_u256(a: u256, b: u256) { assert_eq!(actual, expected); } + +// +// mul_div +// + +#[test] +#[should_panic(expected: 'mul_div division by zero')] +fn test_mul_div_divide_by_zero() { + let x = 1; + let y = 1; + let denominator = 0; + + math::u256_mul_div(x, y, denominator, Rounding::Floor); +} + +#[test] +#[should_panic(expected: 'mul_div quotient > u256')] +fn test_mul_div_result_gt_u256() { + let x = 5; + let y = Bounded::MAX; + let denominator = 2; + + math::u256_mul_div(x, y, denominator, Rounding::Floor); +} + +#[test] +fn test_mul_div_round_down_small_values() { + let round_down = array![Rounding::Floor, Rounding::Trunc]; + let args_list = array![ // (x, y, denominator, expected result) + (3, 4, 5, 2), (3, 5, 5, 3)] + .span(); + + for rounding in round_down { + for args in args_list { + let (x, y, denominator, expected) = args; + assert_eq!(math::u256_mul_div(*x, *y, *denominator, rounding), *expected); + } + } +} + +#[test] +fn test_mul_div_round_down_large_values() { + let round_down = array![Rounding::Floor, Rounding::Trunc]; + let u256_max: u256 = Bounded::MAX; + let args_list = array![ + // (x, y, denominator, expected result) + (42, u256_max - 1, u256_max, 41), + (17, u256_max, u256_max, 17), + (u256_max - 1, u256_max - 1, u256_max, u256_max - 2), + (u256_max, u256_max - 1, u256_max, u256_max - 1), + (u256_max, u256_max, u256_max, u256_max), + ] + .span(); + + for rounding in round_down { + for args in args_list { + let (x, y, denominator, expected) = args; + assert_eq!(math::u256_mul_div(*x, *y, *denominator, rounding), *expected); + }; + }; +} + +#[test] +fn test_mul_div_round_up_small_values() { + let round_up = array![Rounding::Ceil, Rounding::Expand]; + let args_list = array![ // (x, y, denominator, expected result) + (3, 4, 5, 3), (3, 5, 5, 3)] + .span(); + + for rounding in round_up { + for args in args_list { + let (x, y, denominator, expected) = args; + assert_eq!(math::u256_mul_div(*x, *y, *denominator, rounding), *expected); + } + } +} + +#[test] +fn test_mul_div_round_up_large_values() { + let round_up = array![Rounding::Ceil, Rounding::Expand]; + let u256_max: u256 = Bounded::MAX; + let args_list = array![ + // (x, y, denominator, expected result) + (42, u256_max - 1, u256_max, 42), + (17, u256_max, u256_max, 17), + (u256_max - 1, u256_max - 1, u256_max, u256_max - 1), + (u256_max, u256_max - 1, u256_max, u256_max - 1), + (u256_max, u256_max, u256_max, u256_max), + ] + .span(); + + for rounding in round_up { + for args in args_list { + let (x, y, denominator, expected) = args; + assert_eq!(math::u256_mul_div(*x, *y, *denominator, rounding), *expected); + }; + }; +}