Skip to content

Commit 3966b5d

Browse files
committed
[mlir][gpu] Add pass for imitating unsupported types.
This pass imitates (bitcast/reinterpret_cast) unsupported types with supported types of same bitwidth. The imitation is done by bitcasting the unspported types to the supported types of same bitwidth. Therefore, the source type and destination type must have the same bitwidth. The imitation is done by using the following operations: arith.bitcast. The imitation is often needed when the GPU target (dialect/IR) does not support a certain type but the underlying architecture does. Take SPIR-V for example, it does not support bf16, but an underlying architecture (e.g., intel pvc gpu) that uses SPIR-V for code-generation does. Therefore, bf16 is neither a valid data type to pass to gpu kernel, nor to be used inside the kernel. To use bf16 data type in a SPIR-V kernel (as a kernel parameter or inside the kernel), bf16 have to be bitcasted (similar to C++ reinterpret_cast) to a supported type (e.g., i16 for Intel GPUs). The SPIR-V kernel can then use the imitated type (i16) in the computation. However, i16 is not the same as bf16 (integer vs float), so the computation can not readily use the imitated type (i16). Therefore, this transformation pass is intended to be used in conjuction with other transformation passes such as `EmulateUnsupportedFloats` and `ExtendUnsupportedTypes` that extend the bitwidth of bf16 to f32 and vice-versa. Finally, usually, there are instructions available in the target (dialect/IR) that can take advantage of these generated patterns (bf16->i16->f32, f32->bf16->i16), and convert them to the supported types. For example, Intel provides SPIR-V extension ops that can take imitated bf16 (i16) and convert them to f32 and vice-versa. https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_bfloat16_conversion.asciidoc https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertbf16tof-spirvintelconvertbf16tofop https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertftobf16-spirvintelconvertftobf16op
1 parent c51be1b commit 3966b5d

File tree

7 files changed

+1163
-2
lines changed

7 files changed

+1163
-2
lines changed

mlir/include/mlir/Dialect/Arith/Utils/Utils.h

+6
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,12 @@ Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
146146
// Map strings to float types.
147147
std::optional<FloatType> parseFloatType(MLIRContext *ctx, StringRef name);
148148

149+
// Map strings to Int types.
150+
std::optional<IntegerType> parseIntType(MLIRContext *ctx, StringRef name);
151+
152+
// Map strings to int or float types.
153+
std::optional<Type> parseIntOrFloatType(MLIRContext *ctx, StringRef name);
154+
149155
} // namespace arith
150156
} // namespace mlir
151157

mlir/include/mlir/Dialect/GPU/Transforms/Passes.h

+20
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
#include "mlir/Dialect/AMDGPU/Utils/Chipset.h"
1717
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
1818
#include "mlir/Dialect/GPU/Utils/GPUUtils.h"
19+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
20+
#include "mlir/IR/BuiltinTypes.h"
1921
#include "mlir/IR/PatternMatch.h"
2022
#include "mlir/Pass/Pass.h"
2123
#include <optional>
@@ -87,6 +89,24 @@ void populateGpuLowerClusteredSubgroupReduceToDPPPatterns(
8789
RewritePatternSet &patterns, unsigned subgroupSize, amdgpu::Chipset chipset,
8890
PatternBenefit benefit = 1);
8991

92+
/// Set up a type converter to convert unsupported source types to
93+
/// supported target types.
94+
void populateImitateUnsupportedTypesTypeConverter(TypeConverter &typeConverter,
95+
ArrayRef<Type> sourceTypes,
96+
ArrayRef<Type> targetTypes);
97+
98+
/// Collect a set of pattern needed to imitate unsupported source types
99+
/// using supported target types.
100+
void populateImitateUnsupportedTypesConversionPatterns(
101+
RewritePatternSet &patterns, TypeConverter &typeConverter,
102+
ArrayRef<Type> sourceTypes, ArrayRef<Type> targetTypes,
103+
DenseMap<StringAttr, FunctionType> &convertedFuncTypes);
104+
105+
/// Set up a dialect conversion to reject operations on unsupported
106+
/// float types.
107+
void configureImitateUnsupportedTypesLegality(ConversionTarget &target,
108+
TypeConverter &typeConverter);
109+
90110
/// Collect all patterns to rewrite ops within the GPU dialect.
91111
inline void populateGpuRewritePatterns(RewritePatternSet &patterns) {
92112
populateGpuAllReducePatterns(patterns);

mlir/include/mlir/Dialect/GPU/Transforms/Passes.td

+53
Original file line numberDiff line numberDiff line change
@@ -258,4 +258,57 @@ def GpuSPIRVAttachTarget: Pass<"spirv-attach-target", ""> {
258258
];
259259
}
260260

