Skip to content

Commit

Permalink
Add ResultAccuracy to ExpOp (#2694)
Browse files Browse the repository at this point in the history
Implementation of RFC: #2592

For ExpOp.

TODO: Modify spec.md

---------

Co-authored-by: Rachel Han <[email protected]>
  • Loading branch information
GleasonK and hanrach9 authored Jan 29, 2025
1 parent 8993ef7 commit 7c50d4e
Show file tree
Hide file tree
Showing 37 changed files with 3,838 additions and 32 deletions.
2 changes: 1 addition & 1 deletion BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -1547,7 +1547,7 @@ gentbl_cc_library(
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "stablehlo/dialect/VhloAttrs.td",
td_file = "stablehlo/dialect/VhloEnums.td",
deps = [
":vhlo_ops_td_files",
],
Expand Down
23 changes: 23 additions & 0 deletions stablehlo/dialect/AssemblyFormat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -860,6 +860,29 @@ ParseResult parseCustomCallTarget(AsmParser& parser, StringAttr& target) {
return parser.parseSymbolName(target);
}

void printResultAccuracyAttr(AsmPrinter& odsPrinter, APFloat atol, APFloat rtol,
int64_t ulps, Attribute mode) {
odsPrinter << "<";
if (!atol.isZero()) {
odsPrinter << "atol = ";
odsPrinter.printFloat(atol);
odsPrinter << ", ";
}
if (!rtol.isZero()) {
odsPrinter << "rtol = ";
odsPrinter.printFloat(rtol);
odsPrinter << ", ";
}
if (ulps != 0) {
odsPrinter << "ulps = ";
odsPrinter << ulps;
odsPrinter << ", ";
}
odsPrinter << "mode = ";
odsPrinter.printAttribute(mode);
odsPrinter << ">";
}

void printTypeExtensions(BoundedAttrInterface attr, DialectAsmPrinter& os) {
os << "bounds<";
llvm::interleaveComma(attr.getBounds(), os,
Expand Down
59 changes: 59 additions & 0 deletions stablehlo/dialect/AssemblyFormat.h
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,65 @@ ParseResult parseDotDimensionNumbers(AsmParser& parser, AttrTy& target) {
return success();
}

// ResultAccuracyAttr - Custom printing and parsing for ResultAccuracyAttr.
//
// ResultAccuractAttr ::= `<` OptAtolAccuracy OptRtolAccuracy
// OptUlpAccuracy ModeAccuracy `>`
// OptAtolAccuracy ::= `atol` `=` APFloat `, ` | eps
// OptRtolAccuracy ::= `rtol` `=` APFloat `, ` | eps
// OptUlpAccuracy ::= `ulps` `=` int64_t `, ` | eps
// ModeAccuracy ::= `mode` `=` ResultAccuracyModeAttr
void printResultAccuracyAttr(AsmPrinter& odsPrinter, APFloat atol, APFloat rtol,
int64_t ulps, Attribute mode);

template <typename AttrTy, typename ModeTy>
Attribute parseResultAccuracyAttr(AsmParser& parser, Type type) {
APFloat resultAtol = APFloat::getZero(APFloat::IEEEdouble());
APFloat resultRtol = APFloat::getZero(APFloat::IEEEdouble());
int64_t resultUlps = 0;

// Parse literal '<'
if (parser.parseLess()) return {};

// OptAtolAccuracy
if (succeeded(parser.parseOptionalKeyword("atol"))) {
double value;
if (parser.parseEqual() || parser.parseFloat(value) || parser.parseComma())
return {};
resultAtol = APFloat(value);
}

// OptRtolAccuracy
if (succeeded(parser.parseOptionalKeyword("rtol"))) {
double value;
if (parser.parseEqual() || parser.parseFloat(value) || parser.parseComma())
return {};
resultRtol = APFloat(value);
}

// OptUlpAccuracy
if (succeeded(parser.parseOptionalKeyword("ulps"))) {
int64_t value;
if (parser.parseEqual() || parser.parseInteger(value) ||
parser.parseComma())
return {};
resultUlps = value;
}

// ModeAccuracy
ModeTy modeAttr;
if (parser.parseKeyword("mode") || parser.parseEqual() ||
parser.parseAttribute(modeAttr)) {
return {};
}

// Parse literal '>'
if (parser.parseGreater()) return {};
return parser.getChecked<AttrTy>(parser.getCurrentLocation(),
parser.getContext(), resultAtol, resultRtol,
resultUlps, modeAttr);
}

} // namespace hlo
} // namespace mlir

Expand Down
2 changes: 1 addition & 1 deletion stablehlo/dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ mlir_tablegen(VhloEnums.cpp.inc -gen-enum-defs)
set(LLVM_TARGET_DEFINITIONS VhloOps.td)
mlir_tablegen(VhloAttrs.h.inc -gen-attrdef-decls)
mlir_tablegen(VhloAttrs.cpp.inc -gen-attrdef-defs)
set(LLVM_TARGET_DEFINITIONS VhloAttrs.td)
set(LLVM_TARGET_DEFINITIONS VhloEnums.td)
mlir_tablegen(VhloAttrInterfaces.h.inc -gen-attr-interface-decls)
mlir_tablegen(VhloAttrInterfaces.cpp.inc -gen-attr-interface-defs)
set(LLVM_TARGET_DEFINITIONS VhloTypes.td)
Expand Down
15 changes: 15 additions & 0 deletions stablehlo/dialect/StablehloAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.

include "mlir/IR/OpBase.td"
include "mlir/IR/TensorEncoding.td"
include "stablehlo/dialect/StablehloTypes.td"

def StableHLO_Dims : ArrayRefParameter<"int64_t", "Dimension"> {
let parser = "parseDimSizes($_parser)";
Expand Down Expand Up @@ -209,4 +210,18 @@ def StableHLO_ConvDimensionNumbers : AttrDef<StableHLO_Dialect, "ConvDimensionNu
let hasCustomAssemblyFormat = 1;
}

def StableHLO_ResultAccuracyAttr : AttrDef<StableHLO_Dialect, "ResultAccuracy"> {
let mnemonic = "result_accuracy";
let summary = "The requested accuracy for transcendental unary ops.";
let parameters = (ins
"APFloat":$atol,
"APFloat":$rtol,
"int64_t":$ulps,
StableHLO_ResultAccuracyModeAttr:$mode
);
let hasCustomAssemblyFormat = 1;
let genVerifyDecl = 1;
let constBuilderCall = "ResultAccuracyAttr::get($_builder.getContext(), APFloat(0.0), APFloat(0.0), 0, ResultAccuracyModeAttr::get($_builder.getContext(), $0))";
}

#endif // STABLEHLO_DIALECT_STABLEHLO_ATTRS
86 changes: 79 additions & 7 deletions stablehlo/dialect/StablehloBytecode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#include <cstdint>
#include <memory>

#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/TypeSwitch.h"
Expand Down Expand Up @@ -180,6 +181,18 @@ enum AttributeCode {
/// allowImpreciseAccumulation : svarint
/// }
kDotAlgorithmAttr = 15,

// ResultAccuracyModeAttr {
// mode: varint (encoded enum)
// }
kResultAccuracyModeAttr = 16,

// ResultAccuracyAttr {
// atol: APFloat
// rtol: APFloat
// ulps: svarint
// }
kResultAccuracyAttr = 17,
};

/// This enum contains marker codes used to indicate which type is
Expand Down Expand Up @@ -241,6 +254,10 @@ class StablehloBytecodeInterface : public BytecodeDialectInterface {
OutputOperandAliasAttr readOutputOperandAliasAttr(
DialectBytecodeReader &reader) const;
PrecisionAttr readPrecisionAttr(DialectBytecodeReader &reader) const;
ResultAccuracyAttr readResultAccuracyAttr(
DialectBytecodeReader &reader) const;
ResultAccuracyModeAttr readResultAccuracyModeAttr(
DialectBytecodeReader &reader) const;
RngAlgorithmAttr readRngAlgorithmAttr(DialectBytecodeReader &reader) const;
RngDistributionAttr readRngDistributionAttr(
DialectBytecodeReader &reader) const;
Expand All @@ -264,6 +281,8 @@ class StablehloBytecodeInterface : public BytecodeDialectInterface {
DialectBytecodeWriter &writer) const;
void write(OutputOperandAliasAttr attr, DialectBytecodeWriter &writer) const;
void write(PrecisionAttr attr, DialectBytecodeWriter &writer) const;
void write(ResultAccuracyAttr attr, DialectBytecodeWriter &writer) const;
void write(ResultAccuracyModeAttr attr, DialectBytecodeWriter &writer) const;
void write(RngAlgorithmAttr attr, DialectBytecodeWriter &writer) const;
void write(RngDistributionAttr attr, DialectBytecodeWriter &writer) const;
void write(ScatterDimensionNumbersAttr attr,
Expand Down Expand Up @@ -327,6 +346,10 @@ Attribute StablehloBytecodeInterface::readAttribute(
return readOutputOperandAliasAttr(reader);
case stablehlo_encoding::kPrecisionAttr:
return readPrecisionAttr(reader);
case stablehlo_encoding::kResultAccuracyAttr:
return readResultAccuracyAttr(reader);
case stablehlo_encoding::kResultAccuracyModeAttr:
return readResultAccuracyModeAttr(reader);
case stablehlo_encoding::kRngAlgorithmAttr:
return readRngAlgorithmAttr(reader);
case stablehlo_encoding::kRngDistributionAttr:
Expand All @@ -352,13 +375,13 @@ LogicalResult StablehloBytecodeInterface::writeAttribute(
.Case<ChannelHandleAttr, ComparisonDirectionAttr, ComparisonTypeAttr,
ConvDimensionNumbersAttr, DotAlgorithmAttr, DotDimensionNumbersAttr,
FftTypeAttr, GatherDimensionNumbersAttr, OutputOperandAliasAttr,
PrecisionAttr, RngAlgorithmAttr, RngDistributionAttr,
ScatterDimensionNumbersAttr, TransposeAttr, TypeExtensionsAttr>(
[&](auto attr) {
LOG_WRITE_CALL;
write(attr, writer);
return success();
})
PrecisionAttr, ResultAccuracyAttr, ResultAccuracyModeAttr,
RngAlgorithmAttr, RngDistributionAttr, ScatterDimensionNumbersAttr,
TransposeAttr, TypeExtensionsAttr>([&](auto attr) {
LOG_WRITE_CALL;
write(attr, writer);
return success();
})
.Default([&](Attribute) {
LOG_NOT_IMPLEMENTED;
return failure();
Expand Down Expand Up @@ -806,6 +829,55 @@ void StablehloBytecodeInterface::writeVersion(
}
}

//===----------------------------------------------------------------------===//
// ResultAccuracyModeAttr

ResultAccuracyModeAttr StablehloBytecodeInterface::readResultAccuracyModeAttr(
DialectBytecodeReader &reader) const {
LOG_READ_CALL;
return hlo::bytecode::readEnumAttribute<ResultAccuracyModeAttr>(
reader, getContext(),
[](uint32_t val) { return symbolizeResultAccuracyMode(val); });
}

void StablehloBytecodeInterface::write(ResultAccuracyModeAttr attr,
DialectBytecodeWriter &writer) const {
writer.writeVarInt(stablehlo_encoding::kResultAccuracyModeAttr);
hlo::bytecode::writeEnumAttribute<ResultAccuracyMode>(attr, writer);
}

//===----------------------------------------------------------------------===//
// ResultAccuracyAttr

ResultAccuracyAttr StablehloBytecodeInterface::readResultAccuracyAttr(
DialectBytecodeReader &reader) const {
LOG_READ_CALL;
FailureOr<APFloat> atol;
FailureOr<APFloat> rtol;
int64_t ulps;
ResultAccuracyModeAttr mode;
if (failed(atol =
reader.readAPFloatWithKnownSemantics(APFloat::IEEEdouble())) ||
failed(rtol =
reader.readAPFloatWithKnownSemantics(APFloat::IEEEdouble())) ||
failed(reader.readSignedVarInt(ulps)) ||
failed(reader.readAttribute(mode))) {
mlir::emitWarning(mlir::UnknownLoc::get(getContext()))
<< "failed to read APFloat for atol";
return ResultAccuracyAttr();
}
return ResultAccuracyAttr::get(getContext(), *atol, *rtol, ulps, mode);
}

void StablehloBytecodeInterface::write(ResultAccuracyAttr attr,
DialectBytecodeWriter &writer) const {
writer.writeVarInt(stablehlo_encoding::kResultAccuracyAttr);
writer.writeAPFloatWithKnownSemantics(attr.getAtol());
writer.writeAPFloatWithKnownSemantics(attr.getRtol());
writer.writeSignedVarInt(attr.getUlps());
writer.writeAttribute(attr.getMode());
}

} // namespace

void addBytecodeInterface(StablehloDialect *dialect) {
Expand Down
23 changes: 23 additions & 0 deletions stablehlo/dialect/StablehloEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,29 @@ def StableHLO_PrecisionAttr : EnumAttr<StableHLO_Dialect, StableHLO_Precision, "
def StableHLO_PrecisionConfigAttr:
TypedArrayAttrBase<StableHLO_PrecisionAttr, "Precision Config attribute">;

//===----------------------------------------------------------------------===//
// Result Accuracy enum definitions.
//===----------------------------------------------------------------------===//

def STABLEHLO_RESULT_ACCURACY_DEFAULT : I32EnumAttrCase<"DEFAULT", 0>;
def STABLEHLO_RESULT_ACCURACY_HIGHEST : I32EnumAttrCase<"HIGHEST", 1>;
def STABLEHLO_RESULT_ACCURACY_TOLERANCE: I32EnumAttrCase<"TOLERANCE", 2>;

def StableHLO_ResultAccuracyMode : I32EnumAttr<"ResultAccuracyMode",
"XLA result accuracy mode.",
[
STABLEHLO_RESULT_ACCURACY_DEFAULT,
STABLEHLO_RESULT_ACCURACY_HIGHEST,
STABLEHLO_RESULT_ACCURACY_TOLERANCE
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::stablehlo";
}

def StableHLO_ResultAccuracyModeAttr : EnumAttr<StableHLO_Dialect, StableHLO_ResultAccuracyMode, "result_accuracy_mode"> {
let assemblyFormat = "`<` $value `>`";
}

//===----------------------------------------------------------------------===//
// Fast Fourier Transform Type enum definitions.
//===----------------------------------------------------------------------===//
Expand Down
38 changes: 38 additions & 0 deletions stablehlo/dialect/StablehloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ limitations under the License.
#include "stablehlo/dialect/AssemblyFormat.h"
#include "stablehlo/dialect/Base.h"
#include "stablehlo/dialect/StablehloBytecode.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "stablehlo/dialect/StablehloOps.h.inc"
#include "stablehlo/dialect/TypeInference.h"

Expand Down Expand Up @@ -792,6 +793,29 @@ LogicalResult DotAlgorithmAttr::verify(
allowImpreciseAccumulation);
}

// ===----------------------------------------------------------------------===//
// ExpOp
//===----------------------------------------------------------------------===//

LogicalResult ResultAccuracyAttr::verify(
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, APFloat atol,
APFloat rtol, int64_t ulps, ResultAccuracyModeAttr mode) {
return hlo::verifyResultAccuracyAttr(
emitError, atol, rtol, ulps,
stringifyResultAccuracyMode(mode.getValue()));
}

LogicalResult ExpOp::verify() {
if (auto attr = getResultAccuracyAttr()) {
if (failed(ResultAccuracyAttr::verify([&] { return emitError(); },
attr.getAtol(), attr.getRtol(),
attr.getUlps(), attr.getMode()))) {
return failure();
}
}
return success();
}

//===----------------------------------------------------------------------===//
// FftOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -3127,6 +3151,20 @@ Attribute DotDimensionNumbersAttr::parse(AsmParser& parser, Type type) {
lhsContractingDimensions, rhsContractingDimensions);
}

// ===----------------------------------------------------------------------===//
// Custom unary op
// ===----------------------------------------------------------------------===//

void ResultAccuracyAttr::print(AsmPrinter& odsPrinter) const {
hlo::printResultAccuracyAttr(odsPrinter, getAtol(), getRtol(), getUlps(),
getMode());
}

Attribute ResultAccuracyAttr::parse(AsmParser& parser, Type type) {
return hlo::parseResultAccuracyAttr<ResultAccuracyAttr,
ResultAccuracyModeAttr>(parser, type);
}

namespace {
enum NonSpatialDim : int64_t {
IOBatch = -1, // Input or output batch dimension
Expand Down
17 changes: 17 additions & 0 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,23 @@ def StableHLO_ExpOp: StableHLO_UnaryElementwiseOp<"exponential",
%result = stablehlo.exponential %operand : tensor<2x2xf64>
```
}];
let arguments = (ins HLO_FpComplexOrQuantizedIntTensor:$operand,
DefaultValuedOptionalAttr<StableHLO_ResultAccuracyAttr, "::mlir::stablehlo::ResultAccuracyMode::DEFAULT">:$result_accuracy);
let results = (outs HLO_FpComplexOrQuantizedIntTensor:$result);
let extraClassDeclaration = commonClassDeclaration # [{
LogicalResult reifyReturnTypeShapes(
OpBuilder& builder, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) {
return ::mlir::hlo::deriveShapeFromOperand(&builder, getOperation(),
operands.front(),
&reifiedReturnShapes);
}
}];
let hasVerifier = 1;

let assemblyFormat = [{
$operand attr-dict `:` custom<SameOperandsAndResultType>(type($operand), type($result))
}];
}

def StableHLO_Expm1Op: StableHLO_UnaryElementwiseOp<"exponential_minus_one",
Expand Down
Loading

0 comments on commit 7c50d4e

Please sign in to comment.