Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FP16] Implement load and store instructions. #6796

Merged
merged 4 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ endif()
# Compiler setup.

include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/third_party/FP16/include)
if(BUILD_LLVM_DWARF)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/third_party/llvm-project/include)
endif()
Expand Down
5 changes: 5 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,8 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

==============================================================================

The FP16 project is used in this repo, and it has the MIT license, see
third_party/FP16/LICENSE.
2 changes: 2 additions & 0 deletions scripts/gen-s-parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
("i32.load", "makeLoad(Type::i32, /*signed=*/false, 4, /*isAtomic=*/false)"),
("i64.load", "makeLoad(Type::i64, /*signed=*/false, 8, /*isAtomic=*/false)"),
("f32.load", "makeLoad(Type::f32, /*signed=*/false, 4, /*isAtomic=*/false)"),
("f32.load_f16", "makeLoad(Type::f32, /*signed=*/false, 2, /*isAtomic=*/false)"),
("f64.load", "makeLoad(Type::f64, /*signed=*/false, 8, /*isAtomic=*/false)"),
("i32.load8_s", "makeLoad(Type::i32, /*signed=*/true, 1, /*isAtomic=*/false)"),
("i32.load8_u", "makeLoad(Type::i32, /*signed=*/false, 1, /*isAtomic=*/false)"),
Expand All @@ -60,6 +61,7 @@
("i32.store", "makeStore(Type::i32, 4, /*isAtomic=*/false)"),
("i64.store", "makeStore(Type::i64, 8, /*isAtomic=*/false)"),
("f32.store", "makeStore(Type::f32, 4, /*isAtomic=*/false)"),
("f32.store_f16", "makeStore(Type::f32, 2, /*isAtomic=*/false)"),
("f64.store", "makeStore(Type::f64, 8, /*isAtomic=*/false)"),
("i32.store8", "makeStore(Type::i32, 1, /*isAtomic=*/false)"),
("i32.store16", "makeStore(Type::i32, 2, /*isAtomic=*/false)"),
Expand Down
42 changes: 32 additions & 10 deletions src/gen-s-parser.inc
Original file line number Diff line number Diff line change
Expand Up @@ -454,12 +454,23 @@ switch (buf[0]) {
return Ok{};
}
goto parse_error;
case 'o':
if (op == "f32.load"sv) {
CHECK_ERR(makeLoad(ctx, pos, annotations, Type::f32, /*signed=*/false, 4, /*isAtomic=*/false));
return Ok{};
case 'o': {
switch (buf[8]) {
case '\0':
if (op == "f32.load"sv) {
CHECK_ERR(makeLoad(ctx, pos, annotations, Type::f32, /*signed=*/false, 4, /*isAtomic=*/false));
return Ok{};
}
goto parse_error;
case '_':
if (op == "f32.load_f16"sv) {
CHECK_ERR(makeLoad(ctx, pos, annotations, Type::f32, /*signed=*/false, 2, /*isAtomic=*/false));
return Ok{};
}
goto parse_error;
default: goto parse_error;
}
goto parse_error;
}
case 't':
if (op == "f32.lt"sv) {
CHECK_ERR(makeBinary(ctx, pos, annotations, BinaryOp::LtFloat32));
Expand Down Expand Up @@ -529,12 +540,23 @@ switch (buf[0]) {
return Ok{};
}
goto parse_error;
case 't':
if (op == "f32.store"sv) {
CHECK_ERR(makeStore(ctx, pos, annotations, Type::f32, 4, /*isAtomic=*/false));
return Ok{};
case 't': {
switch (buf[9]) {
case '\0':
if (op == "f32.store"sv) {
CHECK_ERR(makeStore(ctx, pos, annotations, Type::f32, 4, /*isAtomic=*/false));
return Ok{};
}
goto parse_error;
case '_':
if (op == "f32.store_f16"sv) {
CHECK_ERR(makeStore(ctx, pos, annotations, Type::f32, 2, /*isAtomic=*/false));
return Ok{};
}
goto parse_error;
default: goto parse_error;
}
goto parse_error;
}
case 'u':
if (op == "f32.sub"sv) {
CHECK_ERR(makeBinary(ctx, pos, annotations, BinaryOp::SubFloat32));
Expand Down
16 changes: 13 additions & 3 deletions src/passes/Print.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -548,13 +548,19 @@ struct PrintExpressionContents
if (curr->bytes == 1) {
o << '8';
} else if (curr->bytes == 2) {
o << "16";
if (curr->type == Type::f32) {
o << "_f16";
} else {
o << "16";
}
} else if (curr->bytes == 4) {
o << "32";
} else {
abort();
}
o << (curr->signed_ ? "_s" : "_u");
if (curr->type != Type::f32) {
o << (curr->signed_ ? "_s" : "_u");
}
}
restoreNormalColor(o);
printMemoryName(curr->memory, o, wasm);
Expand All @@ -575,7 +581,11 @@ struct PrintExpressionContents
if (curr->bytes == 1) {
o << '8';
} else if (curr->bytes == 2) {
o << "16";
if (curr->valueType == Type::f32) {
o << "_f16";
} else {
o << "16";
}
} else if (curr->bytes == 4) {
o << "32";
} else {
Expand Down
12 changes: 10 additions & 2 deletions src/wasm-binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -1051,6 +1051,10 @@ enum ASTNodes {
I16x8DotI8x16I7x16S = 0x112,
I32x4DotI8x16I7x16AddS = 0x113,

// half precision opcodes
F32_F16LoadMem = 0x30,
F32_F16StoreMem = 0x31,

// bulk memory opcodes

MemoryInit = 0x08,
Expand Down Expand Up @@ -1703,8 +1707,12 @@ class WasmBinaryReader {
void visitLocalSet(LocalSet* curr, uint8_t code);
void visitGlobalGet(GlobalGet* curr);
void visitGlobalSet(GlobalSet* curr);
bool maybeVisitLoad(Expression*& out, uint8_t code, bool isAtomic);
bool maybeVisitStore(Expression*& out, uint8_t code, bool isAtomic);
bool maybeVisitLoad(Expression*& out,
uint8_t code,
std::optional<BinaryConsts::ASTNodes> prefix);
bool maybeVisitStore(Expression*& out,
uint8_t code,
std::optional<BinaryConsts::ASTNodes> prefix);
bool maybeVisitNontrappingTrunc(Expression*& out, uint32_t code);
bool maybeVisitAtomicRMW(Expression*& out, uint8_t code);
bool maybeVisitAtomicCmpxchg(Expression*& out, uint8_t code);
Expand Down
37 changes: 33 additions & 4 deletions src/wasm-interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <sstream>
#include <variant>

#include "fp16.h"
#include "ir/intrinsics.h"
#include "ir/module-utils.h"
#include "support/bits.h"
Expand Down Expand Up @@ -2540,8 +2541,22 @@ class ModuleRunnerBase : public ExpressionRunner<SubType> {
}
break;
}
case Type::f32:
return Literal(load32u(addr, memory)).castToF32();
case Type::f32: {
switch (load->bytes) {
case 2: {
// Convert the float16 to float32 and store the binary
// representation.
return Literal(bit_cast<int32_t>(
fp16_ieee_to_fp32_value(load16u(addr, memory))))
.castToF32();
}
case 4:
return Literal(load32u(addr, memory)).castToF32();
default:
WASM_UNREACHABLE("invalid size");
}
break;
}
case Type::f64:
return Literal(load64u(addr, memory)).castToF64();
case Type::v128:
Expand Down Expand Up @@ -2590,9 +2605,23 @@ class ModuleRunnerBase : public ExpressionRunner<SubType> {
break;
}
// write floats carefully, ensuring all bits reach memory
case Type::f32:
store32(addr, value.reinterpreti32(), memory);
case Type::f32: {
switch (store->bytes) {
case 2: {
float f32 = bit_cast<float>(value.reinterpreti32());
// Convert the float32 to float16 and store the binary
// representation.
store16(addr, fp16_ieee_from_fp32_value(f32), memory);
break;
}
case 4:
store32(addr, value.reinterpreti32(), memory);
break;
default:
WASM_UNREACHABLE("invalid store size");
}
break;
}
case Type::f64:
store64(addr, value.reinterpreti64(), memory);
break;
Expand Down
65 changes: 49 additions & 16 deletions src/wasm/wasm-binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4137,10 +4137,10 @@ BinaryConsts::ASTNodes WasmBinaryReader::readExpression(Expression*& curr) {
}
case BinaryConsts::AtomicPrefix: {
code = static_cast<uint8_t>(getU32LEB());
if (maybeVisitLoad(curr, code, /*isAtomic=*/true)) {
if (maybeVisitLoad(curr, code, BinaryConsts::AtomicPrefix)) {
break;
}
if (maybeVisitStore(curr, code, /*isAtomic=*/true)) {
if (maybeVisitStore(curr, code, BinaryConsts::AtomicPrefix)) {
break;
}
if (maybeVisitAtomicRMW(curr, code)) {
Expand Down Expand Up @@ -4190,6 +4190,12 @@ BinaryConsts::ASTNodes WasmBinaryReader::readExpression(Expression*& curr) {
if (maybeVisitTableCopy(curr, opcode)) {
break;
}
if (maybeVisitLoad(curr, opcode, BinaryConsts::MiscPrefix)) {
break;
}
if (maybeVisitStore(curr, opcode, BinaryConsts::MiscPrefix)) {
break;
}
throwError("invalid code after misc prefix: " + std::to_string(opcode));
break;
}
Expand Down Expand Up @@ -4330,10 +4336,10 @@ BinaryConsts::ASTNodes WasmBinaryReader::readExpression(Expression*& curr) {
if (maybeVisitConst(curr, code)) {
break;
}
if (maybeVisitLoad(curr, code, /*isAtomic=*/false)) {
if (maybeVisitLoad(curr, code, /*prefix=*/std::nullopt)) {
break;
}
if (maybeVisitStore(curr, code, /*isAtomic=*/false)) {
if (maybeVisitStore(curr, code, /*prefix=*/std::nullopt)) {
break;
}
throwError("bad node code " + std::to_string(code));
Expand Down Expand Up @@ -4709,14 +4715,15 @@ Index WasmBinaryReader::readMemoryAccess(Address& alignment, Address& offset) {
return memIdx;
}

bool WasmBinaryReader::maybeVisitLoad(Expression*& out,
uint8_t code,
bool isAtomic) {
bool WasmBinaryReader::maybeVisitLoad(
Expression*& out,
uint8_t code,
std::optional<BinaryConsts::ASTNodes> prefix) {
Load* curr;
auto allocate = [&]() {
curr = allocator.alloc<Load>();
};
if (!isAtomic) {
if (!prefix) {
switch (code) {
case BinaryConsts::I32LoadMem8S:
allocate();
Expand Down Expand Up @@ -4797,7 +4804,7 @@ bool WasmBinaryReader::maybeVisitLoad(Expression*& out,
return false;
}
BYN_TRACE("zz node: Load\n");
} else {
} else if (prefix == BinaryConsts::AtomicPrefix) {
switch (code) {
case BinaryConsts::I32AtomicLoad8U:
allocate();
Expand Down Expand Up @@ -4838,9 +4845,22 @@ bool WasmBinaryReader::maybeVisitLoad(Expression*& out,
return false;
}
BYN_TRACE("zz node: AtomicLoad\n");
} else if (prefix == BinaryConsts::MiscPrefix) {
switch (code) {
case BinaryConsts::F32_F16LoadMem:
allocate();
curr->bytes = 2;
curr->type = Type::f32;
break;
default:
return false;
}
BYN_TRACE("zz node: Load\n");
} else {
return false;
}

curr->isAtomic = isAtomic;
curr->isAtomic = prefix == BinaryConsts::AtomicPrefix;
Index memIdx = readMemoryAccess(curr->align, curr->offset);
memoryRefs[memIdx].push_back(&curr->memory);
curr->ptr = popNonVoidExpression();
Expand All @@ -4849,11 +4869,12 @@ bool WasmBinaryReader::maybeVisitLoad(Expression*& out,
return true;
}

bool WasmBinaryReader::maybeVisitStore(Expression*& out,
uint8_t code,
bool isAtomic) {
bool WasmBinaryReader::maybeVisitStore(
Expression*& out,
uint8_t code,
std::optional<BinaryConsts::ASTNodes> prefix) {
Store* curr;
if (!isAtomic) {
if (!prefix) {
switch (code) {
case BinaryConsts::I32StoreMem8:
curr = allocator.alloc<Store>();
Expand Down Expand Up @@ -4903,7 +4924,7 @@ bool WasmBinaryReader::maybeVisitStore(Expression*& out,
default:
return false;
}
} else {
} else if (prefix == BinaryConsts::AtomicPrefix) {
switch (code) {
case BinaryConsts::I32AtomicStore8:
curr = allocator.alloc<Store>();
Expand Down Expand Up @@ -4943,9 +4964,21 @@ bool WasmBinaryReader::maybeVisitStore(Expression*& out,
default:
return false;
}
} else if (prefix == BinaryConsts::MiscPrefix) {
switch (code) {
case BinaryConsts::F32_F16StoreMem:
curr = allocator.alloc<Store>();
curr->bytes = 2;
curr->valueType = Type::f32;
break;
default:
return false;
}
} else {
return false;
}

curr->isAtomic = isAtomic;
curr->isAtomic = prefix == BinaryConsts::AtomicPrefix;
BYN_TRACE("zz node: Store\n");
Index memIdx = readMemoryAccess(curr->align, curr->offset);
memoryRefs[memIdx].push_back(&curr->memory);
Expand Down
30 changes: 26 additions & 4 deletions src/wasm/wasm-stack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,20 @@ void BinaryInstWriter::visitLoad(Load* curr) {
}
break;
}
case Type::f32:
o << int8_t(BinaryConsts::F32LoadMem);
case Type::f32: {
switch (curr->bytes) {
case 2:
o << int8_t(BinaryConsts::MiscPrefix)
<< U32LEB(BinaryConsts::F32_F16LoadMem);
break;
case 4:
o << int8_t(BinaryConsts::F32LoadMem);
break;
default:
WASM_UNREACHABLE("invalid load size");
}
break;
}
case Type::f64:
o << int8_t(BinaryConsts::F64LoadMem);
break;
Expand Down Expand Up @@ -359,9 +370,20 @@ void BinaryInstWriter::visitStore(Store* curr) {
}
break;
}
case Type::f32:
o << int8_t(BinaryConsts::F32StoreMem);
case Type::f32: {
switch (curr->bytes) {
case 2:
o << int8_t(BinaryConsts::MiscPrefix)
<< U32LEB(BinaryConsts::F32_F16StoreMem);
break;
case 4:
o << int8_t(BinaryConsts::F32StoreMem);
break;
default:
WASM_UNREACHABLE("invalid store size");
}
break;
}
case Type::f64:
o << int8_t(BinaryConsts::F64StoreMem);
break;
Expand Down
Loading
Loading