261+
def GpuImitateUnsupportedTypes : Pass<"imitate-unsupported-types", "::mlir::ModuleOp"> {
262+
let summary = "Imitate unsupported types with supported types of same bitwidth.";
263+
let description = [{
264+
This pass imitates (bitcast/reinterpret_cast) unsupported types
265+
with supported types of same bitwidth. The imitation is done
266+
by bitcasting the unspported types to the supported types of same bitwidth.
267+
Therefore, the source type and destination type must have the same bitwidth.
268+
The imitation is done by using the following operations: arith.bitcast.
269+
270+
The imitation is often needed when the GPU target (dialect/IR) does not
271+
support a certain type but the underlying architecture does. Take SPIR-V for
272+
example, it does not support bf16, but an underlying architecture (e.g.,
273+
intel pvc gpu) that uses SPIR-V for code-generation does.
274+
Therefore, bf16 is neither a valid data type to pass to gpu kernel, nor to
275+
be used inside the kernel. To use bf16 data type in a SPIR-V kernel (as a
276+
kernel parameter or inside the kernel), bf16 have to be bitcasted (similar
277+
to C++ reinterpret_cast) to a supported type (e.g., i16 for Intel GPUs). The
278+
SPIR-V kernel can then use the imitated type (i16) in the computation.
279+
However, i16 is not the same as bf16 (integer vs float), so the computation
280+
can not readily use the imitated type (i16).
281+
282+
Therefore, this transformation pass is intended to be used in conjuction
283+
with other transformation passes such as `EmulateUnsupportedFloats` and
284+
`ExtendUnsupportedTypes` that extend the bitwidth of bf16 to f32 and
285+
vice-versa.
286+
287+
Finally, usually, there are instructions available in the target
288+
(dialect/IR) that can take advantage of these generated patterns
289+
(bf16->i16->f32, f32->bf16->i16), and convert them to the supported
290+
types.
291+
For example, Intel provides SPIR-V extension ops that can
292+
take imitated bf16 (i16) and convert them to f32 and vice-versa.
293+
https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/INTEL/SPV_INTEL_bfloat16_conversion.asciidoc
294+
https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertbf16tof-spirvintelconvertbf16tofop
295+
https://mlir.llvm.org/docs/Dialects/SPIR-V/#spirvintelconvertftobf16-spirvintelconvertftobf16op
296+
297+
}];
298+
299+
let options = [
300+
ListOption<"sourceTypeStrs", "source-types", "std::string",
301+
"MLIR types without type support on a given target">,
302+
ListOption<"targetTypeStrs", "target-types", "std::string",
303+
"MLIR types to convert the unsupported source types to">,
304+
];
305+
306+
let dependentDialects = [
307+
"::mlir::gpu::GPUDialect",
308+
"::mlir::arith::ArithDialect",
309+
"::mlir::memref::MemRefDialect"
310+
];
311+
}
312+
313+
261314
#endif // MLIR_DIALECT_GPU_PASSES

mlir/lib/Dialect/Arith/Utils/Utils.cpp

+25
Original file line numberDiff line numberDiff line change
@@ -380,4 +380,29 @@ std::optional<FloatType> parseFloatType(MLIRContext *ctx, StringRef name) {
380380
.Default(std::nullopt);
381381
}
382382

383+
/// Map strings to Int types.
384+
std::optional<IntegerType> parseIntType(MLIRContext *ctx, StringRef name) {
385+
Builder b(ctx);
386+
return llvm::StringSwitch<std::optional<IntegerType>>(name)
387+
.Case("i1", b.getIntegerType(1))
388+
.Case("i2", b.getIntegerType(2))
389+
.Case("i4", b.getIntegerType(4))
390+
.Case("i6", b.getIntegerType(6))
391+
.Case("i8", b.getIntegerType(8))
392+
.Case("i16", b.getIntegerType(16))
393+
.Case("i32", b.getIntegerType(32))
394+
.Case("i64", b.getIntegerType(64))
395+
.Case("i80", b.getIntegerType(80))
396+
.Case("i128", b.getIntegerType(128))
397+
.Default(std::nullopt);
398+
}
399+
/// Map strings to Int or Float types.
400+
std::optional<Type> parseIntOrFloatType(MLIRContext *ctx, StringRef name) {
401+
if (auto floatTy = parseFloatType(ctx, name))
402+
return *floatTy;
403+
if (auto intTy = parseIntType(ctx, name))
404+
return *intTy;
405+
return std::nullopt;
406+
}
407+
383408
} // namespace mlir::arith

mlir/lib/Dialect/GPU/CMakeLists.txt

+3-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ add_mlir_dialect_library(MLIRGPUDialect
2323
MLIRMemRefDialect
2424
MLIRSideEffectInterfaces
2525
MLIRSupport
26-
)
26+
)
2727

2828
add_mlir_dialect_library(MLIRGPUTransforms
2929
Transforms/AllReduceLowering.cpp
@@ -42,6 +42,7 @@ add_mlir_dialect_library(MLIRGPUTransforms
4242
Transforms/SPIRVAttachTarget.cpp
4343
Transforms/SubgroupIdRewriter.cpp
4444
Transforms/SubgroupReduceLowering.cpp
45+
Transforms/ImitateUnsupportedTypes.cpp
4546

4647
OBJECT
4748

@@ -76,7 +77,7 @@ add_mlir_dialect_library(MLIRGPUTransforms
7677
MLIRROCDLTarget
7778
MLIRTransformUtils
7879
MLIRVectorDialect
79-
)
80+
)
8081

8182
add_subdirectory(TransformOps)
8283
add_subdirectory(Pipelines)

0 commit comments

Comments
 (0)