Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save gas and clean the upper bits of computed pool address properly #291

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
5 changes: 4 additions & 1 deletion contracts/modules/Permit2Payments.sol
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,11 @@ abstract contract Permit2Payments is Payments {
internal
{
uint256 batchLength = batchDetails.length;
for (uint256 i = 0; i < batchLength; ++i) {
for (uint256 i = 0; i < batchLength;) {
if (batchDetails[i].from != owner) revert FromAddressIsNotOwner();
unchecked {
++i;
}
}
PERMIT2.transferFrom(batchDetails);
}
Expand Down
50 changes: 50 additions & 0 deletions contracts/modules/uniswap/TernaryLib.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// SPDX-License-Identifier: GPL-3.0-or-later
pragma solidity >=0.5.0;

/// @title Library for replacing ternary operator with efficient bitwise operations
library TernaryLib {
/// @notice Equivalent to the ternary operator: `condition ? a : b`
function ternary(bool condition, uint256 a, uint256 b) internal pure returns (uint256 res) {
assembly {
res := xor(b, mul(xor(a, b), condition))
}
}

/// @notice Equivalent to the ternary operator: `condition ? a : b`
function ternary(bool condition, int256 a, int256 b) internal pure returns (int256 res) {
assembly {
res := xor(b, mul(xor(a, b), condition))
}
}

/// @notice Equivalent to the ternary operator: `condition ? a : b`
function ternary(bool condition, address a, address b) internal pure returns (address res) {
assembly {
res := xor(b, mul(xor(a, b), condition))
}
}

/// @notice Sorts two tokens to return token0 and token1
/// @param tokenA The first token to sort
/// @param tokenB The other token to sort
/// @return token0 The smaller token by address value
/// @return token1 The larger token by address value
function sortTokens(address tokenA, address tokenB) internal pure returns (address token0, address token1) {
assembly {
let diff := mul(xor(tokenA, tokenB), lt(tokenB, tokenA))
token0 := xor(tokenA, diff)
token1 := xor(tokenB, diff)
}
}

/// @notice Switches two uint256 if `condition` is true
/// @dev Equivalent to: `condition ? (b, a) : (a, b)`
function switchIf(bool condition, uint256 a, uint256 b) internal pure returns (uint256, uint256) {
assembly {
let diff := mul(xor(a, b), condition)
a := xor(a, diff)
b := xor(b, diff)
}
return (a, b);
}
}
56 changes: 30 additions & 26 deletions contracts/modules/uniswap/v2/UniswapV2Library.sol
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
pragma solidity >=0.8.0;

import {IUniswapV2Pair} from '@uniswap/v2-core/contracts/interfaces/IUniswapV2Pair.sol';
import {TernaryLib} from '../TernaryLib.sol';

/// @title Uniswap v2 Helper Library
/// @notice Calculates the recipient address for a command
Expand All @@ -20,7 +21,7 @@ library UniswapV2Library {
pure
returns (address pair)
{
(address token0, address token1) = sortTokens(tokenA, tokenB);
(address token0, address token1) = TernaryLib.sortTokens(tokenA, tokenB);
pair = pairForPreSorted(factory, initCodeHash, token0, token1);
}

Expand All @@ -37,7 +38,7 @@ library UniswapV2Library {
returns (address pair, address token0)
{
address token1;
(token0, token1) = sortTokens(tokenA, tokenB);
(token0, token1) = TernaryLib.sortTokens(tokenA, tokenB);
pair = pairForPreSorted(factory, initCodeHash, token0, token1);
}

Expand All @@ -52,15 +53,25 @@ library UniswapV2Library {
pure
returns (address pair)
{
pair = address(
uint160(
uint256(
keccak256(
abi.encodePacked(hex'ff', factory, keccak256(abi.encodePacked(token0, token1)), initCodeHash)
)
)
)
);
// accomplishes the following:
// address(keccak256(abi.encodePacked(hex'ff', factory, keccak256(abi.encodePacked(token0, token1)), initCodeHash)))
assembly ("memory-safe") {
Copy link
Member

@ewilz ewilz May 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I like the gas savings here. Wonder if it's worth a comment:

// accomplishes the following:
// address(keccak256(abi.encodePacked(hex'ff', factory, keccak256(abi.encodePacked(token0, token1)), initCodeHash)))

to get a feel for what this is doing at a glance

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, I think that's pretty helpful for readability, one for the V3Lib would probably make sense too. (sorry, missed that)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

// Cache the free memory pointer.
let fmp := mload(0x40)
// pairHash = keccak256(abi.encodePacked(token0, token1))
mstore(0x14, token1)
mstore(0, token0)
let pairHash := keccak256(0x0c, 0x28)
// abi.encodePacked(hex'ff', factory, pairHash, initCodeHash)
// Prefix the factory address with 0xff.
mstore(0, or(factory, 0xff0000000000000000000000000000000000000000))
mstore(0x20, pairHash)
mstore(0x40, initCodeHash)
// Compute the CREATE2 pair address and clean the upper bits.
pair := and(keccak256(0x0b, 0x55), 0xffffffffffffffffffffffffffffffffffffffff)
// Restore the free memory pointer.
mstore(0x40, fmp)
}
}

/// @notice Calculates the v2 address for a pair and fetches the reserves for each token
Expand All @@ -79,7 +90,7 @@ library UniswapV2Library {
address token0;
(pair, token0) = pairAndToken0For(factory, initCodeHash, tokenA, tokenB);
(uint256 reserve0, uint256 reserve1,) = IUniswapV2Pair(pair).getReserves();
(reserveA, reserveB) = tokenA == token0 ? (reserve0, reserve1) : (reserve1, reserve0);
(reserveA, reserveB) = TernaryLib.switchIf(tokenA == token0, reserve1, reserve0);
}

/// @notice Given an input asset amount returns the maximum output amount of the other asset
Expand Down Expand Up @@ -129,21 +140,14 @@ library UniswapV2Library {
{
if (path.length < 2) revert InvalidPath();
amount = amountOut;
for (uint256 i = path.length - 1; i > 0; i--) {
uint256 reserveIn;
uint256 reserveOut;
unchecked {
for (uint256 i = path.length - 1; i > 0; --i) {
uint256 reserveIn;
uint256 reserveOut;

(pair, reserveIn, reserveOut) = pairAndReservesFor(factory, initCodeHash, path[i - 1], path[i]);
amount = getAmountIn(amount, reserveIn, reserveOut);
(pair, reserveIn, reserveOut) = pairAndReservesFor(factory, initCodeHash, path[i - 1], path[i]);
amount = getAmountIn(amount, reserveIn, reserveOut);
}
}
}

/// @notice Sorts two tokens to return token0 and token1
/// @param tokenA The first token to sort
/// @param tokenB The other token to sort
/// @return token0 The smaller token by address value
/// @return token1 The larger token by address value
function sortTokens(address tokenA, address tokenB) internal pure returns (address token0, address token1) {
(token0, token1) = tokenA < tokenB ? (tokenA, tokenB) : (tokenB, tokenA);
}
}
26 changes: 15 additions & 11 deletions contracts/modules/uniswap/v2/V2SwapRouter.sol
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {Payments} from '../../Payments.sol';
import {Permit2Payments} from '../../Permit2Payments.sol';
import {Constants} from '../../../libraries/Constants.sol';
import {ERC20} from 'solmate/src/tokens/ERC20.sol';
import {TernaryLib} from '../TernaryLib.sol';

/// @title Router for Uniswap v2 Trades
abstract contract V2SwapRouter is RouterImmutables, Permit2Payments {
Expand All @@ -20,22 +21,25 @@ abstract contract V2SwapRouter is RouterImmutables, Permit2Payments {
if (path.length < 2) revert V2InvalidPath();

// cached to save on duplicate operations
(address token0,) = UniswapV2Library.sortTokens(path[0], path[1]);
(address token0,) = TernaryLib.sortTokens(path[0], path[1]);
uint256 finalPairIndex = path.length - 1;
uint256 penultimatePairIndex = finalPairIndex - 1;
for (uint256 i; i < finalPairIndex; i++) {
(address input, address output) = (path[i], path[i + 1]);
(uint256 reserve0, uint256 reserve1,) = IUniswapV2Pair(pair).getReserves();
(uint256 reserveInput, uint256 reserveOutput) =
input == token0 ? (reserve0, reserve1) : (reserve1, reserve0);
uint256 amountInput = ERC20(input).balanceOf(pair) - reserveInput;
uint256 amountOutput = UniswapV2Library.getAmountOut(amountInput, reserveInput, reserveOutput);
(uint256 amount0Out, uint256 amount1Out) =
input == token0 ? (uint256(0), amountOutput) : (amountOutput, uint256(0));
for (uint256 i; i < finalPairIndex; ++i) {
address input = path[i];
uint256 amount0Out;
uint256 amount1Out;
{
(uint256 reserve0, uint256 reserve1,) = IUniswapV2Pair(pair).getReserves();
(uint256 reserveInput, uint256 reserveOutput) =
TernaryLib.switchIf(input == token0, reserve1, reserve0);
uint256 amountInput = ERC20(input).balanceOf(pair) - reserveInput;
uint256 amountOutput = UniswapV2Library.getAmountOut(amountInput, reserveInput, reserveOutput);
(amount0Out, amount1Out) = TernaryLib.switchIf(input == token0, amountOutput, 0);
}
address nextPair;
(nextPair, token0) = i < penultimatePairIndex
? UniswapV2Library.pairAndToken0For(
UNISWAP_V2_FACTORY, UNISWAP_V2_PAIR_INIT_CODE_HASH, output, path[i + 2]
UNISWAP_V2_FACTORY, UNISWAP_V2_PAIR_INIT_CODE_HASH, path[i + 1], path[i + 2]
)
: (recipient, address(0));
IUniswapV2Pair(pair).swap(amount0Out, amount1Out, nextPair, new bytes(0));
Expand Down
3 changes: 1 addition & 2 deletions contracts/modules/uniswap/v3/BytesLib.sol
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
// SPDX-License-Identifier: GPL-3.0-or-later

/// @title Library for Bytes Manipulation
pragma solidity ^0.8.0;

import {Constants} from '../../../libraries/Constants.sol';

/// @title Library for Bytes Manipulation
library BytesLib {
error SliceOutOfBounds();

Expand Down
74 changes: 49 additions & 25 deletions contracts/modules/uniswap/v3/V3SwapRouter.sol
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@ import {RouterImmutables} from '../../../base/RouterImmutables.sol';
import {Permit2Payments} from '../../Permit2Payments.sol';
import {Constants} from '../../../libraries/Constants.sol';
import {ERC20} from 'solmate/src/tokens/ERC20.sol';
import {TernaryLib} from '../TernaryLib.sol';

/// @title Router for Uniswap v3 Trades
abstract contract V3SwapRouter is RouterImmutables, Permit2Payments, IUniswapV3SwapCallback {
using V3Path for bytes;
using BytesLib for bytes;
using SafeCast for uint256;
using TernaryLib for bool;

error V3InvalidSwap();
error V3TooLittleReceived();
Expand Down Expand Up @@ -93,13 +95,16 @@ abstract contract V3SwapRouter is RouterImmutables, Permit2Payments, IUniswapV3S
// the outputs of prior swaps become the inputs to subsequent ones
(int256 amount0Delta, int256 amount1Delta, bool zeroForOne) = _swap(
amountIn.toInt256(),
hasMultiplePools ? address(this) : recipient, // for intermediate swaps, this contract custodies
hasMultiplePools.ternary(address(this), recipient), // for intermediate swaps, this contract custodies
path.getFirstPool(), // only the first pool is needed
payer, // for intermediate swaps, this contract custodies
true
);

amountIn = uint256(-(zeroForOne ? amount1Delta : amount0Delta));
unchecked {
// no need to check for overflow here as it will be caught in `toInt256()`
amountIn = uint256(-zeroForOne.ternary(amount1Delta, amount0Delta));
}

// decide whether to continue or terminate
if (hasMultiplePools) {
Expand Down Expand Up @@ -131,9 +136,11 @@ abstract contract V3SwapRouter is RouterImmutables, Permit2Payments, IUniswapV3S
(int256 amount0Delta, int256 amount1Delta, bool zeroForOne) =
_swap(-amountOut.toInt256(), recipient, path, payer, false);

uint256 amountOutReceived = zeroForOne ? uint256(-amount1Delta) : uint256(-amount0Delta);

if (amountOutReceived != amountOut) revert V3InvalidAmountOut();
unchecked {
// no need to check for overflow
uint256 amountOutReceived = uint256(-zeroForOne.ternary(amount1Delta, amount0Delta));
if (amountOutReceived != amountOut) revert V3InvalidAmountOut();
}

maxAmountInCached = DEFAULT_MAX_AMOUNT_IN;
}
Expand All @@ -144,34 +151,51 @@ abstract contract V3SwapRouter is RouterImmutables, Permit2Payments, IUniswapV3S
private
returns (int256 amount0Delta, int256 amount1Delta, bool zeroForOne)
{
(address tokenIn, uint24 fee, address tokenOut) = path.decodeFirstPool();

zeroForOne = isExactIn ? tokenIn < tokenOut : tokenOut < tokenIn;
address pool;
{
(address tokenIn, uint24 fee, address tokenOut) = path.decodeFirstPool();
pool = computePoolAddress(tokenIn, tokenOut, fee);
// When isExactIn == 1, zeroForOne = tokenIn < tokenOut = !(tokenOut < tokenIn) = 1 ^ (tokenOut < tokenIn)
// When isExactIn == 0, zeroForOne = tokenOut < tokenIn = 0 ^ (tokenOut < tokenIn)
assembly {
zeroForOne := xor(isExactIn, lt(tokenOut, tokenIn))
}
}

(amount0Delta, amount1Delta) = IUniswapV3Pool(computePoolAddress(tokenIn, tokenOut, fee)).swap(
(amount0Delta, amount1Delta) = IUniswapV3Pool(pool).swap(
recipient,
zeroForOne,
amount,
(zeroForOne ? MIN_SQRT_RATIO + 1 : MAX_SQRT_RATIO - 1),
uint160(zeroForOne.ternary(MIN_SQRT_RATIO + 1, MAX_SQRT_RATIO - 1)),
abi.encode(path, payer)
);
}

function computePoolAddress(address tokenA, address tokenB, uint24 fee) private view returns (address pool) {
if (tokenA > tokenB) (tokenA, tokenB) = (tokenB, tokenA);
pool = address(
uint160(
uint256(
keccak256(
abi.encodePacked(
hex'ff',
UNISWAP_V3_FACTORY,
keccak256(abi.encode(tokenA, tokenB, fee)),
UNISWAP_V3_POOL_INIT_CODE_HASH
)
)
)
)
);
address factory = UNISWAP_V3_FACTORY;
bytes32 initCodeHash = UNISWAP_V3_POOL_INIT_CODE_HASH;
// accomplishes the following:
// address(keccak256(abi.encodePacked(hex'ff', factory, keccak256(abi.encode(tokenA, tokenB, fee)), initCodeHash)))
assembly ("memory-safe") {
// Cache the free memory pointer.
let fmp := mload(0x40)
// Hash the pool key.
// Equivalent to `if (tokenA > tokenB) (tokenA, tokenB) = (tokenB, tokenA)`
let diff := mul(xor(tokenA, tokenB), lt(tokenB, tokenA))
// poolHash = keccak256(abi.encode(tokenA, tokenB, fee))
mstore(0, xor(tokenA, diff))
mstore(0x20, xor(tokenB, diff))
mstore(0x40, fee)
let poolHash := keccak256(0, 0x60)
// abi.encodePacked(hex'ff', factory, poolHash, initCodeHash)
// Prefix the factory address with 0xff.
mstore(0, or(factory, 0xff0000000000000000000000000000000000000000))
mstore(0x20, poolHash)
mstore(0x40, initCodeHash)
// Compute the CREATE2 pool address and clean the upper bits.
pool := and(keccak256(0x0b, 0x55), 0xffffffffffffffffffffffffffffffffffffffff)
// Restore the free memory pointer.
mstore(0x40, fmp)
}
}
}
3 changes: 0 additions & 3 deletions test/foundry-tests/UniswapV2.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@ import {Constants} from '../../contracts/libraries/Constants.sol';
import {Commands} from '../../contracts/libraries/Commands.sol';
import {RouterParameters} from '../../contracts/base/RouterImmutables.sol';

import '@openzeppelin/contracts/token/ERC721/IERC721Receiver.sol';
import '@openzeppelin/contracts/token/ERC1155/IERC1155Receiver.sol';

abstract contract UniswapV2Test is Test {
address constant RECIPIENT = address(10);
uint256 constant AMOUNT = 1 ether;
Expand Down
Loading