-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[mlir][gpu] Add pass for emulating unsupported types. #138087
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir-arith @llvm/pr-subscribers-mlir-gpu Author: Md Abdullah Shahneous Bari (mshahneo) ChangesThis pass imitates (bitcast/reinterpret_cast) unsupported types with supported types of same bitwidth. The imitation is done by bitcasting the unspported types to the supported types of same bitwidth. Therefore, the source type and destination type must have the same bitwidth. The imitation is done by using the following operations: arith.bitcast. The imitation is often needed when the GPU target (dialect/IR) does not support a certain type but the underlying architecture does. Take SPIR-V for example, it does not support bf16, but an underlying architecture (e.g., intel pvc gpu) that uses SPIR-V for code-generation does. Therefore, bf16 is neither a valid data type to pass to gpu kernel, nor to be used inside the kernel. To use bf16 data type in a SPIR-V kernel (as a kernel parameter or inside the kernel), bf16 have to be bitcasted (similar to C++ reinterpret_cast) to a supported type (e.g., i16 for Intel GPUs). The SPIR-V kernel can then use the imitated type (i16) in the computation. However, i16 is not the same as bf16 (integer vs float), so the computation can not readily use the imitated type (i16). Therefore, this transformation pass is intended to be used in conjuction with other transformation passes such as Finally, usually, there are instructions available in the target (dialect/IR) that can take advantage of these generated patterns (bf16->i16->f32, f32->bf16->i16), and convert them to the supported types. Patch is 48.64 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/138087.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
index 6cd6f03253aea..0b7339a94b274 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
@@ -16,6 +16,8 @@
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/Utils/GPUUtils.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include <optional>
@@ -87,6 +89,24 @@ void populateGpuLowerClusteredSubgroupReduceToDPPPatterns(
RewritePatternSet &patterns, unsigned subgroupSize, amdgpu::Chipset chipset,
PatternBenefit benefit = 1);
+/// Set up a type converter to convert unsupported source types to
+/// supported target types.
+void populateImitateUnsupportedTypesTypeConverter(TypeConverter &typeConverter,
+ ArrayRef<Type> sourceTypes,
+ ArrayRef<Type> targetTypes);
+
+/// Collect a set of pattern needed to imitate unsupported source types
+/// using supported target types.
+void populateImitateUnsupportedTypesConversionPatterns(
+ RewritePatternSet &patterns, TypeConverter &typeConverter,
+ ArrayRef<Type> sourceTypes, ArrayRef<Type> targetTypes,
+ DenseMap<StringAttr, FunctionType> &convertedFuncTypes);
+
+/// Set up a dialect conversion to reject operations on unsupported
+/// float types.
+void configureImitateUnsupportedTypesLegality(ConversionTarget &target,
+ TypeConverter &typeConverter);
+
/// Collect all patterns to rewrite ops within the GPU dialect.
inline void populateGpuRewritePatterns(RewritePatternSet &patterns) {
populateGpuAllReducePatterns(patterns);
diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
index 3766eb16e9429..feb1b2820abd6 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.td
@@ -258,4 +258,57 @@ def GpuSPIRVAttachTarget: Pass<"spirv-attach-target", ""> {
];
}
+def GpuImitateUnsupportedTypes : Pass<"imitate-unsupported-types", "::mlir::ModuleOp"> {
+ let summary = "Imitate unsupported types with supported types of same bitwidth.";
+ let description = [{
+ This pass imitates (bitcast/reinterpret_cast) unsupported types
+ with supported types of same bitwidth. The imitation is done
+ by bitcasting the unspported types to the supported types of same bitwidth.
+ Therefore, the source type and destination type must have the same bitwidth.
+ The imitation is done by using the following operations: arith.bitcast.
+
+ The imitation is often needed when the GPU target (dialect/IR) does not
+ support a certain type but the underlying architecture does. Take SPIR-V for
+ example, it does not support bf16, but an underlying architecture (e.g.,
+ intel pvc gpu) that uses SPIR-V for code-generation does.
+ Therefore, bf16 is neither a valid data type to pass to gpu kernel, nor to
+ be used inside the kernel. To use bf16 data type in a SPIR-V kernel (as a
+ kernel parameter or inside the kernel), bf16 have to be bitcasted (similar
+ to C++ reinterpret_cast) to a supported type (e.g., i16 for Intel GPUs). The
+ SPIR-V kernel can then use the imitated type (i16) in the computation.
+ However, i16 is not the same as bf16 (integer vs float), so the computation
+ can not readily use the imitated type (i16).
+
+ Therefore, this transformation pass is intended to be used in conjuction
+ with other transformation passes such as `EmulateUnsupportedFloats` and
+ `ExtendUnsupportedTypes` that extend the bitwidth of bf16 to f32 and
+ vice-versa.
+
+ Finally, usually, there are instructions available in the target
+ (dialect/IR) that can take advantage of these generated patterns
+ (bf16->i16->f32, f32->bf16->i16), and convert them to the supported
+ types.
+ For example, Intel provides SPIR-V extension ops that can
+ take imitated bf16 (i16) and convert them to f32 and vice-versa.
+ https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_bfloat16_conversion.asciidoc
+ https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertbf16tof-spirvintelconvertbf16tofop
+ https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertftobf16-spirvintelconvertftobf16op
+
+ }];
+
+ let options = [
+ ListOption<"sourceTypeStrs", "source-types", "std::string",
+ "MLIR types without type support on a given target">,
+ ListOption<"targetTypeStrs", "target-types", "std::string",
+ "MLIR types to convert the unsupported source types to">,
+ ];
+
+ let dependentDialects = [
+ "::mlir::gpu::GPUDialect",
+ "::mlir::arith::ArithDialect",
+ "::mlir::memref::MemRefDialect"
+ ];
+}
+
+
#endif // MLIR_DIALECT_GPU_PASSES
diff --git a/mlir/lib/Dialect/GPU/Transforms/ImitateUnsupportedTypes.cpp b/mlir/lib/Dialect/GPU/Transforms/ImitateUnsupportedTypes.cpp
new file mode 100644
index 0000000000000..c83e6bec568e0
--- /dev/null
+++ b/mlir/lib/Dialect/GPU/Transforms/ImitateUnsupportedTypes.cpp
@@ -0,0 +1,916 @@
+//===- ImitateUnsupportedTypes.cpp - Unsupported Type Imitation ----*- C++
+//-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//===----------------------------------------------------------------------===//
+//
+/// \file
+/// This pass imitates (bitcast/reinterpret_cast) unsupported types
+/// with supported types of same bitwidth. The imitation is done
+/// by bitcasting the unspported types to the supported types of same bitwidth.
+/// Therefore, the source type and destination type must have the same bitwidth.
+/// The imitation is done by using the following operations: arith.bitcast.
+///
+/// The imitation is often needed when the GPU target (dialect/IR) does not
+/// support a certain type but the underlying architecture does. Take SPIR-V for
+/// example, it does not support bf16, but an underlying architecture (e.g.,
+/// intel pvc gpu) that uses SPIR-V for code-generation does.
+/// Therefore, bf16 is neither a valid data type to pass to gpu kernel, nor to
+/// be used inside the kernel. To use bf16 data type in a SPIR-V kernel (as a
+/// kernel parameter or inside the kernel), bf16 have to be bitcasted (similar
+/// to C++ reinterpret_cast) to a supported type (e.g., i16 for Intel GPUs). The
+/// SPIR-V kernel can then use the imitated type (i16) in the computation.
+/// However, i16 is not the same as bf16 (integer vs float), so the computation
+/// can not readily use the imitated type (i16).
+///
+/// Therefore, this transformation pass is intended to be used in conjuction
+/// with other transformation passes such as `EmulateUnsupportedFloats` and
+/// `ExtendUnsupportedTypes` that extend the bitwidth of bf16 to f32 and
+/// vice-versa.
+///
+/// Finally, usually, there are instructions available in the target
+/// (dialect/IR) that can take advantage of these generated patterns
+/// (bf16->i16->f32, f32->bf16->i16), and convert them to the supported
+/// types.
+/// For example, Intel provides SPIR-V extension ops that can
+/// take imitated bf16 (i16) and convert them to f32 and vice-versa.
+/// https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_bfloat16_conversion.asciidoc
+/// https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertbf16tof-spirvintelconvertbf16tofop
+/// https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertftobf16-spirvintelconvertftobf16op
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/GPU/Transforms/Passes.h"
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/ErrorHandling.h"
+
+#include <optional>
+#include <type_traits>
+#include <variant>
+
+using namespace mlir;
+using namespace mlir::gpu;
+
+namespace mlir {
+#define GEN_PASS_DEF_GPUIMITATEUNSUPPORTEDTYPES
+#include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
+} // namespace mlir
+
+//===----------------------------------------------------------------------===//
+// Utility functions
+//===----------------------------------------------------------------------===//
+
+APFloat bitcastAPIntToAPFloat(const APInt &intValue,
+ const llvm::fltSemantics &semantics) {
+ // Get the bit width of the APInt.
+ unsigned intBitWidth = intValue.getBitWidth();
+ // Get the total bit size required for the APFloat based on the semantics.
+ unsigned floatBitWidth = APFloat::getSizeInBits(semantics);
+ // Ensure the bit widths match for a direct bitcast.
+ assert(intBitWidth == floatBitWidth &&
+ "Bitwidth of APInt and APFloat must match for bitcast");
+
+ // Get the raw bit representation of the APInt as a byte vector.
+ auto intWords = intValue.getRawData();
+ // Create an APFloat with the specified semantics and the raw integer bits.
+ APFloat floatValue(semantics, APInt(intBitWidth, *intWords));
+ return floatValue;
+}
+
+// Get FloatAttr from IntegerAttr.
+FloatAttr getFloatAttrFromIntegerAttr(IntegerAttr intAttr, Type dstType,
+ ConversionPatternRewriter &rewriter) {
+ APInt intVal = intAttr.getValue();
+ auto floatVal = bitcastAPIntToAPFloat(
+ intVal, cast<FloatType>(dstType).getFloatSemantics());
+ return rewriter.getFloatAttr(dstType, floatVal);
+}
+// Get IntegerAttr from FloatAttr.
+IntegerAttr getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType,
+ ConversionPatternRewriter &rewriter) {
+ APFloat floatVal = floatAttr.getValue();
+ APInt intVal = floatVal.bitcastToAPInt();
+ return rewriter.getIntegerAttr(dstType, intVal);
+}
+
+struct RawAllocator {
+ RawAllocator(OpBuilder &builder, Location loc) : builder(builder), loc(loc) {}
+
+ std::variant<Value, int64_t> computeTotalBytes(MemRefType srcType,
+ Value srcMemref) {
+ // Element size in bytes.
+ int64_t elemBitWidth = srcType.getElementTypeBitWidth();
+ int64_t elemByteWidth = (elemBitWidth + 7) / 8;
+
+ if (srcType.hasStaticShape()) {
+ // Static shape: compute total bytes statically.
+ int64_t numElements = 1;
+ for (int64_t dim : srcType.getShape()) {
+ numElements *= dim;
+ }
+ return numElements * elemByteWidth;
+ }
+
+ auto sizes = getSizes(srcType, srcMemref);
+ // Compute number of elements dynamically.
+ Value numElements = sizes.front();
+ for (auto size : llvm::drop_begin(sizes))
+ numElements = builder.create<arith::MulIOp>(loc, numElements, size);
+ Value elemSize = builder.create<arith::ConstantIndexOp>(loc, elemByteWidth);
+
+ return builder.create<arith::MulIOp>(loc, numElements, elemSize);
+ }
+
+ SmallVector<Value> getSizes(MemRefType type, Value memref) {
+ SmallVector<Value> sizes;
+ for (unsigned i = 0; i < type.getRank(); ++i) {
+ if (type.isDynamicDim(i)) {
+ sizes.push_back(builder.create<memref::DimOp>(loc, memref, i));
+ } else {
+ sizes.push_back(
+ builder.create<arith::ConstantIndexOp>(loc, type.getShape()[i]));
+ }
+ }
+ return sizes;
+ }
+
+ SmallVector<Value> getDynamicSizes(MemRefType type, Value memref) {
+ SmallVector<Value> sizes;
+ for (unsigned i = 0; i < type.getRank(); ++i) {
+ if (type.isDynamicDim(i)) {
+ sizes.push_back(builder.create<memref::DimOp>(loc, memref, i));
+ }
+ }
+ return sizes;
+ }
+
+ SmallVector<Value> getIdentityStrides(MemRefType type) {
+ SmallVector<Value> strides;
+ int64_t runningStride = 1;
+ for (int64_t dim : llvm::reverse(type.getShape())) {
+ strides.push_back(
+ builder.create<arith::ConstantIndexOp>(loc, runningStride));
+ if (dim != ShapedType::kDynamic)
+ runningStride *= dim;
+ else
+ runningStride = -1; // not handling dynamic strides.
+ }
+ std::reverse(strides.begin(), strides.end());
+ return strides;
+ }
+
+private:
+ OpBuilder &builder;
+ Location loc;
+};
+
+// Replace uses according to predicates automatically.
+template <typename OpTy>
+void replaceUsesWithPredicate(
+ OpTy originalValue,
+ ArrayRef<std::pair<std::function<bool(OpOperand &)>, Value>> replacements,
+ ConversionPatternRewriter &rewriter) {
+
+ for (OpOperand &use : llvm::make_early_inc_range(originalValue->getUses())) {
+ for (const auto &[predicate, newValue] : replacements) {
+ if (predicate(use)) {
+ use.set(newValue);
+ break;
+ }
+ }
+ }
+}
+
+//===----------------------------------------------------------------------===//
+// Convertion patterns
+//===----------------------------------------------------------------------===//
+namespace {
+
+//===----------------------------------------------------------------------===//
+// FunctionOp conversion pattern
+//===----------------------------------------------------------------------===//
+template <typename FuncLikeOp>
+struct ConvertFuncOp final : public OpConversionPattern<FuncLikeOp> {
+ ConvertFuncOp(MLIRContext *context, TypeConverter &typeConverter,
+ ArrayRef<Type> sourceTypes, ArrayRef<Type> targetTypes,
+ DenseMap<StringAttr, FunctionType> &convertedFuncTypes)
+ : OpConversionPattern<FuncLikeOp>(context),
+ typeConverter(typeConverter), // Store the reference
+ sourceTypes(sourceTypes), targetTypes(targetTypes),
+ convertedFuncTypes(convertedFuncTypes) {}
+ using OpConversionPattern<FuncLikeOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(FuncLikeOp op, typename FuncLikeOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Only handle functions a gpu.module
+ if (!op->template getParentOfType<gpu::GPUModuleOp>())
+ return failure();
+ FunctionType oldFuncType = op.getFunctionType();
+
+ // Convert function signature
+ TypeConverter::SignatureConversion signatureConverter(
+ oldFuncType.getNumInputs());
+ for (const auto &argType :
+ llvm::enumerate(op.getFunctionType().getInputs())) {
+ auto convertedType = typeConverter.convertType(argType.value());
+ if (!convertedType)
+ return failure();
+ signatureConverter.addInputs(argType.index(), convertedType);
+ }
+ SmallVector<Type, 4> newResultTypes;
+ for (const auto &resultType : llvm::enumerate(oldFuncType.getResults())) {
+ auto convertedType = typeConverter.convertType(resultType.value());
+ if (!convertedType)
+ return failure();
+ newResultTypes.push_back(convertedType);
+ }
+
+ // Convert function signature
+ FunctionType newFuncType = rewriter.getFunctionType(
+ signatureConverter.getConvertedTypes(), newResultTypes);
+
+ if (!newFuncType)
+ return rewriter.notifyMatchFailure(op, "could not convert function "
+ "type");
+
+ // Create new GPU function with converted type
+ auto newFuncOp =
+ rewriter.create<FuncLikeOp>(op.getLoc(), op.getName(), newFuncType);
+
+ newFuncOp.setVisibility(op.getVisibility());
+ // Copy attributes
+ for (auto attr : op->getAttrs()) {
+ // Skip the function_type attribute since it is already set by
+ // the newFuncType and we don't want to overwrite it.
+ if (attr.getName() != op.getFunctionTypeAttrName() &&
+ attr.getName() != SymbolTable::getSymbolAttrName())
+ newFuncOp->setAttr(attr.getName(), attr.getValue());
+ }
+
+ newFuncOp.getRegion().getBlocks().clear();
+ // Inline region approach
+ rewriter.inlineRegionBefore(op.getBody(), newFuncOp.getBody(),
+ newFuncOp.end());
+ // Convert block argument types using the type converter
+ if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
+ &signatureConverter))) {
+ return rewriter.notifyMatchFailure(op, "could not convert region "
+ "types");
+ }
+
+ if (!op.use_empty()) {
+ op.emitError("Cannot erase func: still has uses");
+ }
+ for (Operation *user : op->getUsers()) {
+ user->emitRemark() << "User of function " << op.getName();
+ }
+ rewriter.eraseOp(op);
+ // Add the converted function type to the map
+ newFuncOp.getNameAttr().getValue();
+ convertedFuncTypes[newFuncOp.getNameAttr()] = newFuncType;
+ return success();
+ }
+
+private:
+ TypeConverter &typeConverter; // Store a reference
+ ArrayRef<Type> sourceTypes;
+ ArrayRef<Type> targetTypes;
+ DenseMap<StringAttr, FunctionType> &convertedFuncTypes;
+};
+
+//===----------------------------------------------------------------------===//
+// CallOp conversion pattern
+//===----------------------------------------------------------------------===//
+struct ConvertCallOp : OpConversionPattern<func::CallOp> {
+ ConvertCallOp(MLIRContext *context, TypeConverter &typeConverter,
+ const DenseMap<StringAttr, FunctionType> &convertedFuncTypes)
+ : OpConversionPattern(context), convertedFuncTypes(convertedFuncTypes) {}
+
+ LogicalResult
+ matchAndRewrite(func::CallOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto callee = op.getCalleeAttr();
+
+ auto it = convertedFuncTypes.find(
+ StringAttr::get(callee.getContext(), callee.getValue()));
+ if (it == convertedFuncTypes.end())
+ return rewriter.notifyMatchFailure(
+ op, "Callee signature not converted. Perhaps the callee is not in "
+ "the same gpu module as the caller.");
+
+ auto newResultTypes = it->second.getResults();
+ rewriter.replaceOpWithNewOp<func::CallOp>(
+ op, callee.getValue(), newResultTypes, adaptor.getOperands());
+
+ return success();
+ }
+
+private:
+ const DenseMap<StringAttr, FunctionType> &convertedFuncTypes;
+};
+
+//===----------------------------------------------------------------------===//
+// GPULaunchFuncOp conversion pattern
+//===----------------------------------------------------------------------===//
+struct ConvertGPULaunchFuncOp : OpConversionPattern<gpu::LaunchFuncOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(gpu::LaunchFuncOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ std::optional<KernelDim3> clusterSizeOpernads =
+ op.hasClusterSize()
+ ? std::optional<gpu::KernelDim3>(op.getClusterSizeOperandValues())
+ : std::nullopt;
+
+ // Create the new launch_func.
+ auto newOp = rewriter.create<gpu::LaunchFuncOp>(
+ op.getLoc(), adaptor.getKernel(), op.getGridSizeOperandValues(),
+ op.getBlockSizeOperandValues(), op.getDynamicSharedMemorySize(),
+ adaptor.getKernelOperands(), op.getAsyncObject(), clusterSizeOpernads);
+
+ // Copy block size and grid size attributes
+ newOp->setAttrs(op->getAttrs());
+ rewriter.replaceOp(op, newOp.getResults());
+ return success();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// AllocOp conversion pattern
+//===----------------------------------------------------------------------===//
+template <typename Allo...
[truncated]
|
This pass imitates (bitcast/reinterpret_cast) unsupported types with supported types of same bitwidth. The imitation is done by bitcasting the unspported types to the supported types of same bitwidth. Therefore, the source type and destination type must have the same bitwidth. The imitation is done by using the following operations: arith.bitcast. The imitation is often needed when the GPU target (dialect/IR) does not support a certain type but the underlying architecture does. Take SPIR-V for example, it does not support bf16, but an underlying architecture (e.g., intel pvc gpu) that uses SPIR-V for code-generation does. Therefore, bf16 is neither a valid data type to pass to gpu kernel, nor to be used inside the kernel. To use bf16 data type in a SPIR-V kernel (as a kernel parameter or inside the kernel), bf16 have to be bitcasted (similar to C++ reinterpret_cast) to a supported type (e.g., i16 for Intel GPUs). The SPIR-V kernel can then use the imitated type (i16) in the computation. However, i16 is not the same as bf16 (integer vs float), so the computation can not readily use the imitated type (i16). Therefore, this transformation pass is intended to be used in conjuction with other transformation passes such as `EmulateUnsupportedFloats` and `ExtendUnsupportedTypes` that extend the bitwidth of bf16 to f32 and vice-versa. Finally, usually, there are instructions available in the target (dialect/IR) that can take advantage of these generated patterns (bf16->i16->f32, f32->bf16->i16), and convert them to the supported types. For example, Intel provides SPIR-V extension ops that can take imitated bf16 (i16) and convert them to f32 and vice-versa. https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_bfloat16_conversion.asciidoc https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertbf16tof-spirvintelconvertbf16tofop https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertftobf16-spirvintelconvertftobf16op
a31fd58
to
3966b5d
Compare
I believe the usual term is "emulate" (instead of "imitate"), unless I missed a nuance you're trying to make, can you update the description (and code) to reflect this? |
https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertbf16tof-spirvintelconvertbf16tofop | ||
https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertftobf16-spirvintelconvertftobf16op | ||
|
||
}]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The implementation only touches specific ops, but from the pass description it's absolutely not clear to me what is the scope here, especially considering we have also EmulateUnsupportedFloats
|
||
// Set up conversion target and configure the legality of the conversion | ||
ConversionTarget target(*ctx); | ||
configureImitateUnsupportedTypesLegality(target, typeConverter); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect all the initialization above could be done in the initialize()
phase of the pass instead of the run phase.
getElementTypeOrSelf(op.getResult(0).getType()) == targetType) { | ||
op->emitError("unresolved unrealized_conversion_cast left in IR " | ||
"after conversion"); | ||
hasUnresolvedCast = true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can call signalPassFailure();
here and get rid of hasUnresolvedCast
entirely.
|
||
SmallVector<Type> sourceTypes; | ||
SmallVector<Type> targetTypes; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like you need to validate that the list sizes are equal?
High-level question: why is this in GPU? Shouldn't the translation from bf16 to i16 either be a pass over on Arith or part of the SPIR-V lowering? See also, when going to LLVM, we replace all the 8-bit float types with i8 |
This pass imitates (bitcast/reinterpret_cast) unsupported types with supported types of same bitwidth. The imitation is done by bitcasting the unspported types to the supported types of same bitwidth. Therefore, the source type and destination type must have the same bitwidth. The imitation is done by using the following operations: arith.bitcast.
The imitation is often needed when the GPU target (dialect/IR) does not support a certain type but the underlying architecture does. Take SPIR-V for example, it does not support bf16, but an underlying architecture (e.g., intel pvc gpu) that uses SPIR-V for code-generation does. Therefore, bf16 is neither a valid data type to pass to gpu kernel, nor to be used inside the kernel. To use bf16 data type in a SPIR-V kernel (as a kernel parameter or inside the kernel), bf16 have to be bitcasted (similar to C++ reinterpret_cast) to a supported type (e.g., i16 for Intel GPUs). The SPIR-V kernel can then use the imitated type (i16) in the computation. However, i16 is not the same as bf16 (integer vs float), so the computation can not readily use the imitated type (i16).
Therefore, this transformation pass is intended to be used in conjuction with other transformation passes such as
EmulateUnsupportedFloats
andExtendUnsupportedTypes
that extend the bitwidth of bf16 to f32 and vice-versa.Finally, usually, there are instructions available in the target (dialect/IR) that can take advantage of these generated patterns (bf16->i16->f32, f32->bf16->i16), and convert them to the supported types.
For example, Intel provides SPIR-V extension ops that can take imitated bf16 (i16) and convert them to f32 and vice-versa. https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_bfloat16_conversion.asciidoc https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertbf16tof-spirvintelconvertbf16tofop https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertftobf16-spirvintelconvertftobf16op