diff --git a/BUILD.bazel b/BUILD.bazel index 866e9715071..40cc04aedda 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -1547,7 +1547,7 @@ gentbl_cc_library( ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "stablehlo/dialect/VhloAttrs.td", + td_file = "stablehlo/dialect/VhloEnums.td", deps = [ ":vhlo_ops_td_files", ], diff --git a/stablehlo/dialect/AssemblyFormat.cpp b/stablehlo/dialect/AssemblyFormat.cpp index 85c4beedfa2..6e830ef85cc 100644 --- a/stablehlo/dialect/AssemblyFormat.cpp +++ b/stablehlo/dialect/AssemblyFormat.cpp @@ -860,6 +860,29 @@ ParseResult parseCustomCallTarget(AsmParser& parser, StringAttr& target) { return parser.parseSymbolName(target); } +void printResultAccuracyAttr(AsmPrinter& odsPrinter, APFloat atol, APFloat rtol, + int64_t ulps, Attribute mode) { + odsPrinter << "<"; + if (!atol.isZero()) { + odsPrinter << "atol = "; + odsPrinter.printFloat(atol); + odsPrinter << ", "; + } + if (!rtol.isZero()) { + odsPrinter << "rtol = "; + odsPrinter.printFloat(rtol); + odsPrinter << ", "; + } + if (ulps != 0) { + odsPrinter << "ulps = "; + odsPrinter << ulps; + odsPrinter << ", "; + } + odsPrinter << "mode = "; + odsPrinter.printAttribute(mode); + odsPrinter << ">"; +} + void printTypeExtensions(BoundedAttrInterface attr, DialectAsmPrinter& os) { os << "bounds<"; llvm::interleaveComma(attr.getBounds(), os, diff --git a/stablehlo/dialect/AssemblyFormat.h b/stablehlo/dialect/AssemblyFormat.h index e472041176f..02ec3821c5b 100644 --- a/stablehlo/dialect/AssemblyFormat.h +++ b/stablehlo/dialect/AssemblyFormat.h @@ -378,6 +378,65 @@ ParseResult parseDotDimensionNumbers(AsmParser& parser, AttrTy& target) { return success(); } +// ResultAccuracyAttr - Custom printing and parsing for ResultAccuracyAttr. +// +// ResultAccuractAttr ::= `<` OptAtolAccuracy OptRtolAccuracy +// OptUlpAccuracy ModeAccuracy `>` +// OptAtolAccuracy ::= `atol` `=` APFloat `, ` | eps +// OptRtolAccuracy ::= `rtol` `=` APFloat `, ` | eps +// OptUlpAccuracy ::= `ulps` `=` int64_t `, ` | eps +// ModeAccuracy ::= `mode` `=` ResultAccuracyModeAttr +void printResultAccuracyAttr(AsmPrinter& odsPrinter, APFloat atol, APFloat rtol, + int64_t ulps, Attribute mode); + +template +Attribute parseResultAccuracyAttr(AsmParser& parser, Type type) { + APFloat resultAtol = APFloat::getZero(APFloat::IEEEdouble()); + APFloat resultRtol = APFloat::getZero(APFloat::IEEEdouble()); + int64_t resultUlps = 0; + + // Parse literal '<' + if (parser.parseLess()) return {}; + + // OptAtolAccuracy + if (succeeded(parser.parseOptionalKeyword("atol"))) { + double value; + if (parser.parseEqual() || parser.parseFloat(value) || parser.parseComma()) + return {}; + resultAtol = APFloat(value); + } + + // OptRtolAccuracy + if (succeeded(parser.parseOptionalKeyword("rtol"))) { + double value; + if (parser.parseEqual() || parser.parseFloat(value) || parser.parseComma()) + return {}; + resultRtol = APFloat(value); + } + + // OptUlpAccuracy + if (succeeded(parser.parseOptionalKeyword("ulps"))) { + int64_t value; + if (parser.parseEqual() || parser.parseInteger(value) || + parser.parseComma()) + return {}; + resultUlps = value; + } + + // ModeAccuracy + ModeTy modeAttr; + if (parser.parseKeyword("mode") || parser.parseEqual() || + parser.parseAttribute(modeAttr)) { + return {}; + } + + // Parse literal '>' + if (parser.parseGreater()) return {}; + return parser.getChecked(parser.getCurrentLocation(), + parser.getContext(), resultAtol, resultRtol, + resultUlps, modeAttr); +} + } // namespace hlo } // namespace mlir diff --git a/stablehlo/dialect/CMakeLists.txt b/stablehlo/dialect/CMakeLists.txt index 051160fd7dd..8312878ea04 100644 --- a/stablehlo/dialect/CMakeLists.txt +++ b/stablehlo/dialect/CMakeLists.txt @@ -190,7 +190,7 @@ mlir_tablegen(VhloEnums.cpp.inc -gen-enum-defs) set(LLVM_TARGET_DEFINITIONS VhloOps.td) mlir_tablegen(VhloAttrs.h.inc -gen-attrdef-decls) mlir_tablegen(VhloAttrs.cpp.inc -gen-attrdef-defs) -set(LLVM_TARGET_DEFINITIONS VhloAttrs.td) +set(LLVM_TARGET_DEFINITIONS VhloEnums.td) mlir_tablegen(VhloAttrInterfaces.h.inc -gen-attr-interface-decls) mlir_tablegen(VhloAttrInterfaces.cpp.inc -gen-attr-interface-defs) set(LLVM_TARGET_DEFINITIONS VhloTypes.td) diff --git a/stablehlo/dialect/StablehloAttrs.td b/stablehlo/dialect/StablehloAttrs.td index debc7d879b4..09aa0e0e071 100644 --- a/stablehlo/dialect/StablehloAttrs.td +++ b/stablehlo/dialect/StablehloAttrs.td @@ -19,6 +19,7 @@ limitations under the License. include "mlir/IR/OpBase.td" include "mlir/IR/TensorEncoding.td" +include "stablehlo/dialect/StablehloTypes.td" def StableHLO_Dims : ArrayRefParameter<"int64_t", "Dimension"> { let parser = "parseDimSizes($_parser)"; @@ -209,4 +210,18 @@ def StableHLO_ConvDimensionNumbers : AttrDef { + let mnemonic = "result_accuracy"; + let summary = "The requested accuracy for transcendental unary ops."; + let parameters = (ins + "APFloat":$atol, + "APFloat":$rtol, + "int64_t":$ulps, + StableHLO_ResultAccuracyModeAttr:$mode + ); + let hasCustomAssemblyFormat = 1; + let genVerifyDecl = 1; + let constBuilderCall = "ResultAccuracyAttr::get($_builder.getContext(), APFloat(0.0), APFloat(0.0), 0, ResultAccuracyModeAttr::get($_builder.getContext(), $0))"; +} + #endif // STABLEHLO_DIALECT_STABLEHLO_ATTRS diff --git a/stablehlo/dialect/StablehloBytecode.cpp b/stablehlo/dialect/StablehloBytecode.cpp index fd36d0d5df4..b8db2ba51aa 100644 --- a/stablehlo/dialect/StablehloBytecode.cpp +++ b/stablehlo/dialect/StablehloBytecode.cpp @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "llvm/ADT/APFloat.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/TypeSwitch.h" @@ -180,6 +181,18 @@ enum AttributeCode { /// allowImpreciseAccumulation : svarint /// } kDotAlgorithmAttr = 15, + + // ResultAccuracyModeAttr { + // mode: varint (encoded enum) + // } + kResultAccuracyModeAttr = 16, + + // ResultAccuracyAttr { + // atol: APFloat + // rtol: APFloat + // ulps: svarint + // } + kResultAccuracyAttr = 17, }; /// This enum contains marker codes used to indicate which type is @@ -241,6 +254,10 @@ class StablehloBytecodeInterface : public BytecodeDialectInterface { OutputOperandAliasAttr readOutputOperandAliasAttr( DialectBytecodeReader &reader) const; PrecisionAttr readPrecisionAttr(DialectBytecodeReader &reader) const; + ResultAccuracyAttr readResultAccuracyAttr( + DialectBytecodeReader &reader) const; + ResultAccuracyModeAttr readResultAccuracyModeAttr( + DialectBytecodeReader &reader) const; RngAlgorithmAttr readRngAlgorithmAttr(DialectBytecodeReader &reader) const; RngDistributionAttr readRngDistributionAttr( DialectBytecodeReader &reader) const; @@ -264,6 +281,8 @@ class StablehloBytecodeInterface : public BytecodeDialectInterface { DialectBytecodeWriter &writer) const; void write(OutputOperandAliasAttr attr, DialectBytecodeWriter &writer) const; void write(PrecisionAttr attr, DialectBytecodeWriter &writer) const; + void write(ResultAccuracyAttr attr, DialectBytecodeWriter &writer) const; + void write(ResultAccuracyModeAttr attr, DialectBytecodeWriter &writer) const; void write(RngAlgorithmAttr attr, DialectBytecodeWriter &writer) const; void write(RngDistributionAttr attr, DialectBytecodeWriter &writer) const; void write(ScatterDimensionNumbersAttr attr, @@ -327,6 +346,10 @@ Attribute StablehloBytecodeInterface::readAttribute( return readOutputOperandAliasAttr(reader); case stablehlo_encoding::kPrecisionAttr: return readPrecisionAttr(reader); + case stablehlo_encoding::kResultAccuracyAttr: + return readResultAccuracyAttr(reader); + case stablehlo_encoding::kResultAccuracyModeAttr: + return readResultAccuracyModeAttr(reader); case stablehlo_encoding::kRngAlgorithmAttr: return readRngAlgorithmAttr(reader); case stablehlo_encoding::kRngDistributionAttr: @@ -352,13 +375,13 @@ LogicalResult StablehloBytecodeInterface::writeAttribute( .Case( - [&](auto attr) { - LOG_WRITE_CALL; - write(attr, writer); - return success(); - }) + PrecisionAttr, ResultAccuracyAttr, ResultAccuracyModeAttr, + RngAlgorithmAttr, RngDistributionAttr, ScatterDimensionNumbersAttr, + TransposeAttr, TypeExtensionsAttr>([&](auto attr) { + LOG_WRITE_CALL; + write(attr, writer); + return success(); + }) .Default([&](Attribute) { LOG_NOT_IMPLEMENTED; return failure(); @@ -806,6 +829,55 @@ void StablehloBytecodeInterface::writeVersion( } } +//===----------------------------------------------------------------------===// +// ResultAccuracyModeAttr + +ResultAccuracyModeAttr StablehloBytecodeInterface::readResultAccuracyModeAttr( + DialectBytecodeReader &reader) const { + LOG_READ_CALL; + return hlo::bytecode::readEnumAttribute( + reader, getContext(), + [](uint32_t val) { return symbolizeResultAccuracyMode(val); }); +} + +void StablehloBytecodeInterface::write(ResultAccuracyModeAttr attr, + DialectBytecodeWriter &writer) const { + writer.writeVarInt(stablehlo_encoding::kResultAccuracyModeAttr); + hlo::bytecode::writeEnumAttribute(attr, writer); +} + +//===----------------------------------------------------------------------===// +// ResultAccuracyAttr + +ResultAccuracyAttr StablehloBytecodeInterface::readResultAccuracyAttr( + DialectBytecodeReader &reader) const { + LOG_READ_CALL; + FailureOr atol; + FailureOr rtol; + int64_t ulps; + ResultAccuracyModeAttr mode; + if (failed(atol = + reader.readAPFloatWithKnownSemantics(APFloat::IEEEdouble())) || + failed(rtol = + reader.readAPFloatWithKnownSemantics(APFloat::IEEEdouble())) || + failed(reader.readSignedVarInt(ulps)) || + failed(reader.readAttribute(mode))) { + mlir::emitWarning(mlir::UnknownLoc::get(getContext())) + << "failed to read APFloat for atol"; + return ResultAccuracyAttr(); + } + return ResultAccuracyAttr::get(getContext(), *atol, *rtol, ulps, mode); +} + +void StablehloBytecodeInterface::write(ResultAccuracyAttr attr, + DialectBytecodeWriter &writer) const { + writer.writeVarInt(stablehlo_encoding::kResultAccuracyAttr); + writer.writeAPFloatWithKnownSemantics(attr.getAtol()); + writer.writeAPFloatWithKnownSemantics(attr.getRtol()); + writer.writeSignedVarInt(attr.getUlps()); + writer.writeAttribute(attr.getMode()); +} + } // namespace void addBytecodeInterface(StablehloDialect *dialect) { diff --git a/stablehlo/dialect/StablehloEnums.td b/stablehlo/dialect/StablehloEnums.td index 65db5b5404c..f69f94a2c87 100644 --- a/stablehlo/dialect/StablehloEnums.td +++ b/stablehlo/dialect/StablehloEnums.td @@ -46,6 +46,29 @@ def StableHLO_PrecisionAttr : EnumAttr; +//===----------------------------------------------------------------------===// +// Result Accuracy enum definitions. +//===----------------------------------------------------------------------===// + +def STABLEHLO_RESULT_ACCURACY_DEFAULT : I32EnumAttrCase<"DEFAULT", 0>; +def STABLEHLO_RESULT_ACCURACY_HIGHEST : I32EnumAttrCase<"HIGHEST", 1>; +def STABLEHLO_RESULT_ACCURACY_TOLERANCE: I32EnumAttrCase<"TOLERANCE", 2>; + +def StableHLO_ResultAccuracyMode : I32EnumAttr<"ResultAccuracyMode", + "XLA result accuracy mode.", + [ + STABLEHLO_RESULT_ACCURACY_DEFAULT, + STABLEHLO_RESULT_ACCURACY_HIGHEST, + STABLEHLO_RESULT_ACCURACY_TOLERANCE + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::stablehlo"; +} + +def StableHLO_ResultAccuracyModeAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + //===----------------------------------------------------------------------===// // Fast Fourier Transform Type enum definitions. //===----------------------------------------------------------------------===// diff --git a/stablehlo/dialect/StablehloOps.cpp b/stablehlo/dialect/StablehloOps.cpp index f32589ef9df..f6a4c0d5466 100644 --- a/stablehlo/dialect/StablehloOps.cpp +++ b/stablehlo/dialect/StablehloOps.cpp @@ -87,6 +87,7 @@ limitations under the License. #include "stablehlo/dialect/AssemblyFormat.h" #include "stablehlo/dialect/Base.h" #include "stablehlo/dialect/StablehloBytecode.h" +#include "stablehlo/dialect/StablehloOps.h" #include "stablehlo/dialect/StablehloOps.h.inc" #include "stablehlo/dialect/TypeInference.h" @@ -792,6 +793,29 @@ LogicalResult DotAlgorithmAttr::verify( allowImpreciseAccumulation); } +// ===----------------------------------------------------------------------===// +// ExpOp +//===----------------------------------------------------------------------===// + +LogicalResult ResultAccuracyAttr::verify( + ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, APFloat atol, + APFloat rtol, int64_t ulps, ResultAccuracyModeAttr mode) { + return hlo::verifyResultAccuracyAttr( + emitError, atol, rtol, ulps, + stringifyResultAccuracyMode(mode.getValue())); +} + +LogicalResult ExpOp::verify() { + if (auto attr = getResultAccuracyAttr()) { + if (failed(ResultAccuracyAttr::verify([&] { return emitError(); }, + attr.getAtol(), attr.getRtol(), + attr.getUlps(), attr.getMode()))) { + return failure(); + } + } + return success(); +} + //===----------------------------------------------------------------------===// // FftOp //===----------------------------------------------------------------------===// @@ -3127,6 +3151,20 @@ Attribute DotDimensionNumbersAttr::parse(AsmParser& parser, Type type) { lhsContractingDimensions, rhsContractingDimensions); } +// ===----------------------------------------------------------------------===// +// Custom unary op +// ===----------------------------------------------------------------------===// + +void ResultAccuracyAttr::print(AsmPrinter& odsPrinter) const { + hlo::printResultAccuracyAttr(odsPrinter, getAtol(), getRtol(), getUlps(), + getMode()); +} + +Attribute ResultAccuracyAttr::parse(AsmParser& parser, Type type) { + return hlo::parseResultAccuracyAttr(parser, type); +} + namespace { enum NonSpatialDim : int64_t { IOBatch = -1, // Input or output batch dimension diff --git a/stablehlo/dialect/StablehloOps.td b/stablehlo/dialect/StablehloOps.td index 9d393fa1c43..a83b696433b 100644 --- a/stablehlo/dialect/StablehloOps.td +++ b/stablehlo/dialect/StablehloOps.td @@ -328,6 +328,23 @@ def StableHLO_ExpOp: StableHLO_UnaryElementwiseOp<"exponential", %result = stablehlo.exponential %operand : tensor<2x2xf64> ``` }]; + let arguments = (ins HLO_FpComplexOrQuantizedIntTensor:$operand, + DefaultValuedOptionalAttr:$result_accuracy); + let results = (outs HLO_FpComplexOrQuantizedIntTensor:$result); + let extraClassDeclaration = commonClassDeclaration # [{ + LogicalResult reifyReturnTypeShapes( + OpBuilder& builder, ValueRange operands, + SmallVectorImpl& reifiedReturnShapes) { + return ::mlir::hlo::deriveShapeFromOperand(&builder, getOperation(), + operands.front(), + &reifiedReturnShapes); + } + }]; + let hasVerifier = 1; + + let assemblyFormat = [{ + $operand attr-dict `:` custom(type($operand), type($result)) + }]; } def StableHLO_Expm1Op: StableHLO_UnaryElementwiseOp<"exponential_minus_one", diff --git a/stablehlo/dialect/TypeInference.cpp b/stablehlo/dialect/TypeInference.cpp index 83ba24af635..daabe04d9a0 100644 --- a/stablehlo/dialect/TypeInference.cpp +++ b/stablehlo/dialect/TypeInference.cpp @@ -5057,5 +5057,30 @@ LogicalResult verifyWhileOp(std::optional location, return success(); } +LogicalResult verifyResultAccuracyCombination( + ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, APFloat atol, + APFloat rtol, int64_t ulps, StringRef mode) { + if (mode == "DEFAULT" || mode == "HIGHEST") { + bool all_zero = atol.isZero() && rtol.isZero() && ulps == 0; + if (!all_zero) { + return emitError() + << "Invalid tolerances for ResultAccuracyAttr with mode " << mode + << ", must be all zero."; + } + } + return success(); +} + +LogicalResult verifyResultAccuracyAttr( + ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, APFloat atol, + APFloat rtol, int64_t ulps, StringRef mode) { + if (atol.isNegative() || rtol.isNegative() || ulps < 0) + return emitError() << "Negative tolerance"; + if (failed( + verifyResultAccuracyCombination(emitError, atol, rtol, ulps, mode))) + return failure(); + return success(); +} + } // end namespace hlo } // end namespace mlir diff --git a/stablehlo/dialect/TypeInference.h b/stablehlo/dialect/TypeInference.h index 9c622acd761..d34e75ca357 100644 --- a/stablehlo/dialect/TypeInference.h +++ b/stablehlo/dialect/TypeInference.h @@ -26,6 +26,7 @@ limitations under the License. #include "mlir/IR/SymbolTable.h" #include "mlir/IR/Types.h" #include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" #include "stablehlo/dialect/Base.h" @@ -596,6 +597,14 @@ LogicalResult verifyUniformQuantizeOp(std::optional location, LogicalResult verifyWhileOp(std::optional location, ValueRange operand, Region& cond, Region& body); + +LogicalResult verifyResultAccuracyCombination( + ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, APFloat atol, + APFloat rtol, int64_t ulps, StringRef mode); + +LogicalResult verifyResultAccuracyAttr( + ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, APFloat atol, + APFloat rtol, int64_t ulps, StringRef mode); } // end namespace hlo } // end namespace mlir diff --git a/stablehlo/dialect/Version.h b/stablehlo/dialect/Version.h index 97597dedd97..85558cf2895 100644 --- a/stablehlo/dialect/Version.h +++ b/stablehlo/dialect/Version.h @@ -38,7 +38,7 @@ class Version { static FailureOr fromString(llvm::StringRef versionRef); /// Return a Version representing the current VHLO dialect version. - static Version getCurrentVersion() { return Version(1, 8, 12); } + static Version getCurrentVersion() { return Version(1, 9, 0); } /// Return a Version representing the minimum supported VHLO dialect version. static Version getMinimumVersion() { return Version(0, 9, 0); } diff --git a/stablehlo/dialect/VhloAttrs.td b/stablehlo/dialect/VhloAttrs.td index 8fd9ae592e2..f190bb268af 100644 --- a/stablehlo/dialect/VhloAttrs.td +++ b/stablehlo/dialect/VhloAttrs.td @@ -21,18 +21,8 @@ include "mlir/IR/AttrTypeBase.td" include "stablehlo/dialect/VhloBase.td" include "stablehlo/dialect/VhloDialect.td" include "stablehlo/dialect/VhloTypes.td" +include "stablehlo/dialect/VhloEnums.td" -def VHLO_VersionedAttrInterface : AttrInterface<"VersionedAttrInterface"> { - let cppNamespace = "::mlir::vhlo"; - let methods = [ - InterfaceMethod< - "Returns the minimum version of the VHLO dialect an attribute is supported in.", - "mlir::vhlo::Version", "getMinVersion">, - InterfaceMethod< - "Returns the maximum version (inclusive) of the VHLO dialect an attribute is supported in.", - "mlir::vhlo::Version", "getMaxVersion">, - ]; -} class VHLO_AttrDef : AttrDef { @@ -190,4 +180,27 @@ def VHLO_TypeExtensionsAttrV1 : VHLO_AttrDef<"TypeExtensionsV1", "0.9.0", "curre let assemblyFormat = "`<` struct(params) `>`"; } + +def VHLO_ResultAccuracyAttrV1 : VHLO_AttrDef<"ResultAccuracyV1", "1.9.0", "current"> { + let mnemonic = "result_accuracy_v1"; + let summary = "The requested accuracy for transcendental unary ops."; + let parameters = (ins + VHLO_APFloatV1:$atol, + VHLO_APFloatV1:$rtol, + "int64_t":$ulps, + "mlir::Attribute":$mode + ); + let assemblyFormat = "`<` struct(params) `>`"; + let genVerifyDecl = 1; + let extraClassDefinition = [{ + LogicalResult ResultAccuracyV1Attr::verify( + llvm::function_ref errFn, + APFloat atol, APFloat rtol, int64_t ulps, + mlir::Attribute mode) { + if (!isFromVhlo(mode)) return errFn() << "expected VHLO result accuracy mode"; + return success(); + } + }]; +} + #endif // STABLEHLO_DIALECT_VHLO_ATTRS diff --git a/stablehlo/dialect/VhloBytecode.cpp b/stablehlo/dialect/VhloBytecode.cpp index b8410bb71d3..2d1086f68b7 100644 --- a/stablehlo/dialect/VhloBytecode.cpp +++ b/stablehlo/dialect/VhloBytecode.cpp @@ -178,6 +178,18 @@ enum AttributeCode { /// bounds : svarint[] /// } kTypeExtensionsV1Attr = 18, + + // ResultAccuracyModeV1Attr { + // mode: varint (encoded enum) + // } + kResultAccuracyModeV1Attr = 19, + + // ResultAccuracyV1Attr { + // atol: APFloat + // rtol: APFloat + // ulps: svarint + // } + kResultAccuracyV1Attr = 20, }; /// This enum contains marker codes used to indicate which type is @@ -433,6 +445,10 @@ class VhloBytecodeInterface : public BytecodeDialectInterface { TypeV1Attr readTypeV1Attr(DialectBytecodeReader &reader) const; TypeExtensionsV1Attr readTypeExtensionsV1Attr( DialectBytecodeReader &reader) const; + ResultAccuracyModeV1Attr readResultAccuracyModeV1Attr( + DialectBytecodeReader &reader) const; + ResultAccuracyV1Attr readResultAccuracyV1Attr( + DialectBytecodeReader &reader) const; // TO ADD ATTRIBUTE: Include a write method for each attribute in VHLO // Ex: void write(SomeAttr attr, DialectBytecodeWriter &writer) const; @@ -457,6 +473,9 @@ class VhloBytecodeInterface : public BytecodeDialectInterface { void write(TransposeV1Attr attr, DialectBytecodeWriter &writer) const; void write(TypeV1Attr attr, DialectBytecodeWriter &writer) const; void write(TypeExtensionsV1Attr attr, DialectBytecodeWriter &writer) const; + void write(ResultAccuracyModeV1Attr attr, + DialectBytecodeWriter &writer) const; + void write(ResultAccuracyV1Attr attr, DialectBytecodeWriter &writer) const; //===--------------------------------------------------------------------===// // Types @@ -541,6 +560,10 @@ Attribute VhloBytecodeInterface::readAttribute( return readTypeV1Attr(reader); case vhlo_encoding::kTypeExtensionsV1Attr: return readTypeExtensionsV1Attr(reader); + case vhlo_encoding::kResultAccuracyModeV1Attr: + return readResultAccuracyModeV1Attr(reader); + case vhlo_encoding::kResultAccuracyV1Attr: + return readResultAccuracyV1Attr(reader); default: reader.emitError() << "unknown vhlo attribute code: " << code; return Attribute(); @@ -558,7 +581,8 @@ LogicalResult VhloBytecodeInterface::writeAttribute( FftTypeV1Attr, FloatV1Attr, IntegerV1Attr, OutputOperandAliasV1Attr, PrecisionV1Attr, RngAlgorithmV1Attr, RngDistributionV1Attr, StringV1Attr, TensorV1Attr, TransposeV1Attr, TypeV1Attr, - TypeExtensionsV1Attr>([&](auto attr) { + TypeExtensionsV1Attr, ResultAccuracyV1Attr, + ResultAccuracyModeV1Attr>([&](auto attr) { LOG_WRITE_CALL; write(attr, writer); return success(); @@ -1450,6 +1474,55 @@ void VhloBytecodeInterface::write(UnrankedTensorV1Type type, writer.writeType(type.getElementType()); } +//===----------------------------------------------------------------------===// +// ResultAccuracyModeAttr + +ResultAccuracyModeV1Attr VhloBytecodeInterface::readResultAccuracyModeV1Attr( + DialectBytecodeReader &reader) const { + LOG_READ_CALL; + return hlo::bytecode::readEnumAttribute( + reader, getContext(), + [](uint32_t val) { return symbolizeResultAccuracyModeV1(val); }); +} + +void VhloBytecodeInterface::write(ResultAccuracyModeV1Attr attr, + DialectBytecodeWriter &writer) const { + writer.writeVarInt(vhlo_encoding::kResultAccuracyModeV1Attr); + hlo::bytecode::writeEnumAttribute(attr, writer); +} + +//===----------------------------------------------------------------------===// +// ResultAccuracyAttr + +ResultAccuracyV1Attr VhloBytecodeInterface::readResultAccuracyV1Attr( + DialectBytecodeReader &reader) const { + LOG_READ_CALL; + FailureOr atol; + FailureOr rtol; + int64_t ulps; + ResultAccuracyModeV1Attr mode; + if (failed(atol = + reader.readAPFloatWithKnownSemantics(APFloat::IEEEdouble())) || + failed(rtol = + reader.readAPFloatWithKnownSemantics(APFloat::IEEEdouble())) || + failed(reader.readSignedVarInt(ulps)) || + failed(reader.readAttribute(mode))) { + mlir::emitWarning(mlir::UnknownLoc::get(getContext())) + << "failed to read APFloat for atol"; + return ResultAccuracyV1Attr(); + } + return ResultAccuracyV1Attr::get(getContext(), *atol, *rtol, ulps, mode); +} + +void VhloBytecodeInterface::write(ResultAccuracyV1Attr attr, + DialectBytecodeWriter &writer) const { + writer.writeVarInt(vhlo_encoding::kResultAccuracyV1Attr); + writer.writeAPFloatWithKnownSemantics(attr.getAtol()); + writer.writeAPFloatWithKnownSemantics(attr.getRtol()); + writer.writeSignedVarInt(attr.getUlps()); + writer.writeAttribute(attr.getMode()); +} + } // namespace void addBytecodeInterface(VhloDialect *dialect) { diff --git a/stablehlo/dialect/VhloDialect.td b/stablehlo/dialect/VhloDialect.td index c7a05727908..75f33ae2557 100644 --- a/stablehlo/dialect/VhloDialect.td +++ b/stablehlo/dialect/VhloDialect.td @@ -47,6 +47,7 @@ def VHLO_Dialect : Dialect { 1.6.0: Add DotAlgorithm specificaiton to `dot_general`. 1.7.0: Introduce `f8E4M3` and `f8E3M4` types. 1.8.0: Introduce `f4E2M1FN`, `f6E2M3FN`, `f6E3M2FN` and `f8E8M0FNU` types. + 1.9.0: Add `ResultAccuracy` attribute to `exp` op. }]; let useDefaultAttributePrinterParser = 0; diff --git a/stablehlo/dialect/VhloEnums.td b/stablehlo/dialect/VhloEnums.td index 0df09f39ccd..cf08ee7b790 100644 --- a/stablehlo/dialect/VhloEnums.td +++ b/stablehlo/dialect/VhloEnums.td @@ -20,7 +20,20 @@ limitations under the License. include "mlir/IR/EnumAttr.td" include "mlir/IR/PatternBase.td" include "stablehlo/dialect/VhloBase.td" -include "stablehlo/dialect/VhloAttrs.td" +include "stablehlo/dialect/VhloDialect.td" +include "mlir/IR/AttrTypeBase.td" + +def VHLO_VersionedAttrInterface : AttrInterface<"VersionedAttrInterface"> { + let cppNamespace = "::mlir::vhlo"; + let methods = [ + InterfaceMethod< + "Returns the minimum version of the VHLO dialect an attribute is supported in.", + "mlir::vhlo::Version", "getMinVersion">, + InterfaceMethod< + "Returns the maximum version (inclusive) of the VHLO dialect an attribute is supported in.", + "mlir::vhlo::Version", "getMaxVersion">, + ]; +} class VHLO_I32EnumAttr cases> : I32EnumAttr { @@ -198,4 +211,23 @@ def VHLO_TransposeV1 : VHLO_I32EnumAttr<"TransposeV1", [ def VHLO_TransposeAttrV1 : VHLO_EnumAttr; +//===----------------------------------------------------------------------===// +// ResultAccuracyMode +//===----------------------------------------------------------------------===// + +def VHLO_RESULT_V1_ACCURACY_DEFAULT : I32EnumAttrCase<"DEFAULT", 0>; +def VHLO_RESULT_V1_ACCURACY_HIGHEST : I32EnumAttrCase<"HIGHEST", 1>; +def VHLO_RESULT_V1_ACCURACY_TOLERANCE: I32EnumAttrCase<"TOLERANCE", 2>; + +def VHLO_ResultAccuracyModeV1 : VHLO_I32EnumAttr<"ResultAccuracyModeV1", + [ + VHLO_RESULT_V1_ACCURACY_DEFAULT, + VHLO_RESULT_V1_ACCURACY_HIGHEST, + VHLO_RESULT_V1_ACCURACY_TOLERANCE + ]> {} + +def VHLO_ResultAccuracyModeV1Attr + : VHLO_EnumAttr; + + #endif // STABLEHLO_DIALECT_VHLO_ENUMS diff --git a/stablehlo/dialect/VhloOps.td b/stablehlo/dialect/VhloOps.td index 7a67a5fb7ab..43aa02bf032 100644 --- a/stablehlo/dialect/VhloOps.td +++ b/stablehlo/dialect/VhloOps.td @@ -618,11 +618,18 @@ def VHLO_Expm1OpV1 : VHLO_Op<"exponential_minus_one_v1", "0.9.0", "current"> { let results = (outs VHLO_AnyType:$result); } -def VHLO_ExpOpV1 : VHLO_Op<"exponential_v1", "0.9.0", "current"> { +def VHLO_ExpOpV1 : VHLO_Op<"exponential_v1", "0.9.0", "1.8.0"> { let arguments = (ins VHLO_AnyType:$operand); let results = (outs VHLO_AnyType:$result); } +def VHLO_ExpOpV2 : VHLO_Op<"exponential_v2", "1.9.0", "current"> { + let arguments = (ins + VHLO_AnyType:$operand, + VHLO_AnyAttr:$result_accuracy); + let results = (outs VHLO_AnyType:$result); +} + def VHLO_FftOpV1 : VHLO_Op<"fft_v1", "0.9.0", "current"> { let arguments = (ins VHLO_AnyType:$operand, diff --git a/stablehlo/integrations/c/StablehloAttributes.cpp b/stablehlo/integrations/c/StablehloAttributes.cpp index 4f888c3d4bc..3be3dd29023 100644 --- a/stablehlo/integrations/c/StablehloAttributes.cpp +++ b/stablehlo/integrations/c/StablehloAttributes.cpp @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include "llvm/ADT/APFloat.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" @@ -687,3 +688,69 @@ int64_t stablehloTypeExtensionsGetBoundsElem(MlirAttribute attr, intptr_t pos) { return llvm::cast(unwrap(attr)) .getBounds()[pos]; } + +//===----------------------------------------------------------------------===// +// ResultAccuracyModeAttr +//===----------------------------------------------------------------------===// + +MlirAttribute stablehloResultAccuracyModeAttrGet(MlirContext ctx, + MlirStringRef value) { + std::optional accuracyMode = + mlir::stablehlo::symbolizeResultAccuracyMode(unwrap(value)); + if (!accuracyMode) llvm::report_fatal_error("Invalid value."); + return wrap(mlir::stablehlo::ResultAccuracyModeAttr::get( + unwrap(ctx), accuracyMode.value())); +} + +bool stablehloAttributeIsAResultAccuracyModeAttr(MlirAttribute attr) { + return llvm::isa(unwrap(attr)); +} + +MlirStringRef stablehloResultAccuracyModeAttrGetValue(MlirAttribute attr) { + return wrap(mlir::stablehlo::stringifyResultAccuracyMode( + llvm::cast(unwrap(attr)) + .getValue())); +} +//===----------------------------------------------------------------------===// +// ResultAccuracyAttr +//===----------------------------------------------------------------------===// + +MlirAttribute stablehloResultAccuracyAttrGet(MlirContext ctx, double atol, + double rtol, int64_t ulps, + MlirStringRef mode) { + std::optional accuracyMode = + mlir::stablehlo::symbolizeResultAccuracyMode(unwrap(mode)); + if (!accuracyMode) llvm::report_fatal_error("Invalid value."); + mlir::stablehlo::ResultAccuracyModeAttr modeAttr = + mlir::stablehlo::ResultAccuracyModeAttr::get(unwrap(ctx), + accuracyMode.value()); + return wrap(mlir::stablehlo::ResultAccuracyAttr::get( + unwrap(ctx), llvm::APFloat(atol), llvm::APFloat(rtol), ulps, modeAttr)); +} + +bool stablehloAttributeIsAResultAccuracyAttr(MlirAttribute attr) { + return llvm::isa(unwrap(attr)); +} + +double stablehloResultAccuracyAttrGetAtol(MlirAttribute attr) { + llvm::APFloat result = + llvm::cast(unwrap(attr)).getAtol(); + return result.convertToDouble(); +} + +double stablehloResultAccuracyAttrGetRtol(MlirAttribute attr) { + llvm::APFloat result = + llvm::cast(unwrap(attr)).getRtol(); + return result.convertToDouble(); +} + +int64_t stablehloResultAccuracyAttrGetUlps(MlirAttribute attr) { + return llvm::cast(unwrap(attr)) + .getUlps(); +} + +MlirAttribute stablehloResultAccuracyAttrGetMode(MlirAttribute attr) { + mlir::stablehlo::ResultAccuracyModeAttr modeAttr = + llvm::cast(unwrap(attr)).getMode(); + return wrap(modeAttr); +} diff --git a/stablehlo/integrations/c/StablehloAttributes.h b/stablehlo/integrations/c/StablehloAttributes.h index 897bfaa1a48..8ea0653b2d5 100644 --- a/stablehlo/integrations/c/StablehloAttributes.h +++ b/stablehlo/integrations/c/StablehloAttributes.h @@ -13,6 +13,7 @@ limitations under the License. #ifndef STABLEHLO_INTEGRATIONS_C_STABLEHLO_ATTRIBUTES_H #define STABLEHLO_INTEGRATIONS_C_STABLEHLO_ATTRIBUTES_H +#include #include #include @@ -376,6 +377,42 @@ stablehloTypeExtensionsGetBoundsSize(MlirAttribute attr); MLIR_CAPI_EXPORTED int64_t stablehloTypeExtensionsGetBoundsElem(MlirAttribute attr, intptr_t pos); +// ===---------------------------------------------------------------------===// +// ResultAccuracyModeAttr +//===----------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED MlirAttribute +stablehloResultAccuracyModeAttrGet(MlirContext ctx, MlirStringRef value); + +MLIR_CAPI_EXPORTED bool stablehloAttributeIsAResultAccuracyModeAttr( + MlirAttribute attr); + +MLIR_CAPI_EXPORTED MlirStringRef +stablehloResultAccuracyModeAttrGetValue(MlirAttribute attr); + +// ===---------------------------------------------------------------------===// +// ResultAccuracyAttr +//===----------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED MlirAttribute +stablehloResultAccuracyAttrGet(MlirContext ctx, double atol, double rtol, + int64_t ulps, MlirStringRef value); + +MLIR_CAPI_EXPORTED bool stablehloAttributeIsAResultAccuracyAttr( + MlirAttribute attr); + +MLIR_CAPI_EXPORTED double stablehloResultAccuracyAttrGetAtol( + MlirAttribute attr); + +MLIR_CAPI_EXPORTED double stablehloResultAccuracyAttrGetRtol( + MlirAttribute attr); + +MLIR_CAPI_EXPORTED int64_t +stablehloResultAccuracyAttrGetUlps(MlirAttribute attr); + +MLIR_CAPI_EXPORTED MlirAttribute +stablehloResultAccuracyAttrGetMode(MlirAttribute attr); + #ifdef __cplusplus } #endif diff --git a/stablehlo/integrations/python/StablehloModule.cpp b/stablehlo/integrations/python/StablehloModule.cpp index 08c6ebd6cfe..6c297b624d8 100644 --- a/stablehlo/integrations/python/StablehloModule.cpp +++ b/stablehlo/integrations/python/StablehloModule.cpp @@ -599,6 +599,50 @@ NB_MODULE(_stablehlo, m) { stablehloTypeExtensionsGetBoundsElem); }); + mlir::python::nanobind_adaptors::mlir_attribute_subclass( + m, "ResultAccuracyAttr", stablehloAttributeIsAResultAccuracyAttr) + .def_classmethod( + "get", + [](nb::object cls, double atol, double rtol, int64_t ulps, + const std::string &mode, MlirContext ctx) { + return cls(stablehloResultAccuracyAttrGet( + ctx, atol, rtol, ulps, + mlirStringRefCreate(mode.c_str(), mode.size()))); + }, + nb::arg("cls"), nb::arg("atol"), nb::arg("rtol"), nb::arg("ulps"), + nb::arg("mode"), nb::arg("context") = nb::none(), + "Creates a ResultAccuracyAttr with the given values.") + .def_property_readonly("atol", + [](MlirAttribute self) { + return stablehloResultAccuracyAttrGetAtol(self); + }) + .def_property_readonly("rtol", + [](MlirAttribute self) { + return stablehloResultAccuracyAttrGetRtol(self); + }) + .def_property_readonly("ulps", + [](MlirAttribute self) { + return stablehloResultAccuracyAttrGetUlps(self); + }) + .def_property_readonly("mode", [](MlirAttribute self) { + return toPyString(stablehloResultAccuracyModeAttrGetValue( + stablehloResultAccuracyAttrGetMode(self))); + }); + + mlir::python::nanobind_adaptors::mlir_attribute_subclass( + m, "ResultAccuracyModeAttr", stablehloAttributeIsAResultAccuracyModeAttr) + .def_classmethod( + "get", + [](nb::object cls, const std::string &value, MlirContext ctx) { + return cls(stablehloResultAccuracyModeAttrGet( + ctx, mlirStringRefCreate(value.c_str(), value.size()))); + }, + nb::arg("cls"), nb::arg("value"), nb::arg("context") = nb::none(), + "Creates a ResultAccuracyModeAttr with the given values.") + .def_property_readonly("value", [](MlirAttribute self) { + return toPyString(stablehloResultAccuracyModeAttrGetValue(self)); + }); + // // StableHLO APIs // diff --git a/stablehlo/integrations/python/tests/stablehlo.py b/stablehlo/integrations/python/tests/stablehlo.py index 39a0cfd3538..2a9aac32a0a 100644 --- a/stablehlo/integrations/python/tests/stablehlo.py +++ b/stablehlo/integrations/python/tests/stablehlo.py @@ -386,3 +386,24 @@ def test_register_passes(): cloned_module = module.operation.clone() pipeline.run(cloned_module.operation) assert str(module) == str(cloned_module) + + +@run +def test_result_accuracy_attr_default(): + attr = stablehlo.ResultAccuracyAttr.get(atol=0, rtol=0, ulps=0, mode="DEFAULT") + assert attr is not None + assert attr.mode == "DEFAULT" + assert attr.atol == 0 + assert attr.rtol == 0 + assert attr.ulps == 0 + +@run +def test_result_accuracy_attr_tolerance(): + attr = stablehlo.ResultAccuracyAttr.get(atol=1e-5, rtol=1.0, + ulps=2, mode="TOLERANCE") + assert attr is not None + assert attr.mode == "TOLERANCE" + assert attr.atol == 1e-5 + assert attr.rtol == 1.0 + assert attr.ulps == 2 + diff --git a/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/tests/ops_stablehlo.mlir index a5aa2f359dc..7bf113c52c7 100644 --- a/stablehlo/tests/ops_stablehlo.mlir +++ b/stablehlo/tests/ops_stablehlo.mlir @@ -1779,6 +1779,30 @@ func.func @dot_bad_precision_config(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi3 // ----- +// CHECK-LABEL: func @exponential_result_accuracy +func.func @exponential_result_accuracy(%arg0: tensor) -> tensor { + %0 = "stablehlo.exponential"(%arg0) {result_accuracy = #stablehlo.result_accuracy>} : (tensor) -> tensor + func.return %0: tensor +} + +// ----- + +// CHECK-LABEL: func @exponential_result_accuracy_tol +func.func @exponential_result_accuracy_tol(%arg0: tensor) -> tensor { + %0 = "stablehlo.exponential"(%arg0) {result_accuracy = #stablehlo.result_accuracy>} : (tensor) -> tensor + func.return %0: tensor +} + +// ----- + +func.func @exponential_result_accuracy_tol(%arg0: tensor) -> tensor { + // expected-error@+1 {{Invalid tolerances for ResultAccuracyAttr with mode HIGHEST, must be all zero.}} + %0 = "stablehlo.exponential"(%arg0) {result_accuracy = #stablehlo.result_accuracy>} : (tensor) -> tensor + func.return %0: tensor +} + +// ----- + func.func @dot_more_dynamic_output_type(%arg0: tensor<3xf32>, %arg1: tensor) -> tensor { %0 = "stablehlo.dot"(%arg0, %arg1) : (tensor<3xf32>, tensor) -> tensor func.return %0 : tensor diff --git a/stablehlo/tests/ops_stablehlo_roundtrip.mlir b/stablehlo/tests/ops_stablehlo_roundtrip.mlir index 7be3c843938..504321c29aa 100644 --- a/stablehlo/tests/ops_stablehlo_roundtrip.mlir +++ b/stablehlo/tests/ops_stablehlo_roundtrip.mlir @@ -766,6 +766,11 @@ func.func @test_unary_cbrt(%arg: tensor<3x4xf32>) -> tensor<3x4xf32> { func.return %0 : tensor<3x4xf32> } +func.func @test_unary_result_accuracy(%arg0: tensor<2xf32>) -> tensor<2xf32> { + %exp = "stablehlo.exponential"(%arg0) {result_accuracy = #stablehlo.result_accuracy>} : (tensor<2xf32>) -> tensor<2xf32> + func.return %exp : tensor<2xf32> +} + func.func @test_unary_round_nearest_even(%arg0: tensor<2xf32>) -> tensor<2xf32> { %0 = "stablehlo.round_nearest_even"(%arg0) {} : (tensor<2xf32>) -> tensor<2xf32> func.return %0 : tensor<2xf32> diff --git a/stablehlo/tests/print_stablehlo.mlir b/stablehlo/tests/print_stablehlo.mlir index bb4b711c851..c58204871a7 100644 --- a/stablehlo/tests/print_stablehlo.mlir +++ b/stablehlo/tests/print_stablehlo.mlir @@ -406,3 +406,16 @@ func.func @slice(%arg0: tensor<3x8xf32>, %arg1: tensor<8xf32>) %slice6 = stablehlo.slice %arg0 [1:3:1, 4:8:2] : (tensor<3x8xf32>) -> tensor<2x2xf32> return %slice1, %slice2, %slice3, %slice4, %slice5, %slice6 : tensor<1xf32>, tensor<2xf32>, tensor<1xf32>, tensor<1xf32>, tensor<2x2xf32>, tensor<2x2xf32> } + +func.func @result_accuracy_default() -> () attributes { + // CHECK: mode.default = #stablehlo.result_accuracy> + // CHECK: mode.highest = #stablehlo.result_accuracy> + // CHECK: mode.tolerance_full = #stablehlo.result_accuracy> + // CHECK: mode.tolerance_partial = #stablehlo.result_accuracy> + mode.default = #stablehlo.result_accuracy>, + mode.highest = #stablehlo.result_accuracy>, + mode.tolerance_full = #stablehlo.result_accuracy>, + mode.tolerance_partial = #stablehlo.result_accuracy> +} { + func.return +} diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_9_0.mlir b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_9_0.mlir new file mode 100644 index 00000000000..c72aba44b45 --- /dev/null +++ b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_9_0.mlir @@ -0,0 +1,2966 @@ +// RUN: stablehlo-opt --mlir-print-op-generic %s.bc | FileCheck %s +// RUN: stablehlo-translate --deserialize %s.bc | stablehlo-translate --serialize --target=1.9.0 | stablehlo-opt --mlir-print-op-generic | FileCheck %s +// RUN: stablehlo-translate --deserialize %s.bc | stablehlo-opt > %t.0 +// RUN: stablehlo-opt --strip-debuginfo %s > %t.1 +// RUN: diff %t.0 %t.1 +// RUN: stablehlo-translate --serialize --target=1.9.0 --strip-debuginfo %s > %t.2 +// RUN: diff %s.bc %t.2 +// RUN: stablehlo-opt --stablehlo-legalize-to-vhlo -emit-bytecode -debug-only=vhlo-bytecode %s 2>&1 | FileCheck --check-prefix=CHECK-WARN %s +// RUN: stablehlo-opt --stablehlo-legalize-to-vhlo -emit-bytecode %s | stablehlo-opt -debug-only=vhlo-bytecode 2>&1 | FileCheck --check-prefix=CHECK-WARN %s + +// CHECK-WARN-NOT: Not Implemented + +// ============ ATTRIBUTES ============ + +// CHECK-LABEL: "attr_comparison_direction_eq" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_comparison_direction_eq(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + // CHECK: comparison_direction = #vhlo + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_direction_ne" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_comparison_direction_ne(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + // CHECK: comparison_direction = #vhlo + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_direction_ge" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_comparison_direction_ge(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + // CHECK: comparison_direction = #vhlo + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_direction_gt" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_comparison_direction_gt(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + // CHECK: comparison_direction = #vhlo + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_direction_le" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_comparison_direction_le(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + // CHECK: comparison_direction = #vhlo + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_direction_lt" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_comparison_direction_lt(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + // CHECK: comparison_direction = #vhlo + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_type_notype" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_comparison_type_notype(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo + // CHECK: compare_type = #vhlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_type_float" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_comparison_type_float(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo, + // CHECK: compare_type = #vhlo, + compare_type = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_type_totalorder" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_comparison_type_totalorder(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo, + // CHECK: compare_type = #vhlo, + compare_type = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_type_signed" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_comparison_type_signed(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo, + // CHECK: compare_type = #vhlo, + compare_type = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_comparison_type_unsigned" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_comparison_type_unsigned(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo, + // CHECK: compare_type = #vhlo, + compare_type = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// ConvDimensionNumbers aka #stablehlo.conv is covered below. + +// CHECK-LABEL: "attr_custom_call_api_version_unspecified" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @attr_custom_call_api_version_unspecified(%arg0: tensor) -> tensor { + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo", + // CHECK: api_version = #vhlo + api_version = 0 : i32 + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_custom_call_api_version_original" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @attr_custom_call_api_version_original(%arg0: tensor) -> tensor { + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo", + // CHECK: api_version = #vhlo + api_version = 1 : i32 + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_custom_call_api_version_status_returning" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @attr_custom_call_api_version_status_returning(%arg0: tensor) -> tensor { + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo", + // CHECK: api_version = #vhlo + api_version = 2 : i32 + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_custom_call_api_version_status_returning_unified" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @attr_custom_call_api_version_status_returning_unified(%arg0: tensor) -> tensor { + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo", + // CHECK: api_version = #vhlo + api_version = 3 : i32 + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_dict" +// CHECK: #vhlo.dict_v1<{#vhlo.string_v1<"attr1"> = #vhlo.integer_v1<1 : i32>, #vhlo.string_v1<"attr2"> = #vhlo.integer_v1<2 : i32>} +func.func @attr_dict() attributes {stablehlo.attr = {attr1 = 1 : i32, attr2 = 2 : i32}} { + return +} + +// CHECK-LABEL: "attr_custom_call_api_version_typed_ffi" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +// CHECK: api_version = #vhlo +// CHECK-SAME: backend_config = #vhlo.dict_v1<{#vhlo.string_v1<"bar"> = #vhlo.integer_v1<42 : i32>}> +func.func @attr_custom_call_api_version_typed_ffi(%arg0: tensor) -> tensor { + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo", + backend_config= {bar = 42 : i32}, + api_version = 4 : i32 + } : (tensor) -> tensor + func.return %0 : tensor +} + + +// CHECK-LABEL: "attr_custom_call_api_version_typed_ffi_no_backend_config" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +// CHECK: api_version = #vhlo +// CHECK-SAME: backend_config = #vhlo.dict_v1<{}> +func.func @attr_custom_call_api_version_typed_ffi_no_backend_config(%arg0: tensor) -> tensor { + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo", + api_version = 4 : i32 + } : (tensor) -> tensor + func.return %0 : tensor +} + +// DotDimensionNumbers aka #stablehlo.dot is covered below. + +// CHECK-LABEL: "attr_fft_type_fft" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @attr_fft_type_fft(%arg0: tensor<16xcomplex>) -> tensor<16xcomplex> { + %0 = "stablehlo.fft"(%arg0) { + // CHECK: fft_type = #vhlo + fft_type = #stablehlo, + fft_length = array + } : (tensor<16xcomplex>) -> tensor<16xcomplex> + func.return %0 : tensor<16xcomplex> +} + +// CHECK-LABEL: "attr_fft_type_ifft" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @attr_fft_type_ifft(%arg0: tensor<16xcomplex>) -> tensor<16xcomplex> { + %0 = "stablehlo.fft"(%arg0) { + // CHECK: fft_type = #vhlo + fft_type = #stablehlo, + fft_length = array + } : (tensor<16xcomplex>) -> tensor<16xcomplex> + func.return %0 : tensor<16xcomplex> +} + +// CHECK-LABEL: "attr_fft_type_rfft" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @attr_fft_type_rfft(%arg0: tensor<16xf32>) -> tensor<9xcomplex> { + %0 = "stablehlo.fft"(%arg0) { + // CHECK: fft_type = #vhlo + fft_type = #stablehlo, + fft_length = array + } : (tensor<16xf32>) -> tensor<9xcomplex> + func.return %0 : tensor<9xcomplex> +} + +// CHECK-LABEL: "attr_fft_type_irfft" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @attr_fft_type_irfft(%arg0: tensor<9xcomplex>) -> tensor<16xf32> { + %0 = "stablehlo.fft"(%arg0) { + // CHECK: fft_type = #vhlo + fft_type = #stablehlo, + fft_length = array + } : (tensor<9xcomplex>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK-LABEL: "attr_result_accuracy_HIGHEST" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}} +func.func @attr_result_accuracy_HIGHEST(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { + %0 = "stablehlo.exponential"(%arg0) { + // CHECK: result_accuracy = #vhlo.result_accuracy_v1> + result_accuracy = #stablehlo.result_accuracy> + } : (tensor<8x16xf32>) -> tensor<8x16xf32> + func.return %0 : tensor<8x16xf32> +} + +// CHECK-LABEL: "attr_result_accuracy_TOLERANCE" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}} +func.func @attr_result_accuracy_TOLERANCE(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { + %0 = "stablehlo.exponential"(%arg0) { + // CHECK: result_accuracy = #vhlo.result_accuracy_v1> + result_accuracy = #stablehlo.result_accuracy> + } : (tensor<8x16xf32>) -> tensor<8x16xf32> + func.return %0 : tensor<8x16xf32> +} + +// CHECK-LABEL: "attr_result_accuracy_DEFAULT" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}} +func.func @attr_result_accuracy_DEFAULT(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { + %0 = "stablehlo.exponential"(%arg0) { + // CHECK: result_accuracy = #vhlo.result_accuracy_v1> + result_accuracy = #stablehlo.result_accuracy> + } : (tensor<8x16xf32>) -> tensor<8x16xf32> + func.return %0 : tensor<8x16xf32> +} + +// GatherDimensionNumbers aka #stablehlo.gather is covered below. + +// CHECK-LABEL: "attr_precision_config_default" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_precision_config_default(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { + %0 = "stablehlo.dot"(%arg0, %arg1) { + // CHECK: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]> + } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> + func.return %0 : tensor<8x8xf32> +} + +// CHECK-LABEL: "attr_precision_config_high" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_precision_config_high(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { + %0 = "stablehlo.dot"(%arg0, %arg1) { + // CHECK: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]> + precision_config = [#stablehlo, #stablehlo] + } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> + func.return %0 : tensor<8x8xf32> +} + +// CHECK-LABEL: "attr_precision_config_highest" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_precision_config_highest(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { + %0 = "stablehlo.dot"(%arg0, %arg1) { + // CHECK: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]> + precision_config = [#stablehlo, #stablehlo] + } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> + func.return %0 : tensor<8x8xf32> +} + +// CHECK-LABEL: "attr_rng_algorithm_default" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @attr_rng_algorithm_default(%arg0: tensor) -> (tensor, tensor) { + %0:2 = "stablehlo.rng_bit_generator"(%arg0) { + // CHECK: rng_algorithm = #vhlo + rng_algorithm = #stablehlo + } : (tensor) -> (tensor, tensor) + func.return %0#0, %0#1 : tensor, tensor +} + +// CHECK-LABEL: "attr_rng_algorithm_three_fry" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @attr_rng_algorithm_three_fry(%arg0: tensor) -> (tensor, tensor) { + %0:2 = "stablehlo.rng_bit_generator"(%arg0) { + // CHECK: rng_algorithm = #vhlo + rng_algorithm = #stablehlo + } : (tensor) -> (tensor, tensor) + func.return %0#0, %0#1 : tensor, tensor +} + +// CHECK-LABEL: "attr_rng_algorithm_philox" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @attr_rng_algorithm_philox(%arg0: tensor) -> (tensor, tensor) { + %0:2 = "stablehlo.rng_bit_generator"(%arg0) { + // CHECK: rng_algorithm = #vhlo + rng_algorithm = #stablehlo + } : (tensor) -> (tensor, tensor) + func.return %0#0, %0#1 : tensor, tensor +} + +// CHECK-LABEL: "attr_rng_distribution_uniform" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @attr_rng_distribution_uniform(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { + %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { + // CHECK: rng_distribution = #vhlo + rng_distribution = #stablehlo + } : (tensor, tensor, tensor<0xindex>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "attr_rng_distribution_normal" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @attr_rng_distribution_normal(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { + %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { + // CHECK: rng_distribution = #vhlo + rng_distribution = #stablehlo + } : (tensor, tensor, tensor<0xindex>) -> tensor + func.return %0 : tensor +} + +// ScatterDimensionNumbers aka #stablehlo.scatter is covered below. + +// CHECK-LABEL: "attr_transpose_no_transpose" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_transpose_no_transpose(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>) -> tensor<16x16xf32> { + %0 = "stablehlo.triangular_solve"(%arg0, %arg1) { + left_side = true, + lower = true, + unit_diagonal = true, + // transpose_a = #vhlo, + transpose_a = #stablehlo + } : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} + +// CHECK-LABEL: "attr_transpose_transpose" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_transpose_transpose(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>) -> tensor<16x16xf32> { + %0 = "stablehlo.triangular_solve"(%arg0, %arg1) { + left_side = true, + lower = true, + unit_diagonal = true, + // transpose_a = #vhlo, + transpose_a = #stablehlo + } : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} + +// CHECK-LABEL: "attr_transpose_adjoint" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @attr_transpose_adjoint(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>) -> tensor<16x16xf32> { + %0 = "stablehlo.triangular_solve"(%arg0, %arg1) { + left_side = true, + lower = true, + unit_diagonal = true, + // transpose_a = #vhlo, + transpose_a = #stablehlo + } : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} + +// TypeExtensionsAttr aka #stablehlo.type_extensions is covered below. + +// CHECK-LABEL: "attr_type_extensions_bounds" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @attr_type_extensions_bounds(%arg0: tensor>) -> tensor> { + // CHECK: "vhlo.return_v1"(%[[ARG0]]) : (!vhlo.tensor_v1>) -> () + func.return %arg0 : tensor> +} + + +// ============ DEFAULTS ============ + +// CHECK-LABEL: "default_all_gather" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @default_all_gather(%arg0: tensor<16x8xf32>) -> tensor<16x16xf32> { + // CHECK: "vhlo.all_gather_v2"(%[[ARG0]]) <{ + // CHECK-SAME: all_gather_dim = #vhlo.integer_v1<1 : i64> + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>>, + // CHECK-SAME: use_global_device_ids = #vhlo.bool_v1 + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x!vhlo.f32_v1> + %0 = "stablehlo.all_gather"(%arg0) { + all_gather_dim = 1 : i64, + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> + } : (tensor<16x8xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} + +// CHECK-LABEL: "default_all_gather_variadic" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @default_all_gather_variadic(%arg0: tensor<16x8xf32>, %arg1: tensor<16x8xf32>) -> (tensor<16x16xf32>, tensor<16x16xf32>) { + %0:2 = "stablehlo.all_gather"(%arg0, %arg1) { + all_gather_dim = 1 : i64, + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> + } : (tensor<16x8xf32>, tensor<16x8xf32>) -> (tensor<16x16xf32>, tensor<16x16xf32>) + func.return %0#0, %0#1 : tensor<16x16xf32>, tensor<16x16xf32> +} + +// CHECK-LABEL: "default_all_reduce" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @default_all_reduce(%arg0: tensor) -> tensor { + // CHECK: "vhlo.all_reduce_v2"(%[[ARG0]]) + // CHECK-SAME: <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>>, + // CHECK-SAME: use_global_device_ids = #vhlo.bool_v1 + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + + %0 = "stablehlo.all_reduce"(%arg0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = "stablehlo.add"(%arg1, %arg2) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "default_all_to_all" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @default_all_to_all(%arg0: tensor<4x16xf32>) -> tensor<16x4xf32> { + // CHECK: "vhlo.all_to_all_v2"(%[[ARG0]]) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: concat_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<1x4xi64>>, + // CHECK-SAME: split_count = #vhlo.integer_v1<4 : i64> + // CHECK-SAME: split_dimension = #vhlo.integer_v1<1 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1<4x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x4x!vhlo.f32_v1> + %0 = "stablehlo.all_to_all"(%arg0) { + split_dimension = 1 : i64, + concat_dimension = 0 : i64, + split_count = 4 : i64, + replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64> + } : (tensor<4x16xf32>) -> tensor<16x4xf32> + func.return %0 : tensor<16x4xf32> +} + +// CHECK-LABEL: "default_all_to_all_variadic" +func.func @default_all_to_all_variadic(%arg0: tensor<4x16xf32>, %arg1: tensor<5x16xf32>) -> (tensor<16x4xf32>, tensor<20x4xf32>) { + %0:2 = "stablehlo.all_to_all"(%arg0, %arg1) { + split_dimension = 1 : i64, + concat_dimension = 0 : i64, + split_count = 4 : i64, + replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + channel_handle = #stablehlo.channel_handle + } : (tensor<4x16xf32>, tensor<5x16xf32>) -> (tensor<16x4xf32>, tensor<20x4xf32>) + func.return %0#0, %0#1 : tensor<16x4xf32>, tensor<20x4xf32> +} + +// CHECK-LABEL: "default_cholesky" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @default_cholesky(%arg0: tensor<1x16x16xf32>) -> tensor<1x16x16xf32> { + // CHECK: "vhlo.cholesky_v1"(%[[ARG0]]) <{ + // CHECK-SAME: lower = #vhlo.bool_v1 + // CHECK-SAME: }> : (!vhlo.tensor_v1<1x16x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<1x16x16x!vhlo.f32_v1> + %0 = "stablehlo.cholesky"(%arg0) : (tensor<1x16x16xf32>) -> tensor<1x16x16xf32> + func.return %0 : tensor<1x16x16xf32> +} + +// CHECK-LABEL: "default_collective_permute" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @default_collective_permute(%arg0: tensor<16x8xf32>) -> tensor<16x8xf32> { + // CHECK: "vhlo.collective_permute_v1"(%[[ARG0]]) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME{LITERAL}: source_target_pairs = #vhlo.tensor_v1 : tensor<3x2xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x8x!vhlo.f32_v1> + %0 = "stablehlo.collective_permute"(%arg0) { + source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64> + } : (tensor<16x8xf32>) -> tensor<16x8xf32> + func.return %0 : tensor<16x8xf32> +} + +// CHECK-LABEL: "default_collective_broadcast" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @default_collective_broadcast(%arg0: tensor<16x8xf32>) -> tensor<16x8xf32> { + // CHECK: "vhlo.collective_broadcast_v1"(%[[ARG0]]) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<1x2xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x8x!vhlo.f32_v1> + %0 = "stablehlo.collective_broadcast"(%arg0) { + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64> + } : (tensor<16x8xf32>) -> tensor<16x8xf32> + func.return %0 : tensor<16x8xf32> +} + +// CHECK-LABEL: "default_compare" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @default_compare(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.compare_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: compare_type = #vhlo, + // CHECK-SAME: comparison_direction = #vhlo + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "default_composite" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @default_composite(%arg0: tensor) -> tensor { + // CHECK: "vhlo.composite_v1"(%[[ARG0]]) <{ + // CHECK-SAME: composite_attributes = #vhlo.dict_v1<{}> + // CHECK-SAME: decomposition = #vhlo.string_v1<"composite_target"> + // CHECK-SAME: name = #vhlo.string_v1<"stablehlo.composite_target"> + // CHECK-SAME: version = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.composite"(%arg0) { + name = "stablehlo.composite_target", + decomposition = @composite_target + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "default_convolution" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @default_convolution(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x6x6x16xf32> { + // CHECK: "vhlo.convolution_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: batch_group_count = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: feature_group_count = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: input_batch_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: input_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: input_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: kernel_input_feature_dimension = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: kernel_output_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: kernel_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: lhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: output_batch_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: output_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: output_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: padding = #vhlo.tensor_v1 : tensor<2x2xi64>>, + // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, + // CHECK-SAME: rhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: window_reversal = #vhlo.tensor_v1 : tensor<2xi1>>, + // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<2xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<1x8x8x207x!vhlo.f32_v1>, !vhlo.tensor_v1<3x3x207x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<1x6x6x16x!vhlo.f32_v1> + %0 = "stablehlo.convolution"(%arg0, %arg1) { + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 1 : i64, + batch_group_count = 1 : i64 + } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x6x6x16xf32> + func.return %0 : tensor<1x6x6x16xf32> +} + +// CHECK-LABEL: "default_custom_call" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @default_custom_call(%arg0: tensor) -> tensor { + // CHECK: "vhlo.custom_call_v1"(%[[ARG0]]) <{ + // CHECK-SAME: api_version = #vhlo, + // CHECK-SAME: backend_config = #vhlo.string_v1<"">, + // CHECK-SAME: call_target_name = #vhlo.string_v1<"foo">, + // CHECK-SAME: called_computations = #vhlo.array_v1<[]>, + // CHECK-SAME: has_side_effect = #vhlo.bool_v1, + // CHECK-SAME: operand_layouts = #vhlo.array_v1<[]>, + // CHECK-SAME: output_operand_aliases = #vhlo.array_v1<[]> + // CHECK-SAME: result_layouts = #vhlo.array_v1<[]> + // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo" + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "default_dot_general" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @default_dot_general(%arg0: tensor<8x8x16xf32>, %arg1: tensor<8x16x8xf32>) -> tensor<8x8x8xf32> { + // CHECK: "vhlo.dot_general_v2"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: accumulation_type = #vhlo.type_v1, + // CHECK-SAME: allow_imprecise_accumulation = #vhlo.type_v1, + // CHECK-SAME: lhs_batching_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: lhs_component_count = #vhlo.type_v1, + // CHECK-SAME: lhs_contracting_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: lhs_precision_type = #vhlo.type_v1, + // CHECK-SAME: num_primitive_operations = #vhlo.type_v1, + // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, + // CHECK-SAME: rhs_batching_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: rhs_component_count = #vhlo.type_v1, + // CHECK-SAME: rhs_contracting_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: rhs_precision_type = #vhlo.type_v1 + // CHECK-SAME: }> : (!vhlo.tensor_v1<8x8x16x!vhlo.f32_v1>, !vhlo.tensor_v1<8x16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x8x8x!vhlo.f32_v1> + %0 = "stablehlo.dot_general"(%arg0, %arg1) { + dot_dimension_numbers = #stablehlo.dot< + lhs_batching_dimensions = [0], + lhs_contracting_dimensions = [2], + rhs_batching_dimensions = [0], + rhs_contracting_dimensions = [1] + > + } : (tensor<8x8x16xf32>, tensor<8x16x8xf32>) -> tensor<8x8x8xf32> + func.return %0 : tensor<8x8x8xf32> +} + +// CHECK-LABEL: "dot_general_algorithm" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @dot_general_algorithm(%arg0: tensor<8x8x16xf32>, %arg1: tensor<8x16x8xf32>) -> tensor<8x8x8xf32> { +// CHECK: "vhlo.dot_general_v2"(%[[ARG0]], %[[ARG1]]) <{ +// CHECK-SAME: accumulation_type = #vhlo.type_v1, +// CHECK-SAME: allow_imprecise_accumulation = #vhlo.bool_v1, +// CHECK-SAME: lhs_batching_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, +// CHECK-SAME: lhs_component_count = #vhlo.integer_v1<1 : i64>, +// CHECK-SAME: lhs_contracting_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, +// CHECK-SAME: lhs_precision_type = #vhlo.type_v1, +// CHECK-SAME: num_primitive_operations = #vhlo.integer_v1<1 : i64>, +// CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, +// CHECK-SAME: rhs_batching_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, +// CHECK-SAME: rhs_component_count = #vhlo.integer_v1<1 : i64>, +// CHECK-SAME: rhs_contracting_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, +// CHECK-SAME: rhs_precision_type = #vhlo.type_v1 +// CHECK-SAME: }> : (!vhlo.tensor_v1<8x8x16x!vhlo.f32_v1>, !vhlo.tensor_v1<8x16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x8x8x!vhlo.f32_v1> + %0 = "stablehlo.dot_general"(%arg0, %arg1) { + dot_dimension_numbers = #stablehlo.dot< + lhs_batching_dimensions = [0], + lhs_contracting_dimensions = [2], + rhs_batching_dimensions = [0], + rhs_contracting_dimensions = [1] + >, + algorithm = #stablehlo.dot_algorithm< + lhs_precision_type = tf32, + rhs_precision_type = tf32, + accumulation_type = f32, + lhs_component_count = 1, + rhs_component_count = 1, + num_primitive_operations = 1, + allow_imprecise_accumulation = false + > + } : (tensor<8x8x16xf32>, tensor<8x16x8xf32>) -> tensor<8x8x8xf32> + func.return %0 : tensor<8x8x8xf32> +} + +// CHECK-LABEL: "default_dynamic_broadcast_in_dim" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @default_dynamic_broadcast_in_dim(%arg0: tensor, %arg1: tensor<2xindex>) -> tensor { + // CHECK: "vhlo.dynamic_broadcast_in_dim_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: broadcast_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: known_expanding_dimensions = #vhlo.tensor_v1 : tensor<0xi64>>, + // CHECK-SAME: known_nonexpanding_dimensions = #vhlo.tensor_v1 : tensor<0xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1<2x!vhlo.index_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { + broadcast_dimensions = array + } : (tensor, tensor<2xindex>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "default_dynamic_conv" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @default_dynamic_conv(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>, %arg2: tensor<2x2xi64>) -> tensor<1x?x?x16xf32> { + // CHECK: "vhlo.dynamic_conv_v2"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ + // CHECK-SAME: batch_group_count = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: feature_group_count = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: input_batch_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: input_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: input_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: kernel_input_feature_dimension = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: kernel_output_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: kernel_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: lhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: output_batch_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: output_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: output_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, + // CHECK-SAME: rhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: window_reversal = #vhlo.tensor_v1 : tensor<2xi1>>, + // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<2xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<1x8x8x207x!vhlo.f32_v1>, !vhlo.tensor_v1<3x3x207x16x!vhlo.f32_v1>, !vhlo.tensor_v1<2x2x!vhlo.i64_v1>) -> !vhlo.tensor_v1<1x?x?x16x!vhlo.f32_v1> + %0 = "stablehlo.dynamic_conv"(%arg0, %arg1, %arg2) { + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 1 : i64, + batch_group_count = 1 : i64 + } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<2x2xi64>) -> tensor<1x?x?x16xf32> + func.return %0 : tensor<1x?x?x16xf32> +} + +// CHECK-LABEL: "default_dynamic_gather" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @default_dynamic_gather(%arg0 : tensor<2x4x9xf32>, %arg1 : tensor<1x5x2xi32>, %arg2 : tensor<3xi32>) -> tensor<1x5x8xf32> { + // CHECK: "vhlo.dynamic_gather_v2"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ + // CHECK-SAME: collapsed_slice_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, + // CHECK-SAME: offset_dims = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: operand_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>>, + // CHECK-SAME: start_index_map = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: start_indices_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<2x4x9x!vhlo.f32_v1>, !vhlo.tensor_v1<1x5x2x!vhlo.i32_v1>, !vhlo.tensor_v1<3x!vhlo.i32_v1>) -> !vhlo.tensor_v1<1x5x8x!vhlo.f32_v1> + %0 = "stablehlo.dynamic_gather"(%arg0, %arg1, %arg2) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2], + collapsed_slice_dims = [0, 1], + start_index_map = [0, 1], + index_vector_dim = 2 + > + } : (tensor<2x4x9xf32>, tensor<1x5x2xi32>, tensor<3xi32>) -> tensor<1x5x8xf32> + func.return %0 : tensor<1x5x8xf32> +} + +func.func @default_func(%arg0: tensor) -> tensor { + // CHECK: "vhlo.func_v1"() <{ + // CHECK-SAME: arg_attrs = #vhlo.array_v1<[]>, + // CHECK-SAME: function_type = #vhlo.type_v1) -> !vhlo.tensor_v1>>, + // CHECK-SAME: res_attrs = #vhlo.array_v1<[]>, + // CHECK-SAME: sym_name = #vhlo.string_v1<"default_func">, + // CHECK-SAME: sym_visibility = #vhlo.string_v1<""> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG0:.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: "vhlo.return_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : () -> () + func.return %arg0 : tensor +} + +// CHECK-LABEL: "default_gather" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @default_gather(%arg0 : tensor<2x4x9xf32>, %arg1 : tensor<1x5x2xi32>) -> tensor<1x5x1xf32> { + // CHECK: "vhlo.gather_v2"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: collapsed_slice_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, + // CHECK-SAME: offset_dims = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: operand_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>>, + // CHECK-SAME: slice_sizes = #vhlo.tensor_v1 : tensor<3xi64>>, + // CHECK-SAME: start_index_map = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: start_indices_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<2x4x9x!vhlo.f32_v1>, !vhlo.tensor_v1<1x5x2x!vhlo.i32_v1>) -> !vhlo.tensor_v1<1x5x1x!vhlo.f32_v1> + %0 = "stablehlo.gather"(%arg0, %arg1) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2], + collapsed_slice_dims = [0, 1], + start_index_map = [0, 1], + index_vector_dim = 2 + >, + slice_sizes = array + } : (tensor<2x4x9xf32>, tensor<1x5x2xi32>) -> tensor<1x5x1xf32> + func.return %0 : tensor<1x5x1xf32> +} + +// CHECK-LABEL: "default_infeed" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @default_infeed(%arg0: !stablehlo.token) -> (tensor, !stablehlo.token) { + // CHECK: "vhlo.infeed_v1"(%[[ARG0]]) <{ + // CHECK-SAME: infeed_config = #vhlo.string_v1<"">, + // CHECK-SAME{LITERAL}: layout = #vhlo.array_v1<[]> + // CHECK-SAME: }> : (!vhlo.token_v1) -> (!vhlo.tensor_v1, !vhlo.token_v1) + %0:2 = "stablehlo.infeed"(%arg0) : (!stablehlo.token) -> (tensor, !stablehlo.token) + func.return %0#0, %0#1 : tensor, !stablehlo.token +} + +// CHECK-LABEL: "default_outfeed" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @default_outfeed(%arg0: tensor, %arg1: !stablehlo.token) -> !stablehlo.token { + // CHECK: "vhlo.outfeed_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: outfeed_config = #vhlo.string_v1<""> + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.token_v1) -> !vhlo.token_v1 + %0 = "stablehlo.outfeed"(%arg0, %arg1) : (tensor, !stablehlo.token) -> !stablehlo.token + func.return %0 : !stablehlo.token +} + +// CHECK-LABEL: "default_recv" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @default_recv(%arg0: !stablehlo.token) -> (tensor, !stablehlo.token) { + // CHECK: "vhlo.recv_v1"(%[[ARG0]]) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: channel_type = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: is_host_transfer = #vhlo.bool_v1 + // CHECK-SAME: }> : (!vhlo.token_v1) -> (!vhlo.tensor_v1, !vhlo.token_v1) + %0:2 = "stablehlo.recv"(%arg0) { + channel_handle = #stablehlo.channel_handle + } : (!stablehlo.token) -> (tensor, !stablehlo.token) + func.return %0#0, %0#1 : tensor, !stablehlo.token +} + +// CHECK-LABEL: "default_send" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @default_send(%arg0: tensor, %arg1: !stablehlo.token) -> !stablehlo.token { + // CHECK: "vhlo.send_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: channel_type = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: is_host_transfer = #vhlo.bool_v1 + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.token_v1) -> !vhlo.token_v1 + %0 = "stablehlo.send"(%arg0, %arg1) { + channel_handle = #stablehlo.channel_handle + } : (tensor, !stablehlo.token) -> !stablehlo.token + func.return %0 : !stablehlo.token +} + +// CHECK-LABEL: "default_reduce_scatter" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @default_reduce_scatter(%arg0: tensor<16xf32>) -> tensor<16xf32> { + // CHECK: "vhlo.reduce_scatter_v1"(%[[ARG0]]) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>>, + // CHECK-SAME: scatter_dimension = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: use_global_device_ids = #vhlo.bool_v1 + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.reduce_scatter"(%arg0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = "stablehlo.add"(%arg1, %arg2) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + scatter_dimension = 0 : i64, + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> + } : (tensor<16xf32>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK-LABEL: "default_reduce_window" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @default_reduce_window(%arg0: tensor<2x17x31x7xf32>, %arg1: tensor) -> tensor<2x16x30x7xf32> { + // CHECK: "vhlo.reduce_window_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: base_dilations = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME{LITERAL}: padding = #vhlo.tensor_v1 : tensor<4x2xi64>>, + // CHECK-SAME: window_dilations = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME: window_dimensions = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<4xi64>> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG2:arg.*]]: !vhlo.tensor_v1, %[[ARG3:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.maximum_v1"(%[[ARG2]], %[[ARG3]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<2x17x31x7x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<2x16x30x7x!vhlo.f32_v1> + %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = "stablehlo.maximum"(%arg2, %arg3) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + window_dimensions = array + } : (tensor<2x17x31x7xf32>, tensor) -> tensor<2x16x30x7xf32> + func.return %0 : tensor<2x16x30x7xf32> +} + +// CHECK-LABEL: "default_scatter" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @default_scatter(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>, %arg2: tensor<10x300xf32>) -> tensor<200x100x300xf32> { + // CHECK: "vhlo.scatter_v2"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ + // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, + // CHECK-SAME: input_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>>, + // CHECK-SAME: inserted_window_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: scatter_dims_to_operand_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: scatter_indices_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>>, + // CHECK-SAME: unique_indices = #vhlo.bool_v1, + // CHECK-SAME: update_window_dims = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG3:arg.*]]: !vhlo.tensor_v1, %[[ARG4:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG3]], %[[ARG4]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<200x100x300x!vhlo.f32_v1>, !vhlo.tensor_v1<10x2x!vhlo.i32_v1>, !vhlo.tensor_v1<10x300x!vhlo.f32_v1>) -> !vhlo.tensor_v1<200x100x300x!vhlo.f32_v1> + %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.add"(%arg3, %arg4) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [1], + inserted_window_dims = [0, 1], + scatter_dims_to_operand_dims = [0, 1], + index_vector_dim = 1 + > + } : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> tensor<200x100x300xf32> + func.return %0 : tensor<200x100x300xf32> +} + +// CHECK-LABEL: "default_select_and_scatter" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @default_select_and_scatter(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<10x23x23x64xf32>, %arg2: tensor) -> tensor<10x24x24x64xf32> { + // CHECK: "vhlo.select_and_scatter_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ + // CHECK-SAME: padding = #vhlo.tensor_v1 : tensor<4x2xi64>>, + // CHECK-SAME: window_dimensions = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<4xi64>> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG31:arg.*]]: !vhlo.tensor_v1, %[[ARG41:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL11:.*]] = "vhlo.compare_v1"(%[[ARG31]], %[[ARG41]]) <{compare_type = #vhlo, comparison_direction = #vhlo}> + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL11]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }, { + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG32:arg.*]]: !vhlo.tensor_v1, %[[ARG42:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL12:.*]] = "vhlo.add_v1"(%[[ARG32]], %[[ARG42]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL12]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<10x24x24x64x!vhlo.f32_v1>, !vhlo.tensor_v1<10x23x23x64x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<10x24x24x64x!vhlo.f32_v1> + %0 = "stablehlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.compare"(%arg3, %arg4) {compare_type = #stablehlo, comparison_direction = #stablehlo} : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.add"(%arg3, %arg4) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + window_dimensions = array + } : (tensor<10x24x24x64xf32>, tensor<10x23x23x64xf32>, tensor) -> tensor<10x24x24x64xf32> + func.return %0 : tensor<10x24x24x64xf32> +} + +// CHECK-LABEL: "default_sort" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @default_sort(%arg0: tensor<16xf32>) -> tensor<16xf32> { + // CHECK: "vhlo.sort_v1"(%[[ARG0]]) <{ + // CHECK-SAME: dimension = #vhlo.integer_v1<-1 : i64> + // CHECK-SAME: is_stable = #vhlo.bool_v1 + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.compare_v1"(%[[ARG1]], %[[ARG2]]) <{compare_type = #vhlo, comparison_direction = #vhlo}> + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.sort"(%arg0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = "stablehlo.compare"(%arg1, %arg2) {compare_type = #stablehlo, comparison_direction = #stablehlo} : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) : (tensor<16xf32>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// ============ OPS ============ + +// CHECK-LABEL: "op_abs" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_abs(%arg0: tensor) -> tensor { + // CHECK: "vhlo.abs_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.abs"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_add" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_add(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_after_all" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_after_all(%arg0: !stablehlo.token) -> !stablehlo.token { + // CHECK: "vhlo.after_all_v1"(%[[ARG0]]) : (!vhlo.token_v1) -> !vhlo.token_v1 + %0 = "stablehlo.after_all"(%arg0) : (!stablehlo.token) -> !stablehlo.token + func.return %0 : !stablehlo.token +} + +// CHECK-LABEL: "op_all_gather" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_all_gather(%arg0: tensor<16x8xf32>) -> tensor<16x16xf32> { + // CHECK: "vhlo.all_gather_v2"(%[[ARG0]]) <{ + // CHECK-SAME: all_gather_dim = #vhlo.integer_v1<1 : i64> + // CHECK-SAME: channel_id = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>>, + // CHECK-SAME: use_global_device_ids = #vhlo.bool_v1 + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x!vhlo.f32_v1> + %0 = "stablehlo.all_gather"(%arg0) { + all_gather_dim = 1 : i64, + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>, + channel_handle = #stablehlo.channel_handle, + use_global_device_ids + } : (tensor<16x8xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} + +// CHECK-LABEL: "op_all_reduce" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_all_reduce(%arg0: tensor) -> tensor { + // CHECK: "vhlo.all_reduce_v2"(%[[ARG0]]) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>>, + // CHECK-SAME: use_global_device_ids = #vhlo.bool_v1 + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.all_reduce"(%arg0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = "stablehlo.add"(%arg1, %arg2) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>, + channel_handle = #stablehlo.channel_handle, + use_global_device_ids + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_all_reduce_with_promotable_types" +func.func @op_all_reduce_with_promotable_types(%operand: tensor) -> tensor { + // CHECK: "vhlo.all_reduce_v2"(%[[ARG0:.*]]) + // CHECK: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK: "vhlo.return_v1"(%[[VAL1:.*]]) : (!vhlo.tensor_v1) -> () + // CHECK: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %result = "stablehlo.all_reduce"(%operand) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + "stablehlo.return"(%0) : (tensor) -> () + }) { + replica_groups = dense<[[0, 1]]> : tensor<1x2xi64>, + channel_handle = #stablehlo.channel_handle, + use_global_device_ids + } : (tensor) -> tensor + + func.return %result : tensor +} + +// CHECK-LABEL: "default_all_reduce_variadic" +func.func @default_all_reduce_variadic(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { + %0:2 = "stablehlo.all_reduce"(%arg0, %arg1) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = "stablehlo.add"(%arg2, %arg3) : (tensor, tensor) -> (tensor) + "stablehlo.return"(%1) : (tensor) -> () + }) { + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> + } : (tensor, tensor) -> (tensor, tensor) + func.return %0#0, %0#1 : tensor, tensor +} + +// CHECK-LABEL: "op_all_to_all" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_all_to_all(%arg0: tensor<4x16xf32>) -> tensor<16x4xf32> { + // CHECK: "vhlo.all_to_all_v2"(%[[ARG0]]) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: concat_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<1x4xi64>>, + // CHECK-SAME: split_count = #vhlo.integer_v1<4 : i64> + // CHECK-SAME: split_dimension = #vhlo.integer_v1<1 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1<4x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x4x!vhlo.f32_v1> + %0 = "stablehlo.all_to_all"(%arg0) { + split_dimension = 1 : i64, + concat_dimension = 0 : i64, + split_count = 4 : i64, + replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + channel_handle = #stablehlo.channel_handle + } : (tensor<4x16xf32>) -> tensor<16x4xf32> + func.return %0 : tensor<16x4xf32> +} + +// CHECK-LABEL: "op_and" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_and(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.and_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.and"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_atan2" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_atan2(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.atan2_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.atan2"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_batch_norm_grad" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}, %[[ARG3:.*]]: {{.*}}, %[[ARG4:.*]]: {{.*}}) +func.func @op_batch_norm_grad(%arg0: tensor<16x16x16x16xf32>, %arg1: tensor<16xf32>, %arg2: tensor<16xf32>, %arg3: tensor<16xf32>, %arg4: tensor<16x16x16x16xf32>) -> (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) { + // CHECK: "vhlo.batch_norm_grad_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]]) <{ + // CHECK-SAME: epsilon = #vhlo.float_v1<1.000000e-03 : !vhlo.f32_v1>, + // CHECK-SAME: feature_index = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1>) -> (!vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>) + %0:3 = "stablehlo.batch_norm_grad"(%arg0, %arg1, %arg2, %arg3, %arg4) { + epsilon = 0.001 : f32, + feature_index = 0 : i64 + } : (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16x16x16x16xf32>) -> (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) + func.return %0#0, %0#1, %0#2 : tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32> +} + +// CHECK-LABEL: "op_batch_norm_inference" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}, %[[ARG3:.*]]: {{.*}}, %[[ARG4:.*]]: {{.*}}) +func.func @op_batch_norm_inference(%arg0: tensor<16x16x16x16xf32>, %arg1: tensor<16xf32>, %arg2: tensor<16xf32>, %arg3: tensor<16xf32>, %arg4: tensor<16xf32>) -> tensor<16x16x16x16xf32> { + // CHECK: "vhlo.batch_norm_inference_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]]) <{ + // CHECK-SAME: epsilon = #vhlo.float_v1<1.000000e-03 : !vhlo.f32_v1>, + // CHECK-SAME: feature_index = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1> + %0 = "stablehlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4) { + epsilon = 0.001 : f32, + feature_index = 0 : i64 + } : (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>, tensor<16xf32>) -> tensor<16x16x16x16xf32> + func.return %0 : tensor<16x16x16x16xf32> +} + +// CHECK-LABEL: "op_batch_norm_training" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @op_batch_norm_training(%arg0: tensor<16x16x16x16xf32>, %arg1: tensor<16xf32>, %arg2: tensor<16xf32>) -> (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) { + // CHECK: "vhlo.batch_norm_training_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ + // CHECK-SAME: epsilon = #vhlo.float_v1<1.000000e-03 : !vhlo.f32_v1>, + // CHECK-SAME: feature_index = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>) -> (!vhlo.tensor_v1<16x16x16x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x!vhlo.f32_v1>) + %0:3 = "stablehlo.batch_norm_training"(%arg0, %arg1, %arg2) { + epsilon = 0.001 : f32, + feature_index = 0 : i64 + } : (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) -> (tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32>) + func.return %0#0, %0#1, %0#2 : tensor<16x16x16x16xf32>, tensor<16xf32>, tensor<16xf32> +} + +// CHECK-LABEL: "op_bitcast_convert" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_bitcast_convert(%arg0: tensor) -> tensor { + // CHECK: "vhlo.bitcast_convert_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.bitcast_convert"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_broadcast_in_dim" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_broadcast_in_dim(%arg0: tensor<16xf32>) -> tensor<16x16xf32> { + // CHECK: "vhlo.broadcast_in_dim_v1"(%[[ARG0]]) <{ + // CHECK-SAME: broadcast_dimensions = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x!vhlo.f32_v1> + %0 = "stablehlo.broadcast_in_dim"(%arg0) { + broadcast_dimensions = array + } : (tensor<16xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} + +// CHECK-LABEL: "op_broadcast" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_broadcast(%arg0: tensor<16xf32>) -> tensor<16x16xf32> { + // CHECK: "vhlo.broadcast_v1"(%[[ARG0]]) <{ + // CHECK-SAME: broadcast_sizes = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x!vhlo.f32_v1> + %0 = "stablehlo.broadcast"(%arg0) { + broadcast_sizes = array + } : (tensor<16xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} + +// CHECK-LABEL: "op_case" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_case(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.case_v1"(%[[ARG0]]) ({ + // CHECK-NEXT: "vhlo.return_v1"(%[[ARG1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.case"(%arg0) ({ + "stablehlo.return"(%arg1) : (tensor) -> () + }) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_cbrt" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_cbrt(%arg0: tensor) -> tensor { + // CHECK: "vhlo.cbrt_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.cbrt"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_ceil" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_ceil(%arg0: tensor) -> tensor { + // CHECK: "vhlo.ceil_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.ceil"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_cholesky" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_cholesky(%arg0: tensor<1x16x16xf32>) -> tensor<1x16x16xf32> { + // CHECK: "vhlo.cholesky_v1"(%[[ARG0]]) <{ + // CHECK-SAME: lower = #vhlo.bool_v1 + // CHECK-SAME: }> : (!vhlo.tensor_v1<1x16x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<1x16x16x!vhlo.f32_v1> + %0 = "stablehlo.cholesky"(%arg0) { + lower = true + } : (tensor<1x16x16xf32>) -> tensor<1x16x16xf32> + func.return %0 : tensor<1x16x16xf32> +} + +// CHECK-LABEL: "op_clamp" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @op_clamp(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: "vhlo.clamp_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.clamp"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_count_leading_zeros" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_count_leading_zeros(%arg0: tensor) -> tensor { + // CHECK: "vhlo.count_leading_zeros_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.count_leading_zeros"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_collective_permute" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_collective_permute(%arg0: tensor<16x8xf32>) -> tensor<16x8xf32> { + // CHECK: "vhlo.collective_permute_v1"(%[[ARG0]]) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME{LITERAL}: source_target_pairs = #vhlo.tensor_v1 : tensor<3x2xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x8x!vhlo.f32_v1> + %0 = "stablehlo.collective_permute"(%arg0) { + source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>, + channel_handle = #stablehlo.channel_handle + } : (tensor<16x8xf32>) -> tensor<16x8xf32> + func.return %0 : tensor<16x8xf32> +} + +// CHECK-LABEL: "op_compare" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_compare(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.compare_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: compare_type = #vhlo, + // CHECK-SAME: comparison_direction = #vhlo + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.compare"(%arg0, %arg1) { + comparison_direction = #stablehlo, + compare_type = #stablehlo + } : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_complex" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_complex(%arg0: tensor, %arg1: tensor) -> tensor> { + // CHECK: "vhlo.complex_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1> + %0 = "stablehlo.complex"(%arg0, %arg1) : (tensor, tensor) -> tensor> + func.return %0 : tensor> +} + +// CHECK-LABEL: "op_composite" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_composite(%arg0: tensor) -> tensor { + // CHECK: "vhlo.composite_v1"(%[[ARG0]]) <{ + // CHECK-SAME: composite_attributes = #vhlo.dict_v1<{#vhlo.string_v1<"my_int"> = #vhlo.integer_v1<1 : i64>, #vhlo.string_v1<"my_string"> = #vhlo.string_v1<"foo">}> + // CHECK-SAME: decomposition = #vhlo.string_v1<"composite_target"> + // CHECK-SAME: name = #vhlo.string_v1<"stablehlo.composite_target"> + // CHECK-SAME: version = #vhlo.integer_v1<1 : i32> + // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.composite"(%arg0) { + name = "stablehlo.composite_target", + decomposition = @composite_target, + version = 1 : i32, + composite_attributes = { + my_string = "foo", + my_int = 1 : i64 + } + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_concatenate" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_concatenate(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> tensor<16xf32> { + // CHECK: "vhlo.concatenate_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: dimension = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1<8x!vhlo.f32_v1>, !vhlo.tensor_v1<8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.concatenate"(%arg0, %arg1) { + dimension = 0 : i64 + } : (tensor<8xf32>, tensor<8xf32>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK-LABEL: "op_constant" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_constant(%arg0: tensor) -> tensor { + // CHECK: "vhlo.constant_v1"() <{ + // CHECK-SAME: value = #vhlo.tensor_v1 : tensor> + // CHECK-SAME: }> : () -> !vhlo.tensor_v1 + %0 = "stablehlo.constant"() { + value = dense<0.0> : tensor + } : () -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_convert" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_convert(%arg0: tensor) -> tensor { + // CHECK: "vhlo.convert_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.convert"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_convolution" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_convolution(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x7x7x16xf32> { + // CHECK: "vhlo.convolution_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: batch_group_count = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: feature_group_count = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: input_batch_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: input_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: input_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: kernel_input_feature_dimension = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: kernel_output_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: kernel_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: lhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: output_batch_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: output_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: output_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: padding = #vhlo.tensor_v1 : tensor<2x2xi64>>, + // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, + // CHECK-SAME: rhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: window_reversal = #vhlo.tensor_v1 : tensor<2xi1>>, + // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<2xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<1x8x8x207x!vhlo.f32_v1>, !vhlo.tensor_v1<3x3x207x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<1x7x7x16x!vhlo.f32_v1> + %0 = "stablehlo.convolution"(%arg0, %arg1) { + window_strides = array, + padding = dense<1> : tensor<2x2xi64>, + lhs_dilation = array, + rhs_dilation = array, + window_reversal = array, + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 1 : i64, + batch_group_count = 1 : i64, + precision_config = [#stablehlo, #stablehlo] + } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x7x7x16xf32> + func.return %0 : tensor<1x7x7x16xf32> +} + +// CHECK-LABEL: "op_cosine" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_cosine(%arg0: tensor) -> tensor { + // CHECK: "vhlo.cosine_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.cosine"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_create_token" +func.func @op_create_token() -> !stablehlo.token { + // CHECK: "vhlo.create_token_v1"() : () -> !vhlo.token_v1 + %0 = "stablehlo.create_token"() : () -> !stablehlo.token + func.return %0 : !stablehlo.token +} + +// CHECK-LABEL: "op_cross_replica_sum" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_cross_replica_sum(%arg0: tensor) -> tensor { + // CHECK: "vhlo.cross-replica-sum_v1"(%[[ARG0]]) <{ + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.cross-replica-sum"(%arg0) { + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64> + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_custom_call" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_custom_call(%arg0: tensor) -> tensor { + // CHECK: "vhlo.custom_call_v1"(%[[ARG0]]) <{ + // CHECK-SAME: api_version = #vhlo, + // CHECK-SAME: backend_config = #vhlo.string_v1<"\08\03\1A\02">, + // CHECK-SAME: call_target_name = #vhlo.string_v1<"foo">, + // CHECK-SAME: called_computations = #vhlo.array_v1<[#vhlo.string_v1<"foo">]>, + // CHECK-SAME: has_side_effect = #vhlo.bool_v1, + // CHECK-SAME: operand_layouts = #vhlo.array_v1<[#vhlo.tensor_v1 : tensor<0xindex>>]>, + // CHECK-SAME: output_operand_aliases = #vhlo.array_v1<[ + // CHECK-SAME: #vhlo.output_operand_alias_v1< + // CHECK-SAME: outputTupleIndices = [], + // CHECK-SAME: operandIndex = 0, + // CHECK-SAME: operandTupleIndices = []>]> + // CHECK-SAME: result_layouts = #vhlo.array_v1<[#vhlo.tensor_v1 : tensor<0xindex>>]> + // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo", + has_side_effect = true, + backend_config = "\08\03\1A\02", + api_version = 2 : i32, + called_computations = [@foo], + operand_layouts = [dense<> : tensor<0xindex>], + output_operand_aliases = [ + #stablehlo.output_operand_alias], + result_layouts = [dense<> : tensor<0xindex>] + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_custom_call_empty_result_layout" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func public @op_custom_call_empty_result_layout(%arg0: tensor) -> tensor { + // %0 = "vhlo.custom_call_v1"(%arg0) <{>}> : (!vhlo.tensor_v1) -> !vhlo.tuple_v1<> + // CHECK: "vhlo.custom_call_v1"(%[[ARG0]]) <{ + // CHECK-SAME: api_version = #vhlo, + // CHECK-SAME: backend_config = #vhlo.string_v1<"">, + // CHECK-SAME: call_target_name = #vhlo.string_v1<"empty_output">, + // CHECK-SAME: called_computations = #vhlo.array_v1<[]>, + // CHECK-SAME: has_side_effect = #vhlo.bool_v1, + // CHECK-SAME: operand_layouts = #vhlo.array_v1<[#vhlo.tensor_v1 : tensor<0xindex>>]>, + // CHECK-SAME: output_operand_aliases = #vhlo.array_v1<[]>, + // CHECK-SAME: result_layouts = #vhlo.array_v1<[]> + // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tuple_v1<> + %0 = "stablehlo.custom_call"(%arg0) <{ + api_version = 2 : i32, + call_target_name = "empty_output", + has_side_effect = true, + operand_layouts = [dense<> : tensor<0xindex>], + result_layouts = [] + }> : (tensor) -> tuple<> + return %arg0 : tensor +} + +// CHECK-LABEL: "op_divide" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_divide(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.divide_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.divide"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_dot_general" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_dot_general(%arg0: tensor<8x8x16xf32>, %arg1: tensor<8x16x8xf32>) -> tensor<8x8x8xf32> { + // CHECK: "vhlo.dot_general_v2"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: accumulation_type = #vhlo.type_v1, + // CHECK-SAME: allow_imprecise_accumulation = #vhlo.type_v1, + // CHECK-SAME: lhs_batching_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: lhs_component_count = #vhlo.type_v1, + // CHECK-SAME: lhs_contracting_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: lhs_precision_type = #vhlo.type_v1, + // CHECK-SAME: num_primitive_operations = #vhlo.type_v1, + // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, + // CHECK-SAME: rhs_batching_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: rhs_component_count = #vhlo.type_v1, + // CHECK-SAME: rhs_contracting_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: rhs_precision_type = #vhlo.type_v1 + // CHECK-SAME: }> : (!vhlo.tensor_v1<8x8x16x!vhlo.f32_v1>, !vhlo.tensor_v1<8x16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x8x8x!vhlo.f32_v1> + %0 = "stablehlo.dot_general"(%arg0, %arg1) { + dot_dimension_numbers = #stablehlo.dot< + lhs_batching_dimensions = [0], + lhs_contracting_dimensions = [2], + rhs_batching_dimensions = [0], + rhs_contracting_dimensions = [1] + >, + precision_config = [#stablehlo, #stablehlo] + } : (tensor<8x8x16xf32>, tensor<8x16x8xf32>) -> tensor<8x8x8xf32> + func.return %0 : tensor<8x8x8xf32> +} + +// CHECK-LABEL: "op_dot" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_dot(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { + // CHECK: "vhlo.dot_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]> + // CHECK-SAME: }> : (!vhlo.tensor_v1<8x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x8x!vhlo.f32_v1> + %0 = "stablehlo.dot"(%arg0, %arg1) { + precision_config = [#stablehlo, #stablehlo] + } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> + func.return %0 : tensor<8x8xf32> +} + +// CHECK-LABEL: "op_dynamic_broadcast_in_dim" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_dynamic_broadcast_in_dim(%arg0: tensor, %arg1: tensor<2xindex>) -> tensor { + // CHECK: "vhlo.dynamic_broadcast_in_dim_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: broadcast_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: known_expanding_dimensions = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: known_nonexpanding_dimensions = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1<2x!vhlo.index_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { + broadcast_dimensions = array, + known_expanding_dimensions = array, + known_nonexpanding_dimensions = array + } : (tensor, tensor<2xindex>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_dynamic_conv" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @op_dynamic_conv(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>, %arg2: tensor<2x2xi64>) -> tensor<1x?x?x16xf32> { + // CHECK: "vhlo.dynamic_conv_v2"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ + // CHECK-SAME: batch_group_count = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: feature_group_count = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: input_batch_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: input_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: input_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: kernel_input_feature_dimension = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: kernel_output_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: kernel_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: lhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: output_batch_dimension = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: output_feature_dimension = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: output_spatial_dimensions = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: precision_config = #vhlo.array_v1<[#vhlo, #vhlo]>, + // CHECK-SAME: rhs_dilation = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: window_reversal = #vhlo.tensor_v1 : tensor<2xi1>>, + // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<2xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<1x8x8x207x!vhlo.f32_v1>, !vhlo.tensor_v1<3x3x207x16x!vhlo.f32_v1>, !vhlo.tensor_v1<2x2x!vhlo.i64_v1>) -> !vhlo.tensor_v1<1x?x?x16x!vhlo.f32_v1> + %0 = "stablehlo.dynamic_conv"(%arg0, %arg1, %arg2) { + window_strides = array, + lhs_dilation = array, + rhs_dilation = array, + window_reversal = array, + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]>, + feature_group_count = 1 : i64, + batch_group_count = 1 : i64, + precision_config = [#stablehlo, #stablehlo] + } : (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>, tensor<2x2xi64>) -> tensor<1x?x?x16xf32> + func.return %0 : tensor<1x?x?x16xf32> +} + +// CHECK-LABEL: "op_dynamic_gather" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @op_dynamic_gather(%arg0 : tensor<2x4x9xf32>, %arg1 : tensor<1x5x2xi32>, %arg2 : tensor<3xi32>) -> tensor<1x5x8xf32> { + // CHECK: "vhlo.dynamic_gather_v2"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ + // CHECK-SAME: collapsed_slice_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, + // CHECK-SAME: offset_dims = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: operand_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>>, + // CHECK-SAME: start_index_map = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: start_indices_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<2x4x9x!vhlo.f32_v1>, !vhlo.tensor_v1<1x5x2x!vhlo.i32_v1>, !vhlo.tensor_v1<3x!vhlo.i32_v1>) -> !vhlo.tensor_v1<1x5x8x!vhlo.f32_v1> + %0 = "stablehlo.dynamic_gather"(%arg0, %arg1, %arg2) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2], + collapsed_slice_dims = [0, 1], + start_index_map = [0, 1], + index_vector_dim = 2 + >, + indices_are_sorted = true + } : (tensor<2x4x9xf32>, tensor<1x5x2xi32>, tensor<3xi32>) -> tensor<1x5x8xf32> + func.return %0 : tensor<1x5x8xf32> +} + +// CHECK-LABEL: "op_dynamic_gather_with_batching_dims" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @op_dynamic_gather_with_batching_dims(%arg0 : tensor<5x2x4x9xf32>, %arg1 : tensor<1x5x2xi32>, %arg2 : tensor<4xi32>) -> tensor<1x5x8xf32> { + // CHECK: "vhlo.dynamic_gather_v2"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ + // CHECK-SAME: collapsed_slice_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, + // CHECK-SAME: offset_dims = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: operand_batching_dims = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: start_index_map = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: start_indices_batching_dims = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<5x2x4x9x!vhlo.f32_v1>, !vhlo.tensor_v1<1x5x2x!vhlo.i32_v1>, !vhlo.tensor_v1<4x!vhlo.i32_v1>) -> !vhlo.tensor_v1<1x5x8x!vhlo.f32_v1> + %0 = "stablehlo.dynamic_gather"(%arg0, %arg1, %arg2) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2], + collapsed_slice_dims = [1, 2], + operand_batching_dims = [0], + start_indices_batching_dims = [1], + start_index_map = [1, 2], + index_vector_dim = 2 + >, + indices_are_sorted = true + } : (tensor<5x2x4x9xf32>, tensor<1x5x2xi32>, tensor<4xi32>) -> tensor<1x5x8xf32> + func.return %0 : tensor<1x5x8xf32> +} + +// CHECK-LABEL: "op_dynamic_iota" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_dynamic_iota(%arg0: tensor<1xindex>) -> tensor { + // CHECK: "vhlo.dynamic_iota_v1"(%[[ARG0]]) <{ + // CHECK-SAME: iota_dimension = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.dynamic_iota"(%arg0) { + iota_dimension = 0 : i64 + } : (tensor<1xindex>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_dynamic_pad" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}, %[[ARG3:.*]]: {{.*}}, %[[ARG4:.*]]: {{.*}}) +func.func @op_dynamic_pad(%arg0: tensor, %arg1: tensor, %arg2: tensor<1xindex>, %arg3: tensor<1xindex>, %arg4: tensor<1xindex>) -> tensor { + // CHECK: "vhlo.dynamic_pad_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1<1x!vhlo.index_v1>, !vhlo.tensor_v1<1x!vhlo.index_v1>, !vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.dynamic_pad"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor, tensor, tensor<1xindex>, tensor<1xindex>, tensor<1xindex>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_dynamic_reshape" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_dynamic_reshape(%arg0: tensor<16xf32>, %arg1: tensor<2xindex>) -> tensor { + // CHECK: "vhlo.dynamic_reshape_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<2x!vhlo.index_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.dynamic_reshape"(%arg0, %arg1) : (tensor<16xf32>, tensor<2xindex>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_dynamic_slice" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_dynamic_slice(%arg0: tensor<16xf32>, %arg1: tensor) -> tensor<4xf32> { + // CHECK: "vhlo.dynamic_slice_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: slice_sizes = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<4x!vhlo.f32_v1> + %0 = "stablehlo.dynamic_slice"(%arg0, %arg1) { + slice_sizes = array + } : (tensor<16xf32>, tensor) -> tensor<4xf32> + func.return %0 : tensor<4xf32> +} + +// CHECK-LABEL: "op_dynamic_update_slice" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @op_dynamic_update_slice(%arg0: tensor<16xf32>, %arg1: tensor<4xf32>, %arg2: tensor) -> tensor<16xf32> { + // CHECK: "vhlo.dynamic_update_slice_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1<4x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.dynamic_update_slice"(%arg0, %arg1, %arg2) : (tensor<16xf32>, tensor<4xf32>, tensor) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK-LABEL: "op_einsum" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_einsum(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { + // CHECK: "vhlo.einsum_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: einsum_config = #vhlo.string_v1<"ab,bc->ac"> + // CHECK-SAME: }> : (!vhlo.tensor_v1<8x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x8x!vhlo.f32_v1> + %0 = "stablehlo.einsum"(%arg0, %arg1) { + einsum_config = "ab,bc->ac" + } : (tensor<8x16xf32>, tensor<16x8xf32>) -> tensor<8x8xf32> + func.return %0 : tensor<8x8xf32> +} + +// CHECK-LABEL: "op_exponential_minus_one" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_exponential_minus_one(%arg0: tensor) -> tensor { + // CHECK: "vhlo.exponential_minus_one_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.exponential_minus_one"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_exponential" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_exponential(%arg0: tensor) -> tensor { + // CHECK: "vhlo.exponential_v2"(%[[ARG0]]) <{result_accuracy = #vhlo.result_accuracy_v1>}> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.exponential"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_fft" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_fft(%arg0: tensor<16xcomplex>) -> tensor<16xcomplex> { + // CHECK: "vhlo.fft_v1"(%[[ARG0]]) <{ + // CHECK-SAME: fft_length = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: fft_type = #vhlo + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x!vhlo.complex_v1>) -> !vhlo.tensor_v1<16x!vhlo.complex_v1> + %0 = "stablehlo.fft"(%arg0) { + fft_type = #stablehlo, + fft_length = array + } : (tensor<16xcomplex>) -> tensor<16xcomplex> + func.return %0 : tensor<16xcomplex> +} + +// CHECK-LABEL: "op_floor" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_floor(%arg0: tensor) -> tensor { + // CHECK: "vhlo.floor_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.floor"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +func.func private @op_func(%arg0: tensor {stablehlo.arg = "0"}) -> (tensor {stablehlo.result = "0"}) { + // CHECK: "vhlo.func_v1"() <{ + // CHECK-SAME: arg_attrs = #vhlo.array_v1<[#vhlo.dict_v1<{#vhlo.string_v1<"stablehlo.arg"> = #vhlo.string_v1<"0">}>]>, + // CHECK-SAME: function_type = #vhlo.type_v1) -> !vhlo.tensor_v1>>, + // CHECK-SAME: res_attrs = #vhlo.array_v1<[#vhlo.dict_v1<{#vhlo.string_v1<"stablehlo.result"> = #vhlo.string_v1<"0">}>]>, + // CHECK-SAME: sym_name = #vhlo.string_v1<"op_func">, + // CHECK-SAME: sym_visibility = #vhlo.string_v1<"private"> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG0:.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: "vhlo.return_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : () -> () + + func.return %arg0 : tensor +} + +// CHECK-LABEL: "op_gather" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_gather(%arg0 : tensor<2x4x9xf32>, %arg1 : tensor<1x5x2xi32>) -> tensor<1x5x1xf32> { + // CHECK: "vhlo.gather_v2"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: collapsed_slice_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, + // CHECK-SAME: offset_dims = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: operand_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>>, + // CHECK-SAME: slice_sizes = #vhlo.tensor_v1 : tensor<3xi64>>, + // CHECK-SAME: start_index_map = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: start_indices_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<2x4x9x!vhlo.f32_v1>, !vhlo.tensor_v1<1x5x2x!vhlo.i32_v1>) -> !vhlo.tensor_v1<1x5x1x!vhlo.f32_v1> + %0 = "stablehlo.gather"(%arg0, %arg1) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2], + collapsed_slice_dims = [0, 1], + start_index_map = [0, 1], + index_vector_dim = 2 + >, + slice_sizes = array, + indices_are_sorted = true + } : (tensor<2x4x9xf32>, tensor<1x5x2xi32>) -> tensor<1x5x1xf32> + func.return %0 : tensor<1x5x1xf32> +} + +// CHECK-LABEL: "op_gather_with_batching_dims" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_gather_with_batching_dims(%arg0 : tensor<5x2x4x9xf32>, %arg1 : tensor<1x5x2xi32>) -> tensor<1x5x1xf32> { + // CHECK: "vhlo.gather_v2"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: collapsed_slice_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, + // CHECK-SAME: offset_dims = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: operand_batching_dims = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: slice_sizes = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME: start_index_map = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: start_indices_batching_dims = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<5x2x4x9x!vhlo.f32_v1>, !vhlo.tensor_v1<1x5x2x!vhlo.i32_v1>) -> !vhlo.tensor_v1<1x5x1x!vhlo.f32_v1> + %0 = "stablehlo.gather"(%arg0, %arg1) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2], + collapsed_slice_dims = [1, 2], + operand_batching_dims = [0], + start_indices_batching_dims = [1], + start_index_map = [1, 2], + index_vector_dim = 2 + >, + slice_sizes = array, + indices_are_sorted = true + } : (tensor<5x2x4x9xf32>, tensor<1x5x2xi32>) -> tensor<1x5x1xf32> + func.return %0 : tensor<1x5x1xf32> +} + +// CHECK-LABEL: "op_get_dimension_size" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_get_dimension_size(%arg0: tensor) -> tensor { + // CHECK: "vhlo.get_dimension_size_v1"(%[[ARG0]]) <{ + // CHECK-SAME: dimension = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.get_dimension_size"(%arg0) { + dimension = 0 : i64 + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_get_tuple_element" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_get_tuple_element(%arg0: tuple, tensor>) -> tensor { + // CHECK: "vhlo.get_tuple_element_v1"(%[[ARG0]]) <{ + // CHECK-SAME: index = #vhlo.integer_v1<0 : i32> + // CHECK-SAME: }> : (!vhlo.tuple_v1, !vhlo.tensor_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.get_tuple_element"(%arg0) { + index = 0 : i32 + } : (tuple, tensor>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_if" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @op_if(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: "vhlo.if_v1"(%[[ARG0]]) ({ + // CHECK-NEXT: "vhlo.return_v1"(%[[ARG1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }, { + // CHECK-NEXT: "vhlo.return_v1"(%[[ARG2]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.if"(%arg0) ({ + "stablehlo.return"(%arg1) : (tensor) -> () + }, { + "stablehlo.return"(%arg2) : (tensor) -> () + }) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_imag" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_imag(%arg0: tensor>) -> tensor { + // CHECK: "vhlo.imag_v1"(%[[ARG0]]) : (!vhlo.tensor_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.imag"(%arg0) : (tensor>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_infeed" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_infeed(%arg0: !stablehlo.token) -> (tensor, !stablehlo.token) { + // CHECK: "vhlo.infeed_v1"(%[[ARG0]]) <{ + // CHECK-SAME: infeed_config = #vhlo.string_v1<"foo">, + // CHECK-SAME{LITERAL}: layout = #vhlo.array_v1<[#vhlo.array_v1<[]>]> + // CHECK-SAME: }> : (!vhlo.token_v1) -> (!vhlo.tensor_v1, !vhlo.token_v1) + %0:2 = "stablehlo.infeed"(%arg0) { + infeed_config = "foo", + layout = [[]] + } : (!stablehlo.token) -> (tensor, !stablehlo.token) + func.return %0#0, %0#1 : tensor, !stablehlo.token +} + +// CHECK-LABEL: "op_iota" +func.func @op_iota() -> tensor<16xf32> { + // CHECK: "vhlo.iota_v1"() <{ + // CHECK-SAME: iota_dimension = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: }> : () -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.iota"() { + iota_dimension = 0 : i64 + } : () -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK-LABEL: "op_is_finite" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_is_finite(%arg0: tensor) -> tensor { + // CHECK: "vhlo.is_finite_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.is_finite"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_log" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_log(%arg0: tensor) -> tensor { + // CHECK: "vhlo.log_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.log"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_log_plus_one" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_log_plus_one(%arg0: tensor) -> tensor { + // CHECK: "vhlo.log_plus_one_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.log_plus_one"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_logistic" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_logistic(%arg0: tensor) -> tensor { + // CHECK: "vhlo.logistic_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.logistic"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_map" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_map(%arg0: tensor<16xf32>) -> tensor<16xf32> { + // CHECK: "vhlo.map_v1"(%[[ARG0]]) <{ + // CHECK-SAME: dimensions = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.abs_v1"(%[[ARG1]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.map"(%arg0) ({ + ^bb0(%arg1: tensor): + %1 = "stablehlo.abs"(%arg1) : (tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + dimensions = array + } : (tensor<16xf32>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK-LABEL: "op_maximum" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_maximum(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.maximum_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.maximum"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_minimum" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_minimum(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.minimum_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.minimum"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_multiply" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_multiply(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.multiply_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.multiply"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_negate" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_negate(%arg0: tensor) -> tensor { + // CHECK: "vhlo.negate_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.negate"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_not" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_not(%arg0: tensor) -> tensor { + // CHECK: "vhlo.not_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.not"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_optimization_barrier" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_optimization_barrier(%arg0: tensor) -> tensor { + // CHECK: "vhlo.optimization_barrier_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.optimization_barrier"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_or" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_or(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.or_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.or"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_outfeed" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_outfeed(%arg0: tensor, %arg1: !stablehlo.token) -> !stablehlo.token { + // CHECK: "vhlo.outfeed_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: outfeed_config = #vhlo.string_v1<"foo"> + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.token_v1) -> !vhlo.token_v1 + %0 = "stablehlo.outfeed"(%arg0, %arg1) { + outfeed_config = "foo" + } : (tensor, !stablehlo.token) -> !stablehlo.token + func.return %0 : !stablehlo.token +} + +// CHECK-LABEL: "op_pad" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_pad(%arg0: tensor<8xf32>, %arg1: tensor) -> tensor<16xf32> { + // CHECK: "vhlo.pad_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: edge_padding_high = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: edge_padding_low = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: interior_padding = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<8x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.pad"(%arg0, %arg1) { + edge_padding_high = array, + edge_padding_low = array, + interior_padding = array + } : (tensor<8xf32>, tensor) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK-LABEL: "op_popcnt" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_popcnt(%arg0: tensor) -> tensor { + // CHECK: "vhlo.popcnt_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.popcnt"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_power" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_power(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.power_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.power"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_real_dynamic_slice" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}, %[[ARG3:.*]]: {{.*}}) +func.func @op_real_dynamic_slice(%arg0: tensor, %arg1: tensor<1xindex>, %arg2: tensor<1xindex>, %arg3: tensor<1xindex>) -> tensor { + // CHECK: "vhlo.real_dynamic_slice_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1<1x!vhlo.index_v1>, !vhlo.tensor_v1<1x!vhlo.index_v1>, !vhlo.tensor_v1<1x!vhlo.index_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.real_dynamic_slice"(%arg0, %arg1, %arg2, %arg3) : (tensor, tensor<1xindex>, tensor<1xindex>, tensor<1xindex>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_real" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_real(%arg0: tensor>) -> tensor { + // CHECK: "vhlo.real_v1"(%[[ARG0]]) : (!vhlo.tensor_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.real"(%arg0) : (tensor>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_recv" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_recv(%arg0: !stablehlo.token) -> (tensor, !stablehlo.token) { + // CHECK: "vhlo.recv_v1"(%[[ARG0]]) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: channel_type = #vhlo.integer_v1<3 : i64>, + // CHECK-SAME: is_host_transfer = #vhlo.bool_v1 + // CHECK-SAME: }> : (!vhlo.token_v1) -> (!vhlo.tensor_v1, !vhlo.token_v1) + %0:2 = "stablehlo.recv"(%arg0) { + channel_handle = #stablehlo.channel_handle, + is_host_transfer = true + } : (!stablehlo.token) -> (tensor, !stablehlo.token) + func.return %0#0, %0#1 : tensor, !stablehlo.token +} + +// CHECK-LABEL: "op_reduce" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_reduce(%arg0: tensor<16xf32>, %arg1: tensor) -> tensor { + // CHECK: "vhlo.reduce_v1"(%[[ARG0]], %[[ARG1]]) + // CHECK: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK: "vhlo.return_v1"(%[[VAL1:.*]]) : (!vhlo.tensor_v1) -> () + // CHECK: }) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.reduce"(%arg0, %arg1) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = "stablehlo.add"(%arg2, %arg3) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + dimensions = array + } : (tensor<16xf32>, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_reduce_precision" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_reduce_precision(%arg0: tensor) -> tensor { + // CHECK: "vhlo.reduce_precision_v1"(%[[ARG0]]) <{ + // CHECK-SAME: exponent_bits = #vhlo.integer_v1<8 : i32> + // CHECK-SAME: mantissa_bits = #vhlo.integer_v1<10 : i32> + // CHECK-SAME: }> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.reduce_precision"(%arg0) { + exponent_bits = 8 : i32, + mantissa_bits = 10 : i32 + } : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK_lABEL: "op_reduce_with_promotable_types" +func.func @op_reduce_with_promotable_types(%arg0: tensor<4x4xf32>, %arg1 : tensor) + -> (tensor<4xf64>) { + // CHECK: "vhlo.reduce_v1"(%[[ARG0:.*]], %[[ARG1:.*]]) + // CHECK: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK: "vhlo.return_v1"(%[[VAL1:.*]]) : (!vhlo.tensor_v1) -> () + // CHECK: }) : (!vhlo.tensor_v1<4x4x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<4x!vhlo.f64_v1> + %0 = "stablehlo.reduce"(%arg0, %arg1) ({ + ^bb0(%arg2: tensor, %arg3: tensor ): + %1 = "stablehlo.add"(%arg2, %arg3) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + + }) {dimensions = array} : (tensor<4x4xf32>, tensor) -> tensor<4xf64> + + func.return %0: tensor<4xf64> +} + +// CHECK-LABEL: "op_reduce_scatter" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_reduce_scatter(%arg0: tensor<16xf32>) -> tensor<16xf32> { + // CHECK: "vhlo.reduce_scatter_v1"(%[[ARG0]]) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME{LITERAL}: replica_groups = #vhlo.tensor_v1 : tensor<2x1xi64>>, + // CHECK-SAME: scatter_dimension = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: use_global_device_ids = #vhlo.bool_v1 + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.reduce_scatter"(%arg0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = "stablehlo.add"(%arg1, %arg2) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + scatter_dimension = 0 : i64, + replica_groups = dense<[[0], [1]]> : tensor<2x1xi64>, + channel_handle = #stablehlo.channel_handle, + use_global_device_ids + } : (tensor<16xf32>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK_lABEL: "op_reduce_scatter_with_promotable_types" +func.func @op_reduce_scatter_with_promotable_types(%data: tensor<4x16xf32>) -> tensor<4x4xf64> { + // CHECK: "vhlo.reduce_scatter_v1"(%[[ARG0:.*]]) + // CHECK: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK: "vhlo.return_v1"(%[[VAL1:.*]]) : (!vhlo.tensor_v1) -> () + // CHECK: }) : (!vhlo.tensor_v1<4x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<4x4x!vhlo.f64_v1> + %0 = "stablehlo.reduce_scatter"(%data) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = stablehlo.add %arg2, %arg3 : tensor + "stablehlo.return"(%1) : (tensor) -> () + }) {replica_groups = dense<[[0, 1, 2, 3]]> : tensor<1x4xi64>, + scatter_dimension = 1 : i64, + channel_handle = #stablehlo.channel_handle, + use_global_device_ids} : (tensor<4x16xf32>) -> tensor<4x4xf64> + func.return %0 : tensor<4x4xf64> +} + + +// CHECK-LABEL: "op_reduce_window" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_reduce_window(%arg0: tensor<2x17x31x7xf32>, %arg1: tensor) -> tensor<2x9x16x7xf32> { + // CHECK: "vhlo.reduce_window_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: base_dilations = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME{LITERAL}: padding = #vhlo.tensor_v1 : tensor<4x2xi64>>, + // CHECK-SAME: window_dilations = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME: window_dimensions = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<4xi64>> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG2:arg.*]]: !vhlo.tensor_v1, %[[ARG3:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.maximum_v1"(%[[ARG2]], %[[ARG3]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<2x17x31x7x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<2x9x16x7x!vhlo.f32_v1> + %0 = "stablehlo.reduce_window"(%arg0, %arg1) ({ + ^bb0(%arg2: tensor, %arg3: tensor): + %1 = "stablehlo.maximum"(%arg2, %arg3) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + window_dimensions = array, + window_strides = array, + base_dilations = array, + window_dilations = array, + padding = dense<[[0, 0], [2, 0], [0, 2], [0, 0]]> : tensor<4x2xi64> + } : (tensor<2x17x31x7xf32>, tensor) -> tensor<2x9x16x7xf32> + func.return %0 : tensor<2x9x16x7xf32> +} + +// CHECK-LABEL: "op_reduce_window_with_promotable_types" +func.func @op_reduce_window_with_promotable_types(%arg0: tensor<4x2xf32>, + %arg1: tensor<4x2xf32>, %init0: tensor, %init1: tensor) -> + (tensor<2x2xf64>, tensor<2x2xf32>) { + // CHECK: "vhlo.reduce_window_v1"(%[[ARG0:.*]], %[[ARG1:.*]], %[[ARG2:.*]], %[[ARG3:.*]]) + // CHECK: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1, %[[ARG3:arg.*]]: !vhlo.tensor_v1, %[[ARG4:arg.*]]: !vhlo.tensor_v1): + // CHECK: "vhlo.return_v1"(%[[VAL1:.*]], %[[VAL2:.*]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> () + // CHECK: }) : (!vhlo.tensor_v1<4x2x!vhlo.f32_v1>, !vhlo.tensor_v1<4x2x!vhlo.f32_v1>, !vhlo.tensor_v1, !vhlo.tensor_v1) -> (!vhlo.tensor_v1<2x2x!vhlo.f64_v1>, !vhlo.tensor_v1<2x2x!vhlo.f32_v1>) + %0:2 = "stablehlo.reduce_window"(%arg0, %arg1, %init0, %init1) ({ + ^bb0(%a0: tensor, %a1: tensor, %b0: tensor, + %b1: tensor): + %2 = stablehlo.add %a0, %b0 : tensor + %3 = stablehlo.add %a1, %b1 : tensor + "stablehlo.return"(%2,%3) : (tensor, tensor) -> () + }) + { padding = dense<[[2, 2], [0, 0]]> : tensor<2x2xi64>, + window_dimensions = array, + window_strides = array } + : (tensor<4x2xf32>, tensor<4x2xf32>, tensor, tensor) -> + (tensor<2x2xf64>, tensor<2x2xf32>) + func.return %0#0, %0#1 : tensor<2x2xf64>, tensor<2x2xf32> +} + +// CHECK-LABEL: "op_remainder" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_remainder(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.remainder_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.remainder"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_replica_id" +func.func @op_replica_id() -> tensor { + // CHECK: "vhlo.replica_id_v1"() : () -> !vhlo.tensor_v1 + %0 = "stablehlo.replica_id"() : () -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_partition_id" +func.func @op_partition_id() -> tensor { + // CHECK: "vhlo.partition_id_v1"() : () -> !vhlo.tensor_v1 + %0 = "stablehlo.partition_id"() : () -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_reshape" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_reshape(%arg0: tensor<16xf32>) -> tensor<4x4xf32> { + // CHECK: "vhlo.reshape_v1"(%[[ARG0]]) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<4x4x!vhlo.f32_v1> + %0 = "stablehlo.reshape"(%arg0) : (tensor<16xf32>) -> tensor<4x4xf32> + func.return %0 : tensor<4x4xf32> +} + +// CHECK-LABEL: "op_return" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_return(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.case_v1"(%[[ARG0]]) ({ + // CHECK-NEXT: "vhlo.return_v1"(%[[ARG1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.case"(%arg0) ({ + "stablehlo.return"(%arg1) : (tensor) -> () + }) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_reverse" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_reverse(%arg0: tensor<16xf32>) -> tensor<16xf32> { + // CHECK: "vhlo.reverse_v1"(%[[ARG0]]) <{ + // CHECK-SAME: dimensions = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.reverse"(%arg0) { + dimensions = array + } : (tensor<16xf32>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK-LABEL: "op_rng_bit_generator" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_rng_bit_generator(%arg0: tensor) -> (tensor, tensor) { + // CHECK: "vhlo.rng_bit_generator_v1"(%[[ARG0]]) <{ + // CHECK-SAME: rng_algorithm = #vhlo + // CHECK-SAME: }> : (!vhlo.tensor_v1) -> (!vhlo.tensor_v1, !vhlo.tensor_v1) + %0:2 = "stablehlo.rng_bit_generator"(%arg0) { + rng_algorithm = #stablehlo + } : (tensor) -> (tensor, tensor) + func.return %0#0, %0#1 : tensor, tensor +} + +// CHECK-LABEL: "op_rng" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @op_rng(%arg0: tensor, %arg1: tensor, %arg2: tensor<0xindex>) -> tensor { + // CHECK: "vhlo.rng_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ + // CHECK-SAME: rng_distribution = #vhlo + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1<0x!vhlo.index_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.rng"(%arg0, %arg1, %arg2) { + rng_distribution = #stablehlo + } : (tensor, tensor, tensor<0xindex>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_round_nearest_afz" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_round_nearest_afz(%arg0: tensor) -> tensor { + // CHECK: "vhlo.round_nearest_afz_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.round_nearest_afz"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_round_nearest_even" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_round_nearest_even(%arg0: tensor) -> tensor { + // CHECK: "vhlo.round_nearest_even_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.round_nearest_even"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_rsqrt" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_rsqrt(%arg0: tensor) -> tensor { + // CHECK: "vhlo.rsqrt_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.rsqrt"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_scatter" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @op_scatter(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>, %arg2: tensor<10x300xf32>) -> tensor<200x100x300xf32> { + // CHECK: "vhlo.scatter_v2"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ + // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, + // CHECK-SAME: input_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>>, + // CHECK-SAME: inserted_window_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: scatter_dims_to_operand_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: scatter_indices_batching_dims = #vhlo.tensor_v1 : tensor<0xi64>>, + // CHECK-SAME: unique_indices = #vhlo.bool_v1, + // CHECK-SAME: update_window_dims = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG3:arg.*]]: !vhlo.tensor_v1, %[[ARG4:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG3]], %[[ARG4]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<200x100x300x!vhlo.f32_v1>, !vhlo.tensor_v1<10x2x!vhlo.i32_v1>, !vhlo.tensor_v1<10x300x!vhlo.f32_v1>) -> !vhlo.tensor_v1<200x100x300x!vhlo.f32_v1> + %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.add"(%arg3, %arg4) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [1], + inserted_window_dims = [0, 1], + scatter_dims_to_operand_dims = [0, 1], + index_vector_dim = 1 + >, + indices_are_sorted = true, + unique_indices = true + } : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> tensor<200x100x300xf32> + func.return %0 : tensor<200x100x300xf32> +} + +// CHECK-LABEL: "op_scatter_with_batching_dims" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @op_scatter_with_batching_dims(%arg0: tensor<10x200x100x300xf32>, %arg1: tensor<10x2xi32>, %arg2: tensor<10x300xf32>) -> tensor<10x200x100x300xf32> { + // CHECK: "vhlo.scatter_v2"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ + // CHECK-SAME: index_vector_dim = #vhlo.integer_v1<1 : i64>, + // CHECK-SAME: indices_are_sorted = #vhlo.bool_v1, + // CHECK-SAME: input_batching_dims = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: inserted_window_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: scatter_dims_to_operand_dims = #vhlo.tensor_v1 : tensor<2xi64>>, + // CHECK-SAME: scatter_indices_batching_dims = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: unique_indices = #vhlo.bool_v1, + // CHECK-SAME: update_window_dims = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG3:arg.*]]: !vhlo.tensor_v1, %[[ARG4:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.add_v1"(%[[ARG3]], %[[ARG4]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<10x200x100x300x!vhlo.f32_v1>, !vhlo.tensor_v1<10x2x!vhlo.i32_v1>, !vhlo.tensor_v1<10x300x!vhlo.f32_v1>) -> !vhlo.tensor_v1<10x200x100x300x!vhlo.f32_v1> + %0 = "stablehlo.scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.add"(%arg3, %arg4) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [1], + inserted_window_dims = [1, 2], + input_batching_dims = [0], + scatter_dims_to_operand_dims = [1, 2], + scatter_indices_batching_dims = [0], + index_vector_dim = 1 + >, + indices_are_sorted = true, + unique_indices = true + } : (tensor<10x200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> tensor<10x200x100x300xf32> + func.return %0 : tensor<10x200x100x300xf32> +} + +// CHECK_lABEL: "op_scatter_with_promotable_types" +func.func @op_scatter_with_promotable_types(%input_tensor: tensor<200x100x300xf32>, + %scatter_indices: tensor<10x2xi32>, %updates: tensor<10x300xf32>) -> + tensor<200x100x300xf64> { + // CHECK: "vhlo.scatter_v2"(%[[ARG0:.*]], %[[ARG1:.*]], %[[ARG2:.*]]) + // CHECK: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK: "vhlo.return_v1"(%[[VAL1:.*]]) : (!vhlo.tensor_v1) -> () + // CHECK: }) : (!vhlo.tensor_v1<200x100x300x!vhlo.f32_v1>, !vhlo.tensor_v1<10x2x!vhlo.i32_v1>, !vhlo.tensor_v1<10x300x!vhlo.f32_v1>) -> !vhlo.tensor_v1<200x100x300x!vhlo.f64_v1> + %0 = "stablehlo.scatter" (%input_tensor, %scatter_indices, %updates) ({ + ^bb0(%lhs: tensor, %rhs: tensor): + %add = stablehlo.add %lhs, %rhs : tensor + "stablehlo.return"(%add) : (tensor) -> () + }) { + scatter_dimension_numbers = #stablehlo.scatter< + update_window_dims = [1], + inserted_window_dims = [0, 1], + scatter_dims_to_operand_dims = [0, 1], + index_vector_dim = 1 + >, + indices_are_sorted = true, + unique_indices = true + } : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<10x300xf32>) -> + tensor<200x100x300xf64> + func.return %0 : tensor<200x100x300xf64> +} + +// CHECK-LABEL: "op_select_and_scatter" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @op_select_and_scatter(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<12x13x13x66xf32>, %arg2: tensor) -> tensor<10x24x24x64xf32> { + // CHECK: "vhlo.select_and_scatter_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) <{ + // CHECK-SAME: padding = #vhlo.tensor_v1 : tensor<4x2xi64>>, + // CHECK-SAME: window_dimensions = #vhlo.tensor_v1 : tensor<4xi64>>, + // CHECK-SAME: window_strides = #vhlo.tensor_v1 : tensor<4xi64>> + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG31:arg.*]]: !vhlo.tensor_v1, %[[ARG41:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL11:.*]] = "vhlo.compare_v1"(%[[ARG31]], %[[ARG41]]) <{compare_type = #vhlo, comparison_direction = #vhlo}> : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL11]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }, { + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG32:arg.*]]: !vhlo.tensor_v1, %[[ARG42:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL12:.*]] = "vhlo.add_v1"(%[[ARG32]], %[[ARG42]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL12]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<10x24x24x64x!vhlo.f32_v1>, !vhlo.tensor_v1<12x13x13x66x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<10x24x24x64x!vhlo.f32_v1> + %0 = "stablehlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.compare"(%arg3, %arg4) {compare_type = #stablehlo, comparison_direction = #stablehlo} : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.add"(%arg3, %arg4) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + window_dimensions = array, + window_strides = array, + padding = dense<1> : tensor<4x2xi64> + } : (tensor<10x24x24x64xf32>, tensor<12x13x13x66xf32>, tensor) -> tensor<10x24x24x64xf32> + func.return %0 : tensor<10x24x24x64xf32> +} + +// CHECK-LABEL: "op_select_and_scatter_with_promotable_types" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @op_select_and_scatter_with_promotable_types(%arg0: tensor<10x24x24x64xf32>, %arg1: tensor<12x13x13x66xf32>, %arg2: tensor) -> tensor<10x24x24x64xf64> { + // CHECK: "vhlo.select_and_scatter_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) + // CHECK: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK: %[[VAL:.*]] = "vhlo.add_v1"(%[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK: "vhlo.return_v1"(%[[VAL]]) : (!vhlo.tensor_v1) -> () + // CHECK: }) : (!vhlo.tensor_v1<10x24x24x64x!vhlo.f32_v1>, !vhlo.tensor_v1<12x13x13x66x!vhlo.f32_v1>, !vhlo.tensor_v1) -> !vhlo.tensor_v1<10x24x24x64x!vhlo.f64_v1> + %0 = "stablehlo.select_and_scatter"(%arg0, %arg1, %arg2) ({ + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.compare"(%arg3, %arg4) {compare_type = #stablehlo, comparison_direction = #stablehlo} : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }, { + ^bb0(%arg3: tensor, %arg4: tensor): + %1 = "stablehlo.add"(%arg3, %arg4) : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + window_dimensions = array, + window_strides = array, + padding = dense<1> : tensor<4x2xi64> + } : (tensor<10x24x24x64xf32>, tensor<12x13x13x66xf32>, tensor) -> tensor<10x24x24x64xf64> + func.return %0 : tensor<10x24x24x64xf64> +} + +// CHECK-LABEL: "op_select" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}, %[[ARG2:.*]]: {{.*}}) +func.func @op_select(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor { + // CHECK: "vhlo.select_v1"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.select"(%arg0, %arg1, %arg2) : (tensor, tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_send" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_send(%arg0: tensor, %arg1: !stablehlo.token) -> !stablehlo.token { + // CHECK: "vhlo.send_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: channel_id = #vhlo.integer_v1<0 : i64>, + // CHECK-SAME: channel_type = #vhlo.integer_v1<2 : i64>, + // CHECK-SAME: is_host_transfer = #vhlo.bool_v1 + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.token_v1) -> !vhlo.token_v1 + %0 = "stablehlo.send"(%arg0, %arg1) { + channel_handle = #stablehlo.channel_handle, + is_host_transfer = true + } : (tensor, !stablehlo.token) -> !stablehlo.token + func.return %0 : !stablehlo.token +} + +// CHECK-LABEL: "op_set_dimension_size" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_set_dimension_size(%arg0: tensor, %arg1: tensor) -> tensor<16xf32> { + // CHECK: "vhlo.set_dimension_size_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: dimension = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.set_dimension_size"(%arg0, %arg1) { + dimension = 0 : i64 + } : (tensor, tensor) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK-LABEL: "op_shift_left" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_shift_left(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.shift_left_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.shift_left"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_shift_right_arithmetic" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_shift_right_arithmetic(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.shift_right_arithmetic_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.shift_right_arithmetic"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_shift_right_logical" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_shift_right_logical(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.shift_right_logical_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.shift_right_logical"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_sign" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_sign(%arg0: tensor) -> tensor { + // CHECK: "vhlo.sign_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.sign"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_sine" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_sine(%arg0: tensor) -> tensor { + // CHECK: "vhlo.sine_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.sine"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_slice" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_slice(%arg0: tensor<16xf32>) -> tensor<4xf32> { + // CHECK: "vhlo.slice_v1"(%[[ARG0]]) <{ + // CHECK-SAME: limit_indices = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: start_indices = #vhlo.tensor_v1 : tensor<1xi64>>, + // CHECK-SAME: strides = #vhlo.tensor_v1 : tensor<1xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<4x!vhlo.f32_v1> + %0 = "stablehlo.slice"(%arg0) { + start_indices = array, + limit_indices = array, + strides = array + } : (tensor<16xf32>) -> tensor<4xf32> + func.return %0 : tensor<4xf32> +} + +// CHECK-LABEL: "op_sort" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_sort(%arg0: tensor<16xf32>) -> tensor<16xf32> { + // CHECK: "vhlo.sort_v1"(%[[ARG0]]) <{ + // CHECK-SAME: dimension = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: is_stable = #vhlo.bool_v1 + // CHECK-SAME: }> ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1, %[[ARG2:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: %[[VAL1:.*]] = "vhlo.compare_v1"(%[[ARG1]], %[[ARG2]]) <{compare_type = #vhlo, comparison_direction = #vhlo}> + // CHECK-NEXT: "vhlo.return_v1"(%[[VAL1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1<16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x!vhlo.f32_v1> + %0 = "stablehlo.sort"(%arg0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %1 = "stablehlo.compare"(%arg1, %arg2) {compare_type = #stablehlo, comparison_direction = #stablehlo} : (tensor, tensor) -> tensor + "stablehlo.return"(%1) : (tensor) -> () + }) { + dimension = 0 : i64, + is_stable = true + } : (tensor<16xf32>) -> tensor<16xf32> + func.return %0 : tensor<16xf32> +} + +// CHECK-LABEL: "op_sqrt" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_sqrt(%arg0: tensor) -> tensor { + // CHECK: "vhlo.sqrt_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.sqrt"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_subtract" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_subtract(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.subtract_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.subtract"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_tan" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_tan(%arg0: tensor) -> tensor { + // CHECK: "vhlo.tan_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.tan"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_tanh" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_tanh(%arg0: tensor) -> tensor { + // CHECK: "vhlo.tanh_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.tanh"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_torch_index_select" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_torch_index_select(%arg0: tensor<5x1x5xf32>, %arg1: tensor<2xi32>) -> tensor<2x1x5xf32> { + // CHECK: "vhlo.torch_index_select_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: batch_dims = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: dim = #vhlo.integer_v1<0 : i64> + // CHECK-SAME: }> : (!vhlo.tensor_v1<5x1x5x!vhlo.f32_v1>, !vhlo.tensor_v1<2x!vhlo.i32_v1>) -> !vhlo.tensor_v1<2x1x5x!vhlo.f32_v1> + %0 = "stablehlo.torch_index_select"(%arg0, %arg1) { + dim = 0 : i64, + batch_dims = 0 : i64 + } : (tensor<5x1x5xf32>, tensor<2xi32>) -> tensor<2x1x5xf32> + func.return %0 : tensor<2x1x5xf32> +} + +// CHECK-LABEL: "op_transpose" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_transpose(%arg0: tensor<16x8xf32>) -> tensor<8x16xf32> { + // CHECK: "vhlo.transpose_v1"(%[[ARG0]]) <{ + // CHECK-SAME: permutation = #vhlo.tensor_v1 : tensor<2xi64>> + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x8x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x16x!vhlo.f32_v1> + %0 = "stablehlo.transpose"(%arg0) { + permutation = array + } : (tensor<16x8xf32>) -> tensor<8x16xf32> + func.return %0 : tensor<8x16xf32> +} + +// CHECK-LABEL: "op_triangular_solve" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_triangular_solve(%arg0: tensor<16x16xf32>, %arg1: tensor<16x16xf32>) -> tensor<16x16xf32> { + // CHECK: "vhlo.triangular_solve_v1"(%[[ARG0]], %[[ARG1]]) <{ + // CHECK-SAME: left_side = #vhlo.bool_v1, + // CHECK-SAME: lower = #vhlo.bool_v1, + // CHECK-SAME: transpose_a = #vhlo, + // CHECK-SAME: unit_diagonal = #vhlo.bool_v1 + // CHECK-SAME: }> : (!vhlo.tensor_v1<16x16x!vhlo.f32_v1>, !vhlo.tensor_v1<16x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<16x16x!vhlo.f32_v1> + %0 = "stablehlo.triangular_solve"(%arg0, %arg1) { + left_side = true, + lower = true, + unit_diagonal = true, + transpose_a = #stablehlo + } : (tensor<16x16xf32>, tensor<16x16xf32>) -> tensor<16x16xf32> + func.return %0 : tensor<16x16xf32> +} + +// CHECK-LABEL: "op_tuple" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_tuple(%arg0: tensor) -> tuple> { + // CHECK: "vhlo.tuple_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tuple_v1> + %0 = "stablehlo.tuple"(%arg0) : (tensor) -> tuple> + func.return %0 : tuple> +} + +// CHECK-LABEL: "op_unary_einsum" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_unary_einsum(%arg0: tensor<8x16xf32>) -> tensor<8xf32> { + // CHECK: "vhlo.unary_einsum_v1"(%[[ARG0]]) <{ + // CHECK-SAME: einsum_config = #vhlo.string_v1<"ab->a"> + // CHECK-SAME: }> : (!vhlo.tensor_v1<8x16x!vhlo.f32_v1>) -> !vhlo.tensor_v1<8x!vhlo.f32_v1> + %0 = "stablehlo.unary_einsum"(%arg0) { + einsum_config = "ab->a" + } : (tensor<8x16xf32>) -> tensor<8xf32> + func.return %0 : tensor<8xf32> +} + +// CHECK-LABEL: "op_uniform_dequantize" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_uniform_dequantize(%arg0: tensor>) -> tensor { + // CHECK: "vhlo.uniform_dequantize_v1"(%[[ARG0]]) : (!vhlo.tensor_v1>) -> !vhlo.tensor_v1 + %0 = "stablehlo.uniform_dequantize"(%arg0) : (tensor>) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "op_uniform_quantize" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_uniform_quantize(%arg0: tensor) -> tensor> { + // CHECK: "vhlo.uniform_quantize_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1> + %0 = "stablehlo.uniform_quantize"(%arg0) : (tensor) -> tensor> + func.return %0 : tensor> +} + +// CHECK-LABEL: "op_while" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @op_while(%arg0: tensor) -> tensor { + // CHECK: "vhlo.while_v1"(%[[ARG0]]) ({ + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1): + // CHECK-NEXT: "vhlo.return_v1"(%[[ARG1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }, { + // CHECK-NEXT: ^[[BB:bb.*]](%[[ARG1:arg.*]]: !vhlo.tensor_v1) + // CHECK-NEXT: "vhlo.return_v1"(%[[ARG1]]) : (!vhlo.tensor_v1) -> () + // CHECK-NEXT: }) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.while"(%arg0) ({ + ^bb0(%arg1: tensor): + "stablehlo.return"(%arg1) : (tensor) -> () + }, { + ^bb0(%arg1: tensor): + "stablehlo.return"(%arg1) : (tensor) -> () + }) : (tensor) -> tensor + func.return %0: tensor +} + +// CHECK-LABEL: "op_xor" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @op_xor(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.xor_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.xor"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// ============ TYPES ============ + +// CHECK-LABEL: "type_i1" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_i1(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.and_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.and"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_i2" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_i2(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_i4" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_i4(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_i8" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_i8(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_i16" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_i16(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_i32" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_i32(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_i64" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_i64(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_ui2" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_ui2(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_ui4" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_ui4(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_ui8" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_ui8(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_ui16" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_ui16(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_ui32" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_ui32(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_ui64" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_ui64(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f4E2M1FN" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_f4E2M1FN(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f6E2M3FN" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_f6E2M3FN(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f6E3M2FN" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_f6E3M2FN(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f8E3M4" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_f8E3M4(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f8E4M3" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_f8E4M3(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f8E4M3FN" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_f8E4M3FN(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f8E5M2" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_f8E5M2(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f8E4M3FNUZ" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_f8E4M3FNUZ(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f8E4M3B11FNUZ" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_f8E4M3B11FNUZ(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f8E5M2FNUZ" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_f8E5M2FNUZ(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f8E8M0FNU" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_f8E8M0FNU(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_bf16" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_bf16(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f16" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_f16(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f32" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_f32(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_f64" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_f64(%arg0: tensor, %arg1: tensor) -> tensor { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1, !vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor, tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_complex_f32" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_complex_f32(%arg0: tensor>, %arg1: tensor>) -> tensor> { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1>, !vhlo.tensor_v1>) -> !vhlo.tensor_v1> + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor>, tensor>) -> tensor> + func.return %0 : tensor> +} + +// CHECK-LABEL: "type_complex_f64" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_complex_f64(%arg0: tensor>, %arg1: tensor>) -> tensor> { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1>, !vhlo.tensor_v1>) -> !vhlo.tensor_v1> + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor>, tensor>) -> tensor> + func.return %0 : tensor> +} + +// CHECK-LABEL: "type_tf32" +// CHECK: #vhlo.type_v1 +func.func @type_tf32() attributes {stablehlo.attr = tf32 } { + return +} + +// CHECK-LABEL: "type_none" +// CHECK: #vhlo.type_v1 +func.func @type_none() attributes {stablehlo.attr = none } { + return +} + +// CHECK-LABEL: "type_dynamism_ranked" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @type_dynamism_ranked(%arg0: tensor) -> tensor { + // CHECK: "vhlo.abs_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + %0 = "stablehlo.abs"(%arg0) : (tensor) -> tensor + func.return %0 : tensor +} + +// CHECK-LABEL: "type_per_tensor_quantization" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}, %[[ARG1:.*]]: {{.*}}) +func.func @type_per_tensor_quantization(%arg0: tensor>, %arg1: tensor>) -> tensor> { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG1]]) : (!vhlo.tensor_v1>, !vhlo.tensor_v1>) -> !vhlo.tensor_v1> + %0 = "stablehlo.add"(%arg0, %arg1) : (tensor>, tensor>) -> tensor> + func.return %0 : tensor> +} + +// CHECK-LABEL: "type_per_axis_quantization" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @type_per_axis_quantization(%arg0: tensor<2x!quant.uniform>) -> tensor<2x!quant.uniform> { + // CHECK: "vhlo.add_v1"(%[[ARG0]], %[[ARG0]]) : (!vhlo.tensor_v1<2x!vhlo.quant_per_axis_v1>, !vhlo.tensor_v1<2x!vhlo.quant_per_axis_v1>) -> !vhlo.tensor_v1<2x!vhlo.quant_per_axis_v1> + %0 = stablehlo.add %arg0, %arg0 : tensor<2x!quant.uniform> + func.return %0 : tensor<2x!quant.uniform> +} + +// CHECK: function_type = #vhlo.type_v1 !vhlo.token_v1>> +// CHECK-LABEL: "type_token_callee" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @type_token_callee(%arg0: !stablehlo.token) -> !stablehlo.token { + // CHECK: "vhlo.return_v1"(%[[ARG0]]) : (!vhlo.token_v1) -> () + return %arg0 : !stablehlo.token +} + +// CHECK: function_type = #vhlo.type_v1 !vhlo.token_v1>> +// CHECK-LABEL: "type_token_caller" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @type_token_caller(%arg0: !stablehlo.token) -> !stablehlo.token { + // CHECK: "vhlo.call_v1"(%[[ARG0]]) <{callee = #vhlo.string_v1<"type_token_callee">} + // CHECK-SAME: (!vhlo.token_v1) -> !vhlo.token_v1 + %0 = func.call @type_token_callee(%arg0) : (!stablehlo.token) -> !stablehlo.token + return %0 : !stablehlo.token +} + +// CHECK-LABEL: "type_tuple" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) +func.func @type_tuple(%arg0: tuple>) -> tuple { + %0 = "stablehlo.custom_call"(%arg0) { + call_target_name = "foo" + // CHECK: (!vhlo.tuple_v1>) -> !vhlo.tuple_v1 + } : (tuple>) -> tuple + return %0 : tuple +} + +// ============ DEPENDENCIES ============ + +func.func @composite_target(%arg0: tensor) -> tensor { + return %arg0: tensor +} diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_9_0.mlir.bc b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_9_0.mlir.bc new file mode 100644 index 00000000000..2f3f073e1c5 Binary files /dev/null and b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_9_0.mlir.bc differ diff --git a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.mlir b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.mlir index 37e378e47dc..8e916d995e2 100644 --- a/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.mlir +++ b/stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.mlir @@ -250,6 +250,36 @@ func.func @attr_fft_type_irfft(%arg0: tensor<9xcomplex>) -> tensor<16xf32> func.return %0 : tensor<16xf32> } +// CHECK-LABEL: "attr_result_accuracy_HIGHEST" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}} +func.func @attr_result_accuracy_HIGHEST(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { + %0 = "stablehlo.exponential"(%arg0) { + // CHECK: result_accuracy = #vhlo.result_accuracy_v1> + result_accuracy = #stablehlo.result_accuracy> + } : (tensor<8x16xf32>) -> tensor<8x16xf32> + func.return %0 : tensor<8x16xf32> +} + +// CHECK-LABEL: "attr_result_accuracy_TOLERANCE" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}} +func.func @attr_result_accuracy_TOLERANCE(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { + %0 = "stablehlo.exponential"(%arg0) { + // CHECK: result_accuracy = #vhlo.result_accuracy_v1> + result_accuracy = #stablehlo.result_accuracy> + } : (tensor<8x16xf32>) -> tensor<8x16xf32> + func.return %0 : tensor<8x16xf32> +} + +// CHECK-LABEL: "attr_result_accuracy_DEFAULT" +// CHECK-NEXT: (%[[ARG0:.*]]: {{.*}} +func.func @attr_result_accuracy_DEFAULT(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { + %0 = "stablehlo.exponential"(%arg0) { + // CHECK: result_accuracy = #vhlo.result_accuracy_v1> + result_accuracy = #stablehlo.result_accuracy> + } : (tensor<8x16xf32>) -> tensor<8x16xf32> + func.return %0 : tensor<8x16xf32> +} + // GatherDimensionNumbers aka #stablehlo.gather is covered below. // CHECK-LABEL: "attr_precision_config_default" @@ -1621,7 +1651,7 @@ func.func @op_exponential_minus_one(%arg0: tensor) -> tensor { // CHECK-LABEL: "op_exponential" // CHECK-NEXT: (%[[ARG0:.*]]: {{.*}}) func.func @op_exponential(%arg0: tensor) -> tensor { - // CHECK: "vhlo.exponential_v1"(%[[ARG0]]) : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 + // CHECK: "vhlo.exponential_v2"(%[[ARG0]]) <{result_accuracy = #vhlo.result_accuracy_v1>}> : (!vhlo.tensor_v1) -> !vhlo.tensor_v1 %0 = "stablehlo.exponential"(%arg0) : (tensor) -> tensor func.return %0 : tensor } diff --git a/stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir b/stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir new file mode 100644 index 00000000000..b73d1005568 --- /dev/null +++ b/stablehlo/tests/vhlo/vhlo_attributes_invalid.mlir @@ -0,0 +1,26 @@ +// RUN: stablehlo-opt --vhlo-to-version=target=1.9.0 -verify-diagnostics --split-input-file %s + +func.func @invalid_array_element() -> () attributes { + // expected-error @+1 {{expected array of VHLO attriutes}} + vhlo.attr = #vhlo.array_v1<[#stablehlo]> +} { + return +} + +// ----- + +func.func @invalid_dict_element_value() -> () attributes { + // expected-error @+1 {{expected VHLO attribute}} + vhlo.attr = #vhlo.dict_v1<{#vhlo.string_v1<"attr1"> = 3 : i32}> +} { + return +} + +// ----- + +func.func @invalid_result_accuracy() -> () attributes { + // expected-error @+1 {{expected VHLO result accuracy mode}} + vhlo.attr = #vhlo.result_accuracy_v1> +} { + return +} diff --git a/stablehlo/tests/vhlo/vhlo_to_version_downgrade.1_8_0.mlir b/stablehlo/tests/vhlo/vhlo_to_version_downgrade.1_8_0.mlir new file mode 100644 index 00000000000..09623ac8b32 --- /dev/null +++ b/stablehlo/tests/vhlo/vhlo_to_version_downgrade.1_8_0.mlir @@ -0,0 +1,24 @@ +// RUN: stablehlo-opt --stablehlo-legalize-to-vhlo --vhlo-to-version='target=1.8.0' %s | FileCheck %s + +// ExpOp was changed in v1.9.0 to have +// result_accuracy attribute. Ensure that serializing for 1.8.0 is valid and targets the +// v1.8.0 opset. +// +// This will catch issues in op `isLegal` checks: +// op.minVersion() <= target <= op.maxVersion() + +// CHECK-LABEL: vhlo.func_v1 @exp_op +func.func public @exp_op(%arg0: tensor) -> tensor { + // CHECK: vhlo.exponential_v1 + %0 = "stablehlo.exponential"(%arg0) : (tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: vhlo.func_v1 @exp_op_default +func.func @exp_op_default(%arg0: tensor) -> tensor { + %0 = "stablehlo.exponential"(%arg0) { + // CHECK: vhlo.exponential_v1 + result_accuracy = #stablehlo.result_accuracy> + } : (tensor) -> tensor + func.return %0 : tensor +} diff --git a/stablehlo/tests/vhlo/vhlo_to_version_downgrade_invalid.1_8_0.mlir b/stablehlo/tests/vhlo/vhlo_to_version_downgrade_invalid.1_8_0.mlir new file mode 100644 index 00000000000..ca81b9dc1ff --- /dev/null +++ b/stablehlo/tests/vhlo/vhlo_to_version_downgrade_invalid.1_8_0.mlir @@ -0,0 +1,22 @@ +// RUN: stablehlo-opt --stablehlo-legalize-to-vhlo --vhlo-to-version='target=1.8.0' --verify-diagnostics --split-input-file %s + + +func.func @attr_result_accuracy_default(%arg0: tensor) -> tensor { + %0 = "stablehlo.exponential"(%arg0) { + // CHECK: vhlo.exponential_v1 + result_accuracy = #stablehlo.result_accuracy> + } : (tensor) -> tensor + func.return %0 : tensor +} + +// ----- + +// expected-error @-3 {{failed to convert VHLO to v1.8.0}} +func.func @attr_result_accuracy_highest(%arg0: tensor) -> tensor { + // expected-error @+1 {{failed to legalize operation 'vhlo.exponential_v2' that was explicitly marked illegal}} + %0 = "stablehlo.exponential"(%arg0) { + result_accuracy = #stablehlo.result_accuracy> + } : (tensor) -> tensor + func.return %0 : tensor +} + diff --git a/stablehlo/transforms/MapStablehloToVhlo.h b/stablehlo/transforms/MapStablehloToVhlo.h index 7ac0e2fabfd..97d9fcacd60 100644 --- a/stablehlo/transforms/MapStablehloToVhlo.h +++ b/stablehlo/transforms/MapStablehloToVhlo.h @@ -94,7 +94,7 @@ MAP_STABLEHLO_TO_VHLO(DynamicSliceOp, V1) MAP_STABLEHLO_TO_VHLO(DynamicUpdateSliceOp, V1) MAP_STABLEHLO_TO_VHLO(EinsumOp, V1) MAP_STABLEHLO_TO_VHLO(Expm1Op, V1) -MAP_STABLEHLO_TO_VHLO(ExpOp, V1) +MAP_STABLEHLO_TO_VHLO(ExpOp, V2) MAP_STABLEHLO_TO_VHLO(FftOp, V1) MAP_STABLEHLO_TO_VHLO(FloorOp, V1) MAP_STABLEHLO_TO_VHLO(GatherOp, V2) diff --git a/stablehlo/transforms/PassUtils.h b/stablehlo/transforms/PassUtils.h index ea7b9f9c73a..9bf6c62d763 100644 --- a/stablehlo/transforms/PassUtils.h +++ b/stablehlo/transforms/PassUtils.h @@ -10,8 +10,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_STABLEHLO_STABLEHLO_TRANSFORMS_PASS_UTILS_H_ -#define THIRD_PARTY_STABLEHLO_STABLEHLO_TRANSFORMS_PASS_UTILS_H_ +#ifndef STABLEHLO_TRANSFORMS_PASS_UTILS_H_ +#define STABLEHLO_TRANSFORMS_PASS_UTILS_H_ #include "llvm/Support/ErrorHandling.h" #include "mlir/Dialect/Complex/IR/Complex.h" @@ -72,4 +72,4 @@ bool isAnyQuantizedTypes(TypeRange types); } // namespace stablehlo } // namespace mlir -#endif // THIRD_PARTY_STABLEHLO_STABLEHLO_TRANSFORMS_PASS_UTILS_H_ +#endif // STABLEHLO_TRANSFORMS_PASS_UTILS_H_ diff --git a/stablehlo/transforms/StablehloComplexMathExpanderPatterns.td b/stablehlo/transforms/StablehloComplexMathExpanderPatterns.td index 0ab485d5f59..6596773dac4 100644 --- a/stablehlo/transforms/StablehloComplexMathExpanderPatterns.td +++ b/stablehlo/transforms/StablehloComplexMathExpanderPatterns.td @@ -683,12 +683,15 @@ def LogOp_ComplexElementType_ComplexMathExpander: Pat<(StableHLO_LogOp ComplexEl // Notice that for `y != 0`, neither `cos(y)` nor `sin(y)` is never // zero on the set of floating point numbers. // -def ExpOp_ComplexElementType_ComplexMathExpander: Pat<(StableHLO_ExpOp ComplexElementType:$z), +def ConstDefaultResultAccuracyAttr : + ConstantAttr; + +def ExpOp_ComplexElementType_ComplexMathExpander: Pat<(StableHLO_ExpOp ComplexElementType:$z, ConstDefaultResultAccuracyAttr), (StableHLO_ComplexOp (StableHLO_SelectOp (StableHLO_CompareOp:$eq_e_constant_posinf (StableHLO_ExpOp:$e - (StableHLO_RealOp:$x $z)), + (StableHLO_RealOp:$x $z), ConstDefaultResultAccuracyAttr), (StableHLO_ConstantLikePosInfValue $x), StableHLO_ComparisonDirectionValue<"EQ">, (STABLEHLO_DEFAULT_COMPARISON_TYPE)), @@ -697,7 +700,7 @@ def ExpOp_ComplexElementType_ComplexMathExpander: Pat<(StableHLO_ExpOp ComplexEl (StableHLO_ExpOp:$e2 (StableHLO_MulOp $x, - (StableHLO_ConstantLike<"0.5"> $x))), + (StableHLO_ConstantLike<"0.5"> $x)), ConstDefaultResultAccuracyAttr), (StableHLO_CosineOp:$cs (StableHLO_ImagOp:$y $z))), $e2), diff --git a/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stablehlo/transforms/StablehloLegalizeToVhlo.cpp index 53c09c409ae..0a1bbd26e59 100644 --- a/stablehlo/transforms/StablehloLegalizeToVhlo.cpp +++ b/stablehlo/transforms/StablehloLegalizeToVhlo.cpp @@ -19,6 +19,7 @@ limitations under the License. #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" @@ -130,6 +131,16 @@ Attribute convertGeneric(Attribute stablehloAttr, if (auto attr = dyn_cast(stablehloAttr)) { RETURN_CONVERTED_ENUM_ATTR(Transpose, V1); } + if (auto attr = dyn_cast(stablehloAttr)) { + RETURN_CONVERTED_ENUM_ATTR(ResultAccuracyMode, V1); + } + if (auto attr = dyn_cast(stablehloAttr)) { + auto modeAttr = convertGeneric(attr.getMode(), typeConverter); + if (!modeAttr) return {}; + return vhlo::ResultAccuracyV1Attr::get(attr.getContext(), attr.getAtol(), + attr.getRtol(), attr.getUlps(), + modeAttr); + } if (stablehloAttr.getDialect().getNamespace() == stablehlo::StablehloDialect::getDialectNamespace()) { // All StableHLO attributes must have counterparts in VHLO. @@ -815,6 +826,19 @@ LogicalResult addDefaults(const OpConversionPattern& pattern, } } } + if constexpr (std::is_same::value) { + if (!stablehloOp.getResultAccuracyAttr()) + addDefaultAttr("result_accuracy", + stablehlo::ResultAccuracyAttr::get( + pattern.getContext(), + /*atol=*/APFloat(0.0), + /*rtol=*/APFloat(0.0), + /*ulps=*/0, + /*mode=*/ + stablehlo::ResultAccuracyModeAttr::get( + pattern.getContext(), + stablehlo::ResultAccuracyMode::DEFAULT))); + } if constexpr (std::is_same::value) { if (!stablehloOp.getKnownExpandingDimensionsAttr()) diff --git a/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stablehlo/transforms/VhloLegalizeToStablehlo.cpp index 66c1b1672e7..6afa449e5a1 100644 --- a/stablehlo/transforms/VhloLegalizeToStablehlo.cpp +++ b/stablehlo/transforms/VhloLegalizeToStablehlo.cpp @@ -23,6 +23,7 @@ limitations under the License. #include "llvm/Support/AllocatorBase.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" @@ -169,6 +170,17 @@ Attribute convertGeneric(Attribute vhloAttr, if (!builtinType) return {}; return TypeAttr::get(builtinType); } + if (auto attr = dyn_cast(vhloAttr)) { + RETURN_CONVERTED_ENUM_ATTR(ResultAccuracyMode, V1); + } + if (auto attr = dyn_cast(vhloAttr)) { + auto modeAttr = dyn_cast_or_null( + convertGeneric(attr.getMode(), typeConverter)); + if (!modeAttr) return {}; + return stablehlo::ResultAccuracyAttr::get(attr.getContext(), attr.getAtol(), + attr.getRtol(), attr.getUlps(), + modeAttr); + } // All VHLO Attributes must be converted by now. if (vhloAttr.getDialect().getNamespace() == @@ -737,6 +749,13 @@ bool isSplatArray(Attribute vhloAttr, Attribute splatValue) { }); } +bool isDefaultResultAccuracyAttribute(Attribute vhloAttr) { + auto attr = dyn_cast_or_null(vhloAttr); + return attr.getAtol().isZero() && attr.getRtol().isZero() && + attr.getUlps() == 0 && + dyn_cast(attr.getMode()).getValue() == + vhlo::ResultAccuracyModeV1::DEFAULT; +} template bool isSplatTensor(const ConversionPattern& pattern, Attribute vhloAttr, T splatValue) { @@ -898,6 +917,11 @@ LogicalResult removeDefaults(const OpConversionPattern& pattern, if (isBoolean(vhloOp.getIsStableAttr(), false)) eraseAttrs(vhloAttrs, "is_stable"); } + if constexpr (std::is_same::value) { + if (isDefaultResultAccuracyAttribute(vhloOp.getResultAccuracyAttr())) { + eraseAttrs(vhloAttrs, "result_accuracy"); + } + } return success(); } diff --git a/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/transforms/VhloToVersion.cpp index e1d2225587e..f8b77b94d38 100644 --- a/stablehlo/transforms/VhloToVersion.cpp +++ b/stablehlo/transforms/VhloToVersion.cpp @@ -139,6 +139,8 @@ LogicalResult isLegalAttribute(const Attribute& attr, Version targetVersion) { return isLegalType(tensorAttr.getType(), targetVersion); if (auto typeAttr = dyn_cast(attr)) return isLegalType(typeAttr.getValue(), targetVersion); + if (auto resultAccuracyAttr = dyn_cast(attr)) + return isLegalAttribute(resultAccuracyAttr.getMode(), targetVersion); // Is VHLO and valid version, success. return success(); @@ -324,6 +326,22 @@ TensorV1Attr getDefaultConvPadding(OpBuilder& builder, Value lhs) { denseElements.getRawData()); } +bool isDefaultResultAccuracy(Attribute attr) { + auto resultAccuracy = dyn_cast(attr); + auto default_mode = ResultAccuracyModeV1Attr::get( + attr.getContext(), ResultAccuracyModeV1::DEFAULT); + return resultAccuracy.getAtol().isZero() && + resultAccuracy.getRtol().isZero() && resultAccuracy.getUlps() == 0 && + resultAccuracy.getMode() == default_mode; +} + +ResultAccuracyV1Attr getDefaultResultAccuracy(OpBuilder& builder) { + return ResultAccuracyV1Attr::get( + builder.getContext(), APFloat(0.0), APFloat(0.0), 0, + ResultAccuracyModeV1Attr::get(builder.getContext(), + ResultAccuracyModeV1::DEFAULT)); +} + // DRR has limited support for ops with regions struct ScatterOpV2ToV1 : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -393,6 +411,40 @@ struct AllReduceOpV2ToV1 : public OpRewritePattern { } }; +struct ExpOpV1ToV2 : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExpOpV1 op, + PatternRewriter& rewriter) const override { + ResultAccuracyV1Attr defaultResultAccuracy = ResultAccuracyV1Attr::get( + rewriter.getContext(), APFloat(0.0), APFloat(0.0), 0, + ResultAccuracyModeV1Attr::get(rewriter.getContext(), + ResultAccuracyModeV1::DEFAULT)); + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), op.getOperand(), defaultResultAccuracy); + return success(); + } +}; + +struct ExpOpV2ToV1 : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExpOpV2 op, + PatternRewriter& rewriter) const override { + auto defaultResultAccuracy = ResultAccuracyV1Attr::get( + rewriter.getContext(), APFloat(0.0), APFloat(0.0), 0, + ResultAccuracyModeV1Attr::get(rewriter.getContext(), + ResultAccuracyModeV1::DEFAULT)); + if (op.getResultAccuracy() != defaultResultAccuracy) { + return rewriter.notifyMatchFailure(op, + "non-default result accuracy attr"); + } + rewriter.replaceOpWithNewOp(op, op->getResultTypes(), + op.getOperand()); + return success(); + } +}; + #include "stablehlo/transforms/VhloToVersionPatterns.h.inc" } // namespace @@ -405,6 +457,7 @@ void populateVhloToVersionPatterns(RewritePatternSet* patterns, vhlo::populateWithGenerated(*patterns); patterns->add(context); patterns->add(context); + patterns->add(context); } } // namespace stablehlo diff --git a/stablehlo/transforms/VhloToVersionPatterns.td b/stablehlo/transforms/VhloToVersionPatterns.td index 68c9d7375fb..78f98c3aa4e 100644 --- a/stablehlo/transforms/VhloToVersionPatterns.td +++ b/stablehlo/transforms/VhloToVersionPatterns.td @@ -15,6 +15,9 @@ limitations under the License. include "mlir/IR/OpBase.td" include "stablehlo/dialect/VhloOps.td" +include "mlir/IR/CommonAttrConstraints.td" +include "stablehlo/dialect/VhloEnums.td" +include "stablehlo/dialect/VhloAttrs.td" def VHLO_GetEmptyDims : NativeCodeCall<"getEmptyI64Tensor($_builder)">; @@ -32,6 +35,11 @@ def VHLO_GetFirstOperand : NativeCodeCall<"$0.front()">; def VHLO_WrapInVector : NativeCodeCall<"{$0}">; +def VHLO_GetDefaultResultAccuracyAttr : NativeCodeCall<"getDefaultResultAccuracy($_builder)">; + + +def VHLO_DefaultResultAccuracy : AttrConstraint, "Default result accuracy">; + def DynamicConvUpgradeV1ToV2: Pat<(VHLO_DynamicConvOpV1 $lhs, $rhs, $d_padding, $window_strides, $padding, $lhs_dilation, $rhs_dilation, $window_reversal, $input_batch_dimension, $input_feature_dimension, $input_spatial_dimensions, $kernel_input_feature_dimension, $kernel_output_feature_dimension, $kernel_spatial_dimensions, $output_batch_dimension, $output_feature_dimension, $output_spatial_dimensions, $feature_group_count, $batch_group_count, $precision_config), (VHLO_DynamicConvOpV2 $lhs, $rhs, $d_padding, $window_strides, $lhs_dilation, $rhs_dilation, $window_reversal, $input_batch_dimension, $input_feature_dimension, $input_spatial_dimensions, $kernel_input_feature_dimension, $kernel_output_feature_dimension, $kernel_spatial_dimensions, $output_batch_dimension, $output_feature_dimension, $output_spatial_dimensions, $feature_group_count, $batch_group_count, $precision_config)>; @@ -83,3 +91,11 @@ def DotGeneralOpUpradeV1ToV2 : Pat<(VHLO_DotGeneralOpV1 $lhs, $rhs, $lhs_batching_dimensions, $rhs_batching_dimensions, $lhs_contracting_dimensions, $rhs_contracting_dimensions, $precision_config), (VHLO_DotGeneralOpV2 $lhs, $rhs, $lhs_batching_dimensions, $rhs_batching_dimensions, $lhs_contracting_dimensions, $rhs_contracting_dimensions, $precision_config, (VHLO_GetNoneType), (VHLO_GetNoneType), (VHLO_GetNoneType), (VHLO_GetNoneType), (VHLO_GetNoneType), (VHLO_GetNoneType), (VHLO_GetNoneType))>; + +def ExpOpDowngradeV2ToV1 : + Pat<(VHLO_ExpOpV2 $operand, VHLO_DefaultResultAccuracy:$result_accuracy), + (VHLO_ExpOpV1 $operand)>; + +def ExpOpUpgradeV1ToV2 : + Pat<(VHLO_ExpOpV1 $operand), + (VHLO_ExpOpV2 $operand, (VHLO_GetDefaultResultAccuracyAttr))>;