Skip to content

Commit

Permalink
[FP16] Implement load and store instructions. (#6796)
Browse files Browse the repository at this point in the history
  • Loading branch information
brendandahl authored Aug 6, 2024
1 parent d5a5425 commit 0c26948
Show file tree
Hide file tree
Showing 18 changed files with 952 additions and 41 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ endif()
# Compiler setup.

include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/third_party/FP16/include)
if(BUILD_LLVM_DWARF)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/third_party/llvm-project/include)
endif()
Expand Down
5 changes: 5 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,8 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

==============================================================================

The FP16 project is used in this repo, and it has the MIT license, see
third_party/FP16/LICENSE.
2 changes: 2 additions & 0 deletions scripts/gen-s-parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
("i32.load", "makeLoad(Type::i32, /*signed=*/false, 4, /*isAtomic=*/false)"),
("i64.load", "makeLoad(Type::i64, /*signed=*/false, 8, /*isAtomic=*/false)"),
("f32.load", "makeLoad(Type::f32, /*signed=*/false, 4, /*isAtomic=*/false)"),
("f32.load_f16", "makeLoad(Type::f32, /*signed=*/false, 2, /*isAtomic=*/false)"),
("f64.load", "makeLoad(Type::f64, /*signed=*/false, 8, /*isAtomic=*/false)"),
("i32.load8_s", "makeLoad(Type::i32, /*signed=*/true, 1, /*isAtomic=*/false)"),
("i32.load8_u", "makeLoad(Type::i32, /*signed=*/false, 1, /*isAtomic=*/false)"),
Expand All @@ -60,6 +61,7 @@
("i32.store", "makeStore(Type::i32, 4, /*isAtomic=*/false)"),
("i64.store", "makeStore(Type::i64, 8, /*isAtomic=*/false)"),
("f32.store", "makeStore(Type::f32, 4, /*isAtomic=*/false)"),
("f32.store_f16", "makeStore(Type::f32, 2, /*isAtomic=*/false)"),
("f64.store", "makeStore(Type::f64, 8, /*isAtomic=*/false)"),
("i32.store8", "makeStore(Type::i32, 1, /*isAtomic=*/false)"),
("i32.store16", "makeStore(Type::i32, 2, /*isAtomic=*/false)"),
Expand Down
42 changes: 32 additions & 10 deletions src/gen-s-parser.inc
Original file line number Diff line number Diff line change
Expand Up @@ -454,12 +454,23 @@ switch (buf[0]) {
return Ok{};
}
goto parse_error;
case 'o':
if (op == "f32.load"sv) {
CHECK_ERR(makeLoad(ctx, pos, annotations, Type::f32, /*signed=*/false, 4, /*isAtomic=*/false));
return Ok{};
case 'o': {
switch (buf[8]) {
case '\0':
if (op == "f32.load"sv) {
CHECK_ERR(makeLoad(ctx, pos, annotations, Type::f32, /*signed=*/false, 4, /*isAtomic=*/false));
return Ok{};
}
goto parse_error;
case '_':
if (op == "f32.load_f16"sv) {
CHECK_ERR(makeLoad(ctx, pos, annotations, Type::f32, /*signed=*/false, 2, /*isAtomic=*/false));
return Ok{};
}
goto parse_error;
default: goto parse_error;
}
goto parse_error;
}
case 't':
if (op == "f32.lt"sv) {
CHECK_ERR(makeBinary(ctx, pos, annotations, BinaryOp::LtFloat32));
Expand Down Expand Up @@ -529,12 +540,23 @@ switch (buf[0]) {
return Ok{};
}
goto parse_error;
case 't':
if (op == "f32.store"sv) {
CHECK_ERR(makeStore(ctx, pos, annotations, Type::f32, 4, /*isAtomic=*/false));
return Ok{};
case 't': {
switch (buf[9]) {
case '\0':
if (op == "f32.store"sv) {
CHECK_ERR(makeStore(ctx, pos, annotations, Type::f32, 4, /*isAtomic=*/false));
return Ok{};
}
goto parse_error;
case '_':
if (op == "f32.store_f16"sv) {
CHECK_ERR(makeStore(ctx, pos, annotations, Type::f32, 2, /*isAtomic=*/false));
return Ok{};
}
goto parse_error;
default: goto parse_error;
}
goto parse_error;
}
case 'u':
if (op == "f32.sub"sv) {
CHECK_ERR(makeBinary(ctx, pos, annotations, BinaryOp::SubFloat32));
Expand Down
16 changes: 13 additions & 3 deletions src/passes/Print.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -548,13 +548,19 @@ struct PrintExpressionContents
if (curr->bytes == 1) {
o << '8';
} else if (curr->bytes == 2) {
o << "16";
if (curr->type == Type::f32) {
o << "_f16";
} else {
o << "16";
}
} else if (curr->bytes == 4) {
o << "32";
} else {
abort();
}
o << (curr->signed_ ? "_s" : "_u");
if (curr->type != Type::f32) {
o << (curr->signed_ ? "_s" : "_u");
}
}
restoreNormalColor(o);
printMemoryName(curr->memory, o, wasm);
Expand All @@ -575,7 +581,11 @@ struct PrintExpressionContents
if (curr->bytes == 1) {
o << '8';
} else if (curr->bytes == 2) {
o << "16";
if (curr->valueType == Type::f32) {
o << "_f16";
} else {
o << "16";
}
} else if (curr->bytes == 4) {
o << "32";
} else {
Expand Down
12 changes: 10 additions & 2 deletions src/wasm-binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -1051,6 +1051,10 @@ enum ASTNodes {
I16x8DotI8x16I7x16S = 0x112,
I32x4DotI8x16I7x16AddS = 0x113,

// half precision opcodes
F32_F16LoadMem = 0x30,
F32_F16StoreMem = 0x31,

// bulk memory opcodes

MemoryInit = 0x08,
Expand Down Expand Up @@ -1703,8 +1707,12 @@ class WasmBinaryReader {
void visitLocalSet(LocalSet* curr, uint8_t code);
void visitGlobalGet(GlobalGet* curr);
void visitGlobalSet(GlobalSet* curr);
bool maybeVisitLoad(Expression*& out, uint8_t code, bool isAtomic);
bool maybeVisitStore(Expression*& out, uint8_t code, bool isAtomic);
bool maybeVisitLoad(Expression*& out,
uint8_t code,
std::optional<BinaryConsts::ASTNodes> prefix);
bool maybeVisitStore(Expression*& out,
uint8_t code,
std::optional<BinaryConsts::ASTNodes> prefix);
bool maybeVisitNontrappingTrunc(Expression*& out, uint32_t code);
bool maybeVisitAtomicRMW(Expression*& out, uint8_t code);
bool maybeVisitAtomicCmpxchg(Expression*& out, uint8_t code);
Expand Down
37 changes: 33 additions & 4 deletions src/wasm-interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <sstream>
#include <variant>

#include "fp16.h"
#include "ir/intrinsics.h"
#include "ir/module-utils.h"
#include "support/bits.h"
Expand Down Expand Up @@ -2540,8 +2541,22 @@ class ModuleRunnerBase : public ExpressionRunner<SubType> {
}
break;
}
case Type::f32:
return Literal(load32u(addr, memory)).castToF32();
case Type::f32: {
switch (load->bytes) {
case 2: {
// Convert the float16 to float32 and store the binary
// representation.
return Literal(bit_cast<int32_t>(
fp16_ieee_to_fp32_value(load16u(addr, memory))))
.castToF32();
}
case 4:
return Literal(load32u(addr, memory)).castToF32();
default:
WASM_UNREACHABLE("invalid size");
}
break;
}
case Type::f64:
return Literal(load64u(addr, memory)).castToF64();
case Type::v128:
Expand Down Expand Up @@ -2590,9 +2605,23 @@ class ModuleRunnerBase : public ExpressionRunner<SubType> {
break;
}
// write floats carefully, ensuring all bits reach memory
case Type::f32:
store32(addr, value.reinterpreti32(), memory);
case Type::f32: {
switch (store->bytes) {
case 2: {
float f32 = bit_cast<float>(value.reinterpreti32());
// Convert the float32 to float16 and store the binary
// representation.
store16(addr, fp16_ieee_from_fp32_value(f32), memory);
break;
}
case 4:
store32(addr, value.reinterpreti32(), memory);
break;
default:
WASM_UNREACHABLE("invalid store size");
}
break;
}
case Type::f64:
store64(addr, value.reinterpreti64(), memory);
break;
Expand Down
65 changes: 49 additions & 16 deletions src/wasm/wasm-binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4145,10 +4145,10 @@ BinaryConsts::ASTNodes WasmBinaryReader::readExpression(Expression*& curr) {
}
case BinaryConsts::AtomicPrefix: {
code = static_cast<uint8_t>(getU32LEB());
if (maybeVisitLoad(curr, code, /*isAtomic=*/true)) {
if (maybeVisitLoad(curr, code, BinaryConsts::AtomicPrefix)) {
break;
}
if (maybeVisitStore(curr, code, /*isAtomic=*/true)) {
if (maybeVisitStore(curr, code, BinaryConsts::AtomicPrefix)) {
break;
}
if (maybeVisitAtomicRMW(curr, code)) {
Expand Down Expand Up @@ -4198,6 +4198,12 @@ BinaryConsts::ASTNodes WasmBinaryReader::readExpression(Expression*& curr) {
if (maybeVisitTableCopy(curr, opcode)) {
break;
}
if (maybeVisitLoad(curr, opcode, BinaryConsts::MiscPrefix)) {
break;
}
if (maybeVisitStore(curr, opcode, BinaryConsts::MiscPrefix)) {
break;
}
throwError("invalid code after misc prefix: " + std::to_string(opcode));
break;
}
Expand Down Expand Up @@ -4338,10 +4344,10 @@ BinaryConsts::ASTNodes WasmBinaryReader::readExpression(Expression*& curr) {
if (maybeVisitConst(curr, code)) {
break;
}
if (maybeVisitLoad(curr, code, /*isAtomic=*/false)) {
if (maybeVisitLoad(curr, code, /*prefix=*/std::nullopt)) {
break;
}
if (maybeVisitStore(curr, code, /*isAtomic=*/false)) {
if (maybeVisitStore(curr, code, /*prefix=*/std::nullopt)) {
break;
}
throwError("bad node code " + std::to_string(code));
Expand Down Expand Up @@ -4717,14 +4723,15 @@ Index WasmBinaryReader::readMemoryAccess(Address& alignment, Address& offset) {
return memIdx;
}

bool WasmBinaryReader::maybeVisitLoad(Expression*& out,
uint8_t code,
bool isAtomic) {
bool WasmBinaryReader::maybeVisitLoad(
Expression*& out,
uint8_t code,
std::optional<BinaryConsts::ASTNodes> prefix) {
Load* curr;
auto allocate = [&]() {
curr = allocator.alloc<Load>();
};
if (!isAtomic) {
if (!prefix) {
switch (code) {
case BinaryConsts::I32LoadMem8S:
allocate();
Expand Down Expand Up @@ -4805,7 +4812,7 @@ bool WasmBinaryReader::maybeVisitLoad(Expression*& out,
return false;
}
BYN_TRACE("zz node: Load\n");
} else {
} else if (prefix == BinaryConsts::AtomicPrefix) {
switch (code) {
case BinaryConsts::I32AtomicLoad8U:
allocate();
Expand Down Expand Up @@ -4846,9 +4853,22 @@ bool WasmBinaryReader::maybeVisitLoad(Expression*& out,
return false;
}
BYN_TRACE("zz node: AtomicLoad\n");
} else if (prefix == BinaryConsts::MiscPrefix) {
switch (code) {
case BinaryConsts::F32_F16LoadMem:
allocate();
curr->bytes = 2;
curr->type = Type::f32;
break;
default:
return false;
}
BYN_TRACE("zz node: Load\n");
} else {
return false;
}

curr->isAtomic = isAtomic;
curr->isAtomic = prefix == BinaryConsts::AtomicPrefix;
Index memIdx = readMemoryAccess(curr->align, curr->offset);
memoryRefs[memIdx].push_back(&curr->memory);
curr->ptr = popNonVoidExpression();
Expand All @@ -4857,11 +4877,12 @@ bool WasmBinaryReader::maybeVisitLoad(Expression*& out,
return true;
}

bool WasmBinaryReader::maybeVisitStore(Expression*& out,
uint8_t code,
bool isAtomic) {
bool WasmBinaryReader::maybeVisitStore(
Expression*& out,
uint8_t code,
std::optional<BinaryConsts::ASTNodes> prefix) {
Store* curr;
if (!isAtomic) {
if (!prefix) {
switch (code) {
case BinaryConsts::I32StoreMem8:
curr = allocator.alloc<Store>();
Expand Down Expand Up @@ -4911,7 +4932,7 @@ bool WasmBinaryReader::maybeVisitStore(Expression*& out,
default:
return false;
}
} else {
} else if (prefix == BinaryConsts::AtomicPrefix) {
switch (code) {
case BinaryConsts::I32AtomicStore8:
curr = allocator.alloc<Store>();
Expand Down Expand Up @@ -4951,9 +4972,21 @@ bool WasmBinaryReader::maybeVisitStore(Expression*& out,
default:
return false;
}
} else if (prefix == BinaryConsts::MiscPrefix) {
switch (code) {
case BinaryConsts::F32_F16StoreMem:
curr = allocator.alloc<Store>();
curr->bytes = 2;
curr->valueType = Type::f32;
break;
default:
return false;
}
} else {
return false;
}

curr->isAtomic = isAtomic;
curr->isAtomic = prefix == BinaryConsts::AtomicPrefix;
BYN_TRACE("zz node: Store\n");
Index memIdx = readMemoryAccess(curr->align, curr->offset);
memoryRefs[memIdx].push_back(&curr->memory);
Expand Down
30 changes: 26 additions & 4 deletions src/wasm/wasm-stack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,20 @@ void BinaryInstWriter::visitLoad(Load* curr) {
}
break;
}
case Type::f32:
o << int8_t(BinaryConsts::F32LoadMem);
case Type::f32: {
switch (curr->bytes) {
case 2:
o << int8_t(BinaryConsts::MiscPrefix)
<< U32LEB(BinaryConsts::F32_F16LoadMem);
break;
case 4:
o << int8_t(BinaryConsts::F32LoadMem);
break;
default:
WASM_UNREACHABLE("invalid load size");
}
break;
}
case Type::f64:
o << int8_t(BinaryConsts::F64LoadMem);
break;
Expand Down Expand Up @@ -359,9 +370,20 @@ void BinaryInstWriter::visitStore(Store* curr) {
}
break;
}
case Type::f32:
o << int8_t(BinaryConsts::F32StoreMem);
case Type::f32: {
switch (curr->bytes) {
case 2:
o << int8_t(BinaryConsts::MiscPrefix)
<< U32LEB(BinaryConsts::F32_F16StoreMem);
break;
case 4:
o << int8_t(BinaryConsts::F32StoreMem);
break;
default:
WASM_UNREACHABLE("invalid store size");
}
break;
}
case Type::f64:
o << int8_t(BinaryConsts::F64StoreMem);
break;
Expand Down
Loading

0 comments on commit 0c26948

Please sign in to comment.