Skip to content

[CIR] Refactor floating point type constraints #138112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td
Original file line number Diff line number Diff line change
Expand Up @@ -110,4 +110,35 @@ def CIR_AnyFundamentalSIntType
let cppFunctionName = "isFundamentalSIntType";
}

//===----------------------------------------------------------------------===//
// Float Type predicates
//===----------------------------------------------------------------------===//

def CIR_AnySingleType : CIR_TypeBase<"::cir::SingleType", "single float type">;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason for the explicit global scope?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe it used to make problems when core mlir dialects where included in out of tree projects/dialects.
Thats why entire mlir has explicit global scope, though I cannot remember the issue it was causing.

def CIR_AnyFP32Type : TypeAlias<CIR_AnySingleType>;

def CIR_AnyDoubleType : CIR_TypeBase<"::cir::DoubleType", "double float type">;
def CIR_AnyFP64Type : TypeAlias<CIR_AnyDoubleType>;

def CIR_AnyFP16Type : CIR_TypeBase<"::cir::FP16Type", "f16 type">;
def CIR_AnyBFloat16Type : CIR_TypeBase<"::cir::BF16Type", "bf16 type">;
def CIR_AnyFP80Type : CIR_TypeBase<"::cir::FP80Type", "f80 type">;
def CIR_AnyFP128Type : CIR_TypeBase<"::cir::FP128Type", "f128 type">;
def CIR_AnyLongDoubleType : CIR_TypeBase<"::cir::LongDoubleType",
"long double type">;

def CIR_AnyFloatType : AnyTypeOf<[
CIR_AnySingleType, CIR_AnyDoubleType, CIR_AnyFP16Type,
CIR_AnyBFloat16Type, CIR_AnyFP80Type, CIR_AnyFP128Type,
CIR_AnyLongDoubleType
]> {
let cppFunctionName = "isAnyFloatingPointType";
}

def CIR_AnyIntOrFloatType : AnyTypeOf<[CIR_AnyFloatType, CIR_AnyIntType],
"integer or floating point type"
> {
let cppFunctionName = "isAnyIntegerOrFloatingPointType";
}

#endif // CLANG_CIR_DIALECT_IR_CIRTYPECONSTRAINTS_TD
1 change: 0 additions & 1 deletion clang/include/clang/CIR/Dialect/IR/CIRTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ struct RecordTypeStorage;

bool isValidFundamentalIntWidth(unsigned width);

bool isAnyFloatingPointType(mlir::Type t);
bool isFPOrFPVectorTy(mlir::Type);

} // namespace cir
Expand Down
23 changes: 7 additions & 16 deletions clang/include/clang/CIR/Dialect/IR/CIRTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,10 @@ def CIR_IntType : CIR_Type<"Int", "int",
// FloatType
//===----------------------------------------------------------------------===//

class CIR_FloatType<string name, string mnemonic>
: CIR_Type<name, mnemonic,
[
DeclareTypeInterfaceMethods<DataLayoutTypeInterface>,
DeclareTypeInterfaceMethods<CIRFPTypeInterface>,
]> {}
class CIR_FloatType<string name, string mnemonic> : CIR_Type<name, mnemonic, [
DeclareTypeInterfaceMethods<DataLayoutTypeInterface>,
DeclareTypeInterfaceMethods<CIRFPTypeInterface>
]>;

def CIR_Single : CIR_FloatType<"Single", "float"> {
let summary = "CIR single-precision 32-bit float type";
Expand Down Expand Up @@ -155,21 +153,14 @@ def CIR_LongDouble : CIR_FloatType<"LongDouble", "long_double"> {
format are all in use.
}];

let parameters = (ins "mlir::Type":$underlying);
let parameters = (ins AnyTypeOf<[CIR_Double, CIR_FP80, CIR_FP128],
"expects !cir.double, !cir.fp80 or !cir.fp128">:$underlying);

let assemblyFormat = [{
`<` $underlying `>`
}];

let genVerifyDecl = 1;
}

// Constraints

def CIR_AnyFloat: AnyTypeOf<[CIR_Single, CIR_Double, CIR_FP80, CIR_FP128,
CIR_LongDouble, CIR_FP16, CIR_BFloat16]>;
def CIR_AnyIntOrFloat: AnyTypeOf<[CIR_AnyFloat, CIR_IntType]>;

//===----------------------------------------------------------------------===//
// PointerType
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -518,7 +509,7 @@ def CIRRecordType : Type<

def CIR_AnyType : AnyTypeOf<[
CIR_VoidType, CIR_BoolType, CIR_ArrayType, CIR_VectorType, CIR_IntType,
CIR_AnyFloat, CIR_PointerType, CIR_FuncType, CIR_RecordType
CIR_AnyFloatType, CIR_PointerType, CIR_FuncType, CIR_RecordType
]>;

#endif // MLIR_CIR_DIALECT_CIR_TYPES
20 changes: 0 additions & 20 deletions clang/lib/CIR/Dialect/IR/CIRTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -550,26 +550,6 @@ LongDoubleType::getABIAlignment(const mlir::DataLayout &dataLayout,
.getABIAlignment(dataLayout, params);
}

LogicalResult
LongDoubleType::verify(function_ref<InFlightDiagnostic()> emitError,
mlir::Type underlying) {
if (!mlir::isa<DoubleType, FP80Type, FP128Type>(underlying)) {
emitError() << "invalid underlying type for long double";
return failure();
}

return success();
}

//===----------------------------------------------------------------------===//
// Floating-point type helpers
//===----------------------------------------------------------------------===//

bool cir::isAnyFloatingPointType(mlir::Type t) {
return isa<cir::SingleType, cir::DoubleType, cir::LongDoubleType,
cir::FP80Type, cir::BF16Type, cir::FP16Type, cir::FP128Type>(t);
}

//===----------------------------------------------------------------------===//
// Floating-point and Float-point Vector type helpers
//===----------------------------------------------------------------------===//
Expand Down
6 changes: 6 additions & 0 deletions clang/test/CIR/IR/invalid-long-double.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// RUN: cir-opt %s -verify-diagnostics -split-input-file

// expected-error@+1 {{failed to verify 'underlying': expects !cir.double, !cir.fp80 or !cir.fp128}}
cir.func @bad_long_double(%arg0 : !cir.long_double<!cir.float>) -> () {
cir.return
}
Loading