diff --git a/src/examples/CCTPHookWrapper.sol b/src/examples/CCTPHookWrapper.sol index 27174b1..82e926a 100644 --- a/src/examples/CCTPHookWrapper.sol +++ b/src/examples/CCTPHookWrapper.sol @@ -21,23 +21,24 @@ import {IReceiverV2} from "../interfaces/v2/IReceiverV2.sol"; import {TypedMemView} from "@memview-sol/contracts/TypedMemView.sol"; import {MessageV2} from "../messages/v2/MessageV2.sol"; import {BurnMessageV2} from "../messages/v2/BurnMessageV2.sol"; +import {Ownable2Step} from "../roles/Ownable2Step.sol"; /** * @title CCTPHookWrapper * @notice A sample wrapper around CCTP v2 that relays a message and * optionally executes the hook contained in the Burn Message. - * @dev This is intended to only work with CCTP v2 message formats and interfaces. + * @dev Intended to only work with CCTP v2 message formats and interfaces. */ -contract CCTPHookWrapper { - // ============ State Variables ============ +contract CCTPHookWrapper is Ownable2Step { + // ============ Constants ============ // Address of the local message transmitter IReceiverV2 public immutable messageTransmitter; // The supported Message Format version - uint32 public immutable supportedMessageVersion; + uint32 public constant supportedMessageVersion = 1; // The supported Message Body version - uint32 public immutable supportedMessageBodyVersion; + uint32 public constant supportedMessageBodyVersion = 1; // Byte-length of an address uint256 internal constant ADDRESS_BYTE_LENGTH = 20; @@ -46,34 +47,17 @@ contract CCTPHookWrapper { using TypedMemView for bytes; using TypedMemView for bytes29; - // ============ Modifiers ============ - /** - * @notice A modifier to enable access control - * @dev Can be overridden to customize the behavior - */ - modifier onlyAllowed() virtual { - _; - } - // ============ Constructor ============ /** * @param _messageTransmitter The address of the local message transmitter - * @param _messageVersion The required CCTP message version. For CCTP v2, this is 1. - * @param _messageBodyVersion The required message body (Burn Message) version. For CCTP v2, this is 1. */ - constructor( - address _messageTransmitter, - uint32 _messageVersion, - uint32 _messageBodyVersion - ) { + constructor(address _messageTransmitter) Ownable2Step() { require( _messageTransmitter != address(0), "Message transmitter is the zero address" ); messageTransmitter = IReceiverV2(_messageTransmitter); - supportedMessageVersion = _messageVersion; - supportedMessageBodyVersion = _messageBodyVersion; } // ============ External Functions ============ @@ -89,19 +73,12 @@ contract CCTPHookWrapper { * The hook handler will call the target address with the hookCallData, even if hookCallData * is zero-length. Additional data about the burn message is not passed in this call. * - * WARNING: this implementation does NOT enforce atomicity in the hook call. If atomicity is - * required, a new wrapper contract can be created, possibly by overriding this behavior in `_handleHook`, - * or by introducing a different format for the hook data that includes more information about - * the desired handling. + * @dev Reverts if not called by the Owner. Due to the lack of atomicity with the hook call, permissionless relay of messages containing hooks via + * an implementation like this contract should be carefully considered, as a malicious caller could use a low gas attack to consume + * the message's nonce without executing the hook. * - * WARNING: in a permissionless context, it is important not to view this wrapper implementation as a trusted - * caller of a hook, as others can craft messages containing hooks that look identical, that are - * similarly executed from this wrapper, either by setting this contract as the destination caller, - * or by setting the destination caller to be bytes32(0). Alternate implementations may extract more information - * from the burn message, such as the mintRecipient or the amount, to include in the hook call to allow recipients - * to further filter their receiving actions. - * - * WARNING: re-entrant behavior is allowed in this implementation. Relay() can be overridden to disable this. + * WARNING: this implementation does NOT enforce atomicity in the hook call. This is to prevent a failed hook call + * from preventing relay of a message if this contract is set as the destinationCaller. * * @dev Reverts if the receiveMessage() call to the local message transmitter reverts, or returns false. * @param message The message to relay, as bytes @@ -118,73 +95,66 @@ contract CCTPHookWrapper { ) external virtual - onlyAllowed returns ( bool relaySuccess, bool hookSuccess, bytes memory hookReturnData ) { - bytes29 _msg = message.ref(0); - bytes29 _msgBody = MessageV2._getMessageBody(_msg); + _checkOwner(); - // Perform message validation - _validateMessage(_msg, _msgBody); + // Validate message + bytes29 _msg = message.ref(0); + MessageV2._validateMessageFormat(_msg); + require( + MessageV2._getVersion(_msg) == supportedMessageVersion, + "Invalid message version" + ); - // Relay message + // Validate burn message + bytes29 _msgBody = MessageV2._getMessageBody(_msg); + BurnMessageV2._validateBurnMessageFormat(_msgBody); require( - messageTransmitter.receiveMessage(message, attestation), - "Receive message failed" + BurnMessageV2._getVersion(_msgBody) == supportedMessageBodyVersion, + "Invalid message body version" ); - relaySuccess = true; + // Relay message + relaySuccess = messageTransmitter.receiveMessage(message, attestation); + require(relaySuccess, "Receive message failed"); - // Handle hook + // Handle hook if present bytes29 _hookData = BurnMessageV2._getHookData(_msgBody); - (hookSuccess, hookReturnData) = _handleHook(_hookData); + if (_hookData.isValid()) { + uint256 _hookDataLength = _hookData.len(); + if (_hookDataLength >= ADDRESS_BYTE_LENGTH) { + address _target = _hookData.indexAddress(0); + bytes memory _hookCalldata = _hookData + .postfix(_hookDataLength - ADDRESS_BYTE_LENGTH, 0) + .clone(); + + (hookSuccess, hookReturnData) = _executeHook( + _target, + _hookCalldata + ); + } + } } // ============ Internal Functions ============ - /** - * @notice Validates a message and its message body - * @dev Can be overridden to customize the validation - * @dev Reverts if the message format version or message body version - * do not match the supported versions. - */ - function _validateMessage( - bytes29 _message, - bytes29 _messageBody - ) internal virtual { - require( - MessageV2._getVersion(_message) == supportedMessageVersion, - "Invalid message version" - ); - require( - BurnMessageV2._getVersion(_messageBody) == - supportedMessageBodyVersion, - "Invalid message body version" - ); - } - /** * @notice Handles hook data by executing a call to a target address - * @dev Can be overridden to customize the execution behavior - * @param _hookData The hook data contained in the Burn Message + * @dev Can be overridden to customize execution behavior + * @dev Does not revert if the CALL to the hook target fails + * @param _hookTarget The target address of the hook + * @param _hookCalldata The hook calldata * @return _success True if the call to the encoded hook target succeeds * @return _returnData The data returned from the call to the hook target */ - function _handleHook( - bytes29 _hookData + function _executeHook( + address _hookTarget, + bytes memory _hookCalldata ) internal virtual returns (bool _success, bytes memory _returnData) { - uint256 _hookDataLength = _hookData.len(); - - if (_hookDataLength >= ADDRESS_BYTE_LENGTH) { - address _target = _hookData.indexAddress(0); - bytes memory _hookCalldata = _hookData - .postfix(_hookDataLength - ADDRESS_BYTE_LENGTH, 0) - .clone(); - - (_success, _returnData) = address(_target).call(_hookCalldata); - } + (_success, _returnData) = address(_hookTarget).call(_hookCalldata); } } diff --git a/src/messages/v2/BurnMessageV2.sol b/src/messages/v2/BurnMessageV2.sol index e9bc81a..703dc06 100644 --- a/src/messages/v2/BurnMessageV2.sol +++ b/src/messages/v2/BurnMessageV2.sol @@ -152,7 +152,7 @@ library BurnMessageV2 { require(_message.isValid(), "Malformed message"); require( _message.len() >= HOOK_DATA_INDEX, - "Invalid message: too short" + "Invalid burn message: too short" ); } } diff --git a/test/examples/CCTPHookWrapper.t.sol b/test/examples/CCTPHookWrapper.t.sol index dd46d15..78aac7b 100644 --- a/test/examples/CCTPHookWrapper.t.sol +++ b/test/examples/CCTPHookWrapper.t.sol @@ -24,26 +24,32 @@ import {MessageV2} from "../../src/messages/v2/MessageV2.sol"; import {BurnMessageV2} from "../../src/messages/v2/BurnMessageV2.sol"; import {MockHookTarget} from "../mocks/v2/MockHookTarget.sol"; import {Test} from "forge-std/Test.sol"; +import {TypedMemView} from "@memview-sol/contracts/TypedMemView.sol"; contract CCTPHookWrapperTest is Test { + // Libraries + + using TypedMemView for bytes; + using TypedMemView for bytes29; + // Test events event HookReceived(uint256 paramOne, uint256 paramTwo); // Test constants - uint32 messageVersion = 1; - uint32 messageBodyVersion = 1; + uint32 v2MessageVersion = 1; + uint32 v2MessageBodyVersion = 1; + + address wrapperOwner = address(123); address localMessageTransmitter = address(10); MockHookTarget hookTarget; CCTPHookWrapper wrapper; function setUp() public { - wrapper = new CCTPHookWrapper( - localMessageTransmitter, - messageVersion, - messageBodyVersion - ); + vm.prank(wrapperOwner); + wrapper = new CCTPHookWrapper(localMessageTransmitter); + hookTarget = new MockHookTarget(); } @@ -51,81 +57,116 @@ contract CCTPHookWrapperTest is Test { function testInitialization__revertsIfMessageTransmitterIsZero() public { vm.expectRevert("Message transmitter is the zero address"); - new CCTPHookWrapper(address(0), messageVersion, messageBodyVersion); + new CCTPHookWrapper(address(0)); } function testInitialization__setsTheMessageTransmitter( address _messageTransmitter ) public { vm.assume(_messageTransmitter != address(0)); - CCTPHookWrapper _wrapper = new CCTPHookWrapper( - _messageTransmitter, - messageVersion, - messageBodyVersion - ); + CCTPHookWrapper _wrapper = new CCTPHookWrapper(_messageTransmitter); assertEq(address(_wrapper.messageTransmitter()), _messageTransmitter); } - function testInitialization__setsTheMessageVersion( - uint32 _messageVersion - ) public { - CCTPHookWrapper _wrapper = new CCTPHookWrapper( - localMessageTransmitter, - _messageVersion, - messageBodyVersion - ); + function testInitialization__usesTheV2MessageVersion() public view { assertEq( - uint256(address(_wrapper.supportedMessageVersion())), - uint256(_messageVersion) + uint256(address(wrapper.supportedMessageVersion())), + uint256(v2MessageVersion) ); } - function testInitialization__setsTheMessageBodyVersion( - uint32 _messageBodyVersion - ) public { - CCTPHookWrapper _wrapper = new CCTPHookWrapper( - localMessageTransmitter, - messageVersion, - _messageBodyVersion - ); + function testInitialization__usesTheV2MessageBodyVersion() public view { assertEq( - uint256(address(_wrapper.supportedMessageBodyVersion())), - uint256(_messageBodyVersion) + uint256(address(wrapper.supportedMessageBodyVersion())), + uint256(v2MessageBodyVersion) ); } + function testRelay__revertsIfNotCalledByOwner( + address _randomAddress, + bytes calldata _randomBytes + ) public { + vm.assume(_randomAddress != wrapperOwner); + + vm.expectRevert("Ownable: caller is not the owner"); + vm.prank(_randomAddress); + wrapper.relay(_randomBytes, bytes("")); + } + function testRelay__revertsIfMessageFormatVersionIsInvalid( uint32 _messageVersion ) public { - vm.assume(_messageVersion != messageVersion); + vm.assume(_messageVersion != v2MessageVersion); vm.expectRevert("Invalid message version"); bytes memory _message = _createMessage( _messageVersion, - messageBodyVersion, + v2MessageBodyVersion, bytes("") ); + + vm.prank(wrapperOwner); wrapper.relay(_message, bytes("")); } function testRelay__revertsIfMessageBodyVersionIsInvalid( uint32 _messageBodyVersion ) public { - vm.assume(_messageBodyVersion != messageBodyVersion); + vm.assume(_messageBodyVersion != v2MessageBodyVersion); vm.expectRevert("Invalid message body version"); bytes memory _message = _createMessage( - messageVersion, + v2MessageVersion, _messageBodyVersion, bytes("") ); + + vm.prank(wrapperOwner); wrapper.relay(_message, bytes("")); } + function testRelay__revertsIfMessageValidationFails() public { + bytes memory _message = _createMessage( + v2MessageVersion, + v2MessageBodyVersion, + bytes("") + ); + + // Slice the message to make it fail validation + bytes memory _truncatedMessage = _message + .ref(0) + .slice(0, 147, 0) + .clone(); // See: MessageV2#MESSAGE_BODY_INDEX + + vm.expectRevert("Invalid message: too short"); + + vm.prank(wrapperOwner); + wrapper.relay(_truncatedMessage, bytes("")); + } + + function testRelay__revertsIfMessageBodyValidationFails() public { + bytes memory _message = _createMessage( + v2MessageVersion, + v2MessageBodyVersion, + bytes("") + ); + + // Slice the message to make it fail validation + bytes memory _truncatedMessage = _message + .ref(0) + .slice(0, 375, 0) + .clone(); // See: BurnMessageV2#HOOK_DATA_INDEX (148 + 228 = 376) + + vm.expectRevert("Invalid burn message: too short"); + + vm.prank(wrapperOwner); + wrapper.relay(_truncatedMessage, bytes("")); + } + function testRelay__revertsIfMessageTransmitterCallReverts() public { bytes memory _message = _createMessage( - messageVersion, - messageBodyVersion, + v2MessageVersion, + v2MessageBodyVersion, bytes("") ); @@ -137,17 +178,19 @@ contract CCTPHookWrapperTest is Test { _message, bytes("") ), - "Testing: token minter failed" + "Testing: message transmitter failed" ); vm.expectRevert(); + + vm.prank(wrapperOwner); wrapper.relay(_message, bytes("")); } function testRelay__revertsIfMessageTransmitterReturnsFalse() public { bytes memory _message = _createMessage( - messageVersion, - messageBodyVersion, + v2MessageVersion, + v2MessageBodyVersion, bytes("") ); @@ -173,13 +216,15 @@ contract CCTPHookWrapperTest is Test { ); vm.expectRevert("Receive message failed"); + + vm.prank(wrapperOwner); wrapper.relay(_message, bytes("")); } function testRelay__succeedsWithNoHook() public { bytes memory _message = _createMessage( - messageVersion, - messageBodyVersion, + v2MessageVersion, + v2MessageBodyVersion, bytes("") ); @@ -193,6 +238,7 @@ contract CCTPHookWrapperTest is Test { abi.encode(true) ); + vm.prank(wrapperOwner); ( bool _relaySuccess, bool _hookSuccess, @@ -210,8 +256,8 @@ contract CCTPHookWrapperTest is Test { MockHookTarget.failingHook.selector ); bytes memory _message = _createMessage( - messageVersion, - messageBodyVersion, + v2MessageVersion, + v2MessageBodyVersion, abi.encodePacked(address(hookTarget), _failingHookCalldata) ); @@ -227,6 +273,7 @@ contract CCTPHookWrapperTest is Test { ); // Call wrapper + vm.prank(wrapperOwner); ( bool _relaySuccess, bool _hookSuccess, @@ -238,12 +285,48 @@ contract CCTPHookWrapperTest is Test { assertEq(_getRevertMsg(_returnData), "Hook failure"); } + function testRelay__succeedsAndIgnoresHooksLessThanRequiredLength( + bytes calldata randomBytes + ) public { + vm.assume(randomBytes.length > 20); + // Prepare a message with hookData less than required length (20 bytes) + bytes memory _shortCallData = randomBytes[:19]; + bytes memory _message = _createMessage( + v2MessageVersion, + v2MessageBodyVersion, + _shortCallData + ); + + // Mock successful call to MessageTransmitter + vm.mockCall( + localMessageTransmitter, + abi.encodeWithSelector( + IReceiver.receiveMessage.selector, + _message, + bytes("") + ), + abi.encode(true) + ); + + // Call wrapper + vm.prank(wrapperOwner); + ( + bool _relaySuccess, + bool _hookSuccess, + bytes memory _returnData + ) = wrapper.relay(_message, bytes("")); + + assertTrue(_relaySuccess); + assertFalse(_hookSuccess); + assertEq(_returnData.length, 0); + } + function testRelay__succeedsWithCallToEOAHookTarget( bytes calldata _hookCalldata ) public { bytes memory _message = _createMessage( - messageVersion, - messageBodyVersion, + v2MessageVersion, + v2MessageBodyVersion, abi.encodePacked(address(12345), _hookCalldata) ); @@ -259,6 +342,7 @@ contract CCTPHookWrapperTest is Test { ); // Call wrapper + vm.prank(wrapperOwner); ( bool _relaySuccess, bool _hookSuccess, @@ -271,7 +355,7 @@ contract CCTPHookWrapperTest is Test { } function testRelay__succeedsWithSucceedingHook() public { - // Prepare a message with hookCalldata that will fail + // Prepare a message with hookCalldata that will succeed uint256 _expectedReturnData = 12; bytes memory _succeedingHookCallData = abi.encodeWithSelector( MockHookTarget.succeedingHook.selector, @@ -279,8 +363,8 @@ contract CCTPHookWrapperTest is Test { 7 ); bytes memory _message = _createMessage( - messageVersion, - messageBodyVersion, + v2MessageVersion, + v2MessageBodyVersion, abi.encodePacked(address(hookTarget), _succeedingHookCallData) ); @@ -299,6 +383,7 @@ contract CCTPHookWrapperTest is Test { emit HookReceived(5, 7); // Call wrapper + vm.prank(wrapperOwner); ( bool _relaySuccess, bool _hookSuccess, @@ -319,7 +404,7 @@ contract CCTPHookWrapperTest is Test { ) internal pure returns (bytes memory) { return abi.encodePacked( - _messageVersion, + _messageVersion, // messageVersion uint32(0), // sourceDomain uint32(0), // destinationDomain bytes32(0), // nonce @@ -338,15 +423,15 @@ contract CCTPHookWrapperTest is Test { ) internal pure returns (bytes memory) { return abi.encodePacked( - _burnMessageVersion, - bytes32(0), - bytes32(0), - uint256(0), - bytes32(0), - uint256(0), - uint256(0), - uint256(0), - _hookData + _burnMessageVersion, // messageBodyVersion + bytes32(0), // burnToken + bytes32(0), // mintRecipient + uint256(0), // amount + bytes32(0), // messageSender + uint256(0), // maxFee + uint256(0), // feeExecuted + uint256(0), // expirationBlock + _hookData // hookData ); } diff --git a/test/messages/v2/BurnMessageV2.t.sol b/test/messages/v2/BurnMessageV2.t.sol index 85559f4..5561322 100644 --- a/test/messages/v2/BurnMessageV2.t.sol +++ b/test/messages/v2/BurnMessageV2.t.sol @@ -97,7 +97,7 @@ contract BurnMessageV2Test is Test { // Lop off the hookData bytes, and then one more _m = _m.slice(0, _m.len() - _hookData.length - 1, 0); - vm.expectRevert("Invalid message: too short"); + vm.expectRevert("Invalid burn message: too short"); _m._validateBurnMessageFormat(); } diff --git a/test/v2/TokenMessengerV2.t.sol b/test/v2/TokenMessengerV2.t.sol index c9e0997..fa8552d 100644 --- a/test/v2/TokenMessengerV2.t.sol +++ b/test/v2/TokenMessengerV2.t.sol @@ -1445,7 +1445,7 @@ contract TokenMessengerV2Test is BaseTokenMessengerTest { vm.assume(_messageBody.length < 228); vm.prank(localMessageTransmitter); - vm.expectRevert("Invalid message: too short"); + vm.expectRevert("Invalid burn message: too short"); localTokenMessenger.handleReceiveFinalizedMessage( remoteDomain, remoteTokenMessengerAddr, @@ -2073,7 +2073,7 @@ contract TokenMessengerV2Test is BaseTokenMessengerTest { ); vm.prank(localMessageTransmitter); - vm.expectRevert("Invalid message: too short"); + vm.expectRevert("Invalid burn message: too short"); localTokenMessenger.handleReceiveUnfinalizedMessage( remoteDomain, remoteTokenMessengerAddr,