diff --git a/Sources/WasmInterpreter/Heap.swift b/Sources/WasmInterpreter/Heap.swift index dd27294..2d70774 100644 --- a/Sources/WasmInterpreter/Heap.swift +++ b/Sources/WasmInterpreter/Heap.swift @@ -3,4 +3,8 @@ import Foundation public struct Heap { public let pointer: UnsafeMutablePointer public let size: Int + + public func isValid(offset: Int, length: Int) -> Bool { + (offset + length) <= size + } } diff --git a/Sources/WasmInterpreter/WasmInterpreter.swift b/Sources/WasmInterpreter/WasmInterpreter.swift index 997ad5a..179643d 100644 --- a/Sources/WasmInterpreter/WasmInterpreter.swift +++ b/Sources/WasmInterpreter/WasmInterpreter.swift @@ -63,7 +63,9 @@ public final class WasmInterpreter { public func dataFromHeap(offset: Int, length: Int) throws -> Data { let heap = try self.heap() - guard offset + length < heap.size else { throw WasmInterpreterError.invalidMemoryAccess } + + guard heap.isValid(offset: offset, length: length) + else { throw WasmInterpreterError.invalidMemoryAccess } return Data(bytes: heap.pointer.advanced(by: offset), count: length) } @@ -85,7 +87,9 @@ public final class WasmInterpreter { public func valuesFromHeap(offset: Int, length: Int) throws -> [T] { let heap = try self.heap() - guard offset + length < heap.size else { throw WasmInterpreterError.invalidMemoryAccess } + + guard heap.isValid(offset: offset, length: length) + else { throw WasmInterpreterError.invalidMemoryAccess } return heap.pointer .advanced(by: offset) @@ -99,7 +103,9 @@ public final class WasmInterpreter { public func writeToHeap(data: Data, offset: Int) throws { let heap = try self.heap() - guard offset + data.count < heap.size else { throw WasmInterpreterError.invalidMemoryAccess } + + guard heap.isValid(offset: offset, length: data.count) + else { throw WasmInterpreterError.invalidMemoryAccess } try data.withUnsafeBytes { (rawPointer: UnsafeRawBufferPointer) -> Void in guard let pointer = rawPointer.bindMemory(to: UInt8.self).baseAddress diff --git a/Tests/WasmInterpreterTests/Wasm Modules/MemoryModule.swift b/Tests/WasmInterpreterTests/Wasm Modules/MemoryModule.swift index 15d4bdc..a59932c 100644 --- a/Tests/WasmInterpreterTests/Wasm Modules/MemoryModule.swift +++ b/Tests/WasmInterpreterTests/Wasm Modules/MemoryModule.swift @@ -8,6 +8,10 @@ public struct MemoryModule { _vm = try WasmInterpreter(module: MemoryModule.wasm) } + func heapSize() throws -> Int { + try _vm.heap().size + } + func string(at offset: Int, length: Int) throws -> String { try _vm.stringFromHeap(offset: offset, length: length) } diff --git a/Tests/WasmInterpreterTests/WasmInterpreterTests.swift b/Tests/WasmInterpreterTests/WasmInterpreterTests.swift index cc0234d..14ed3f3 100644 --- a/Tests/WasmInterpreterTests/WasmInterpreterTests.swift +++ b/Tests/WasmInterpreterTests/WasmInterpreterTests.swift @@ -52,11 +52,41 @@ final class WasmInterpreterTests: XCTestCase { XCTAssertEqual("👋", try mod.string(at: 17, length: "👋".utf8.count)) } + func testAccessingInvalidMemoryAddresses() throws { + let mod = try MemoryModule() + let size = try mod.heapSize() + + let message = "Hello" + + let validOffset = size - message.utf8.count + XCTAssertNoThrow(try mod.write(message, to: validOffset)) + XCTAssertEqual( + message, + try mod.string(at: validOffset, length: message.utf8.count) + ) + + let invalidOffset = size - message.utf8.count + 1 + XCTAssertThrowsError(try mod.write(message, to: invalidOffset)) { error in + guard let wasmError = error as? WasmInterpreterError + else { return XCTFail() } + + guard case .invalidMemoryAccess = wasmError + else { return XCTFail() } + } + + // Ensure memory hasn't been modified + XCTAssertEqual( + message, + try mod.string(at: validOffset, length: message.utf8.count) + ) + } + static var allTests = [ ("testCallingTwoFunctionsWithSameImplementation", testCallingTwoFunctionsWithSameImplementation), ("testPassingAndReturning32BitValues", testPassingAndReturning32BitValues), ("testPassingAndReturning64BitValues", testPassingAndReturning64BitValues), ("testUsingImportedFunction", testUsingImportedFunction), ("testAccessingAndModifyingHeapMemory", testAccessingAndModifyingHeapMemory), + ("testAccessingInvalidMemoryAddresses", testAccessingInvalidMemoryAddresses), ] }