Skip to content

Commit c617466

Browse files
authored
[flang][llvm][OpenMP] Add implicit casts to omp.atomic (#131603)
Currently, implicit casts in Fortran are handled by the OMPIRBuilder. This patch shifts that responsibility to FIR codegen.
1 parent 2bb2f8a commit c617466

File tree

5 files changed

+191
-48
lines changed

5 files changed

+191
-48
lines changed

flang/lib/Lower/OpenMP/OpenMP.cpp

+80-3
Original file line numberDiff line numberDiff line change
@@ -2889,9 +2889,82 @@ static void genAtomicRead(lower::AbstractConverter &converter,
28892889
fir::getBase(converter.genExprAddr(fromExpr, stmtCtx));
28902890
mlir::Value toAddress = fir::getBase(converter.genExprAddr(
28912891
*semantics::GetExpr(assignmentStmtVariable), stmtCtx));
2892-
genAtomicCaptureStatement(converter, fromAddress, toAddress,
2893-
leftHandClauseList, rightHandClauseList,
2894-
elementType, loc);
2892+
2893+
if (fromAddress.getType() != toAddress.getType()) {
2894+
// Emit an implicit cast. Different yet compatible types on
2895+
// omp.atomic.read constitute valid Fortran. The OMPIRBuilder will
2896+
// emit atomic instructions (on primitive types) and `__atomic_load`
2897+
// libcall (on complex type) without explicitly converting
2898+
// between such compatible types. The OMPIRBuilder relies on the
2899+
// frontend to resolve such inconsistencies between `omp.atomic.read `
2900+
// operand types. Similar inconsistencies between operand types in
2901+
// `omp.atomic.write` are resolved through implicit casting by use of typed
2902+
// assignment (i.e. `evaluate::Assignment`). However, use of typed
2903+
// assignment in `omp.atomic.read` (of form `v = x`) leads to an unsafe,
2904+
// non-atomic load of `x` into a temporary `alloca`, followed by an atomic
2905+
// read of form `v = alloca`. Hence, it is needed to perform a custom
2906+
// implicit cast.
2907+
2908+
// An atomic read of form `v = x` would (without implicit casting)
2909+
// lower to `omp.atomic.read %v = %x : !fir.ref<type1>, !fir.ref<type2>,
2910+
// type2`. This implicit casting will rather generate the following FIR:
2911+
//
2912+
// %alloca = fir.alloca type2
2913+
// omp.atomic.read %alloca = %x : !fir.ref<type2>, !fir.ref<type2>, type2
2914+
// %load = fir.load %alloca : !fir.ref<type2>
2915+
// %cvt = fir.convert %load : (type2) -> type1
2916+
// fir.store %cvt to %v : !fir.ref<type1>
2917+
2918+
// These sequence of operations is thread-safe since each thread allocates
2919+
// the `alloca` in its stack, and performs `%alloca = %x` atomically. Once
2920+
// safely read, each thread performs the implicit cast on the local
2921+
// `alloca`, and writes the final result to `%v`.
2922+
mlir::Type toType = fir::unwrapRefType(toAddress.getType());
2923+
mlir::Type fromType = fir::unwrapRefType(fromAddress.getType());
2924+
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
2925+
auto oldIP = builder.saveInsertionPoint();
2926+
builder.setInsertionPointToStart(builder.getAllocaBlock());
2927+
mlir::Value alloca = builder.create<fir::AllocaOp>(
2928+
loc, fromType); // Thread scope `alloca` to atomically read `%x`.
2929+
builder.restoreInsertionPoint(oldIP);
2930+
genAtomicCaptureStatement(converter, fromAddress, alloca,
2931+
leftHandClauseList, rightHandClauseList,
2932+
elementType, loc);
2933+
auto load = builder.create<fir::LoadOp>(loc, alloca);
2934+
if (fir::isa_complex(fromType) && !fir::isa_complex(toType)) {
2935+
// Emit an additional `ExtractValueOp` if `fromAddress` is of complex
2936+
// type, but `toAddress` is not.
2937+
auto extract = builder.create<fir::ExtractValueOp>(
2938+
loc, mlir::cast<mlir::ComplexType>(fromType).getElementType(), load,
2939+
builder.getArrayAttr(
2940+
builder.getIntegerAttr(builder.getIndexType(), 0)));
2941+
auto cvt = builder.create<fir::ConvertOp>(loc, toType, extract);
2942+
builder.create<fir::StoreOp>(loc, cvt, toAddress);
2943+
} else if (!fir::isa_complex(fromType) && fir::isa_complex(toType)) {
2944+
// Emit an additional `InsertValueOp` if `toAddress` is of complex
2945+
// type, but `fromAddress` is not.
2946+
mlir::Value undef = builder.create<fir::UndefOp>(loc, toType);
2947+
mlir::Type complexEleTy =
2948+
mlir::cast<mlir::ComplexType>(toType).getElementType();
2949+
mlir::Value cvt = builder.create<fir::ConvertOp>(loc, complexEleTy, load);
2950+
mlir::Value zero = builder.createRealZeroConstant(loc, complexEleTy);
2951+
mlir::Value idx0 = builder.create<fir::InsertValueOp>(
2952+
loc, toType, undef, cvt,
2953+
builder.getArrayAttr(
2954+
builder.getIntegerAttr(builder.getIndexType(), 0)));
2955+
mlir::Value idx1 = builder.create<fir::InsertValueOp>(
2956+
loc, toType, idx0, zero,
2957+
builder.getArrayAttr(
2958+
builder.getIntegerAttr(builder.getIndexType(), 1)));
2959+
builder.create<fir::StoreOp>(loc, idx1, toAddress);
2960+
} else {
2961+
auto cvt = builder.create<fir::ConvertOp>(loc, toType, load);
2962+
builder.create<fir::StoreOp>(loc, cvt, toAddress);
2963+
}
2964+
} else
2965+
genAtomicCaptureStatement(converter, fromAddress, toAddress,
2966+
leftHandClauseList, rightHandClauseList,
2967+
elementType, loc);
28952968
}
28962969

28972970
/// Processes an atomic construct with update clause.
@@ -2976,6 +3049,10 @@ static void genAtomicCapture(lower::AbstractConverter &converter,
29763049
mlir::Type stmt2VarType =
29773050
fir::getBase(converter.genExprValue(assign2.lhs, stmtCtx)).getType();
29783051

3052+
// Check if implicit type is needed
3053+
if (stmt1VarType != stmt2VarType)
3054+
TODO(loc, "atomic capture requiring implicit type casts");
3055+
29793056
mlir::Operation *atomicCaptureOp = nullptr;
29803057
mlir::IntegerAttr hint = nullptr;
29813058
mlir::omp::ClauseMemoryOrderKindAttr memoryOrder = nullptr;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
!RUN: %not_todo_cmd %flang_fc1 -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
2+
3+
!CHECK: not yet implemented: atomic capture requiring implicit type casts
4+
subroutine capture_with_convert_f32_to_i32()
5+
implicit none
6+
integer :: k, v, i
7+
8+
k = 1
9+
v = 0
10+
11+
!$omp atomic capture
12+
v = k
13+
k = (i + 1) * 3.14
14+
!$omp end atomic
15+
end subroutine
16+
17+
subroutine capture_with_convert_i32_to_f64()
18+
real(8) :: x
19+
integer :: v
20+
x = 1.0
21+
v = 0
22+
!$omp atomic capture
23+
v = x
24+
x = v
25+
!$omp end atomic
26+
end subroutine capture_with_convert_i32_to_f64
27+
28+
subroutine capture_with_convert_f64_to_i32()
29+
integer :: x
30+
real(8) :: v
31+
x = 1
32+
v = 0
33+
!$omp atomic capture
34+
x = v
35+
v = x
36+
!$omp end atomic
37+
end subroutine capture_with_convert_f64_to_i32
38+
39+
subroutine capture_with_convert_i32_to_f32()
40+
real(4) :: x
41+
integer :: v
42+
x = 1.0
43+
v = 0
44+
!$omp atomic capture
45+
v = x
46+
x = x + v
47+
!$omp end atomic
48+
end subroutine capture_with_convert_i32_to_f32
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
! REQUIRES : openmp_runtime
2+
3+
! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
4+
5+
! CHECK: func.func @_QPatomic_implicit_cast_read() {
6+
subroutine atomic_implicit_cast_read
7+
! CHECK: %[[ALLOCA3:.*]] = fir.alloca complex<f32>
8+
! CHECK: %[[ALLOCA2:.*]] = fir.alloca complex<f32>
9+
! CHECK: %[[ALLOCA1:.*]] = fir.alloca i32
10+
! CHECK: %[[ALLOCA0:.*]] = fir.alloca f32
11+
12+
! CHECK: %[[M:.*]] = fir.alloca complex<f64> {bindc_name = "m", uniq_name = "_QFatomic_implicit_cast_readEm"}
13+
! CHECK: %[[M_DECL:.*]]:2 = hlfir.declare %[[M]] {uniq_name = "_QFatomic_implicit_cast_readEm"} : (!fir.ref<complex<f64>>) -> (!fir.ref<complex<f64>>, !fir.ref<complex<f64>>)
14+
! CHECK: %[[W:.*]] = fir.alloca complex<f32> {bindc_name = "w", uniq_name = "_QFatomic_implicit_cast_readEw"}
15+
! CHECK: %[[W_DECL:.*]]:2 = hlfir.declare %[[W]] {uniq_name = "_QFatomic_implicit_cast_readEw"} : (!fir.ref<complex<f32>>) -> (!fir.ref<complex<f32>>, !fir.ref<complex<f32>>)
16+
! CHECK: %[[X:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFatomic_implicit_cast_readEx"}
17+
! CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] {uniq_name = "_QFatomic_implicit_cast_readEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
18+
! CHECK: %[[Y:.*]] = fir.alloca f32 {bindc_name = "y", uniq_name = "_QFatomic_implicit_cast_readEy"}
19+
! CHECK: %[[Y_DECL:.*]]:2 = hlfir.declare %[[Y]] {uniq_name = "_QFatomic_implicit_cast_readEy"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
20+
! CHECK: %[[Z:.*]] = fir.alloca f64 {bindc_name = "z", uniq_name = "_QFatomic_implicit_cast_readEz"}
21+
! CHECK: %[[Z_DECL:.*]]:2 = hlfir.declare %[[Z]] {uniq_name = "_QFatomic_implicit_cast_readEz"} : (!fir.ref<f64>) -> (!fir.ref<f64>, !fir.ref<f64>)
22+
integer :: x
23+
real :: y
24+
double precision :: z
25+
complex :: w
26+
complex(8) :: m
27+
28+
! CHECK: omp.atomic.read %[[ALLOCA0:.*]] = %[[Y_DECL]]#0 : !fir.ref<f32>, !fir.ref<f32>, f32
29+
! CHECK: %[[LOAD:.*]] = fir.load %[[ALLOCA0]] : !fir.ref<f32>
30+
! CHECK: %[[CVT:.*]] = fir.convert %[[LOAD]] : (f32) -> i32
31+
! CHECK: fir.store %[[CVT]] to %[[X_DECL]]#0 : !fir.ref<i32>
32+
!$omp atomic read
33+
x = y
34+
35+
! CHECK: omp.atomic.read %[[ALLOCA1:.*]] = %[[X_DECL]]#0 : !fir.ref<i32>, !fir.ref<i32>, i32
36+
! CHECK: %[[LOAD:.*]] = fir.load %[[ALLOCA1]] : !fir.ref<i32>
37+
! CHECK: %[[CVT:.*]] = fir.convert %[[LOAD]] : (i32) -> f64
38+
! CHECK: fir.store %[[CVT]] to %[[Z_DECL]]#0 : !fir.ref<f64>
39+
!$omp atomic read
40+
z = x
41+
42+
! CHECK: omp.atomic.read %[[ALLOCA2:.*]] = %[[W_DECL]]#0 : !fir.ref<complex<f32>>, !fir.ref<complex<f32>>, complex<f32>
43+
! CHECK: %[[LOAD:.*]] = fir.load %[[ALLOCA2]] : !fir.ref<complex<f32>>
44+
! CHECK: %[[EXTRACT:.*]] = fir.extract_value %[[LOAD]], [0 : index] : (complex<f32>) -> f32
45+
! CHECK: %[[CVT:.*]] = fir.convert %[[EXTRACT]] : (f32) -> i32
46+
! CHECK: fir.store %[[CVT]] to %[[X_DECL]]#0 : !fir.ref<i32>
47+
!$omp atomic read
48+
x = w
49+
50+
! CHECK: omp.atomic.read %[[ALLOCA3:.*]] = %[[W_DECL]]#0 : !fir.ref<complex<f32>>, !fir.ref<complex<f32>>, complex<f32>
51+
! CHECK: %[[LOAD:.*]] = fir.load %[[ALLOCA3]] : !fir.ref<complex<f32>>
52+
! CHECK: %[[CVT:.*]] = fir.convert %[[LOAD]] : (complex<f32>) -> complex<f64>
53+
! CHECK: fir.store %[[CVT]] to %[[M_DECL]]#0 : !fir.ref<complex<f64>>
54+
!$omp atomic read
55+
m = w
56+
end subroutine

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

-31
Original file line numberDiff line numberDiff line change
@@ -268,33 +268,6 @@ computeOpenMPScheduleType(ScheduleKind ClauseKind, bool HasChunks,
268268
return Result;
269269
}
270270

271-
/// Emit an implicit cast to convert \p XRead to type of variable \p V
272-
static llvm::Value *emitImplicitCast(IRBuilder<> &Builder, llvm::Value *XRead,
273-
llvm::Value *V) {
274-
// TODO: Add this functionality to the `AtomicInfo` interface
275-
llvm::Type *XReadType = XRead->getType();
276-
llvm::Type *VType = V->getType();
277-
if (llvm::AllocaInst *vAlloca = dyn_cast<llvm::AllocaInst>(V))
278-
VType = vAlloca->getAllocatedType();
279-
280-
if (XReadType->isStructTy() && VType->isStructTy())
281-
// No need to extract or convert. A direct
282-
// `store` will suffice.
283-
return XRead;
284-
285-
if (XReadType->isStructTy())
286-
XRead = Builder.CreateExtractValue(XRead, /*Idxs=*/0);
287-
if (VType->isIntegerTy() && XReadType->isFloatingPointTy())
288-
XRead = Builder.CreateFPToSI(XRead, VType);
289-
else if (VType->isFloatingPointTy() && XReadType->isIntegerTy())
290-
XRead = Builder.CreateSIToFP(XRead, VType);
291-
else if (VType->isIntegerTy() && XReadType->isIntegerTy())
292-
XRead = Builder.CreateIntCast(XRead, VType, true);
293-
else if (VType->isFloatingPointTy() && XReadType->isFloatingPointTy())
294-
XRead = Builder.CreateFPCast(XRead, VType);
295-
return XRead;
296-
}
297-
298271
/// Make \p Source branch to \p Target.
299272
///
300273
/// Handles two situations:
@@ -8685,8 +8658,6 @@ OpenMPIRBuilder::createAtomicRead(const LocationDescription &Loc,
86858658
}
86868659
}
86878660
checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Read);
8688-
if (XRead->getType() != V.Var->getType())
8689-
XRead = emitImplicitCast(Builder, XRead, V.Var);
86908661
Builder.CreateStore(XRead, V.Var, V.IsVolatile);
86918662
return Builder.saveIP();
86928663
}
@@ -8983,8 +8954,6 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createAtomicCapture(
89838954
return AtomicResult.takeError();
89848955
Value *CapturedVal =
89858956
(IsPostfixUpdate ? AtomicResult->first : AtomicResult->second);
8986-
if (CapturedVal->getType() != V.Var->getType())
8987-
CapturedVal = emitImplicitCast(Builder, CapturedVal, V.Var);
89888957
Builder.CreateStore(CapturedVal, V.Var, V.IsVolatile);
89898958

89908959
checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Capture);

mlir/test/Target/LLVMIR/openmp-llvm.mlir

+7-14
Original file line numberDiff line numberDiff line change
@@ -1396,42 +1396,35 @@ llvm.func @omp_atomic_read_implicit_cast () {
13961396

13971397
//CHECK: call void @__atomic_load(i64 8, ptr %[[X_ELEMENT]], ptr %[[ATOMIC_LOAD_TEMP]], i32 0)
13981398
//CHECK: %[[LOAD:.*]] = load { float, float }, ptr %[[ATOMIC_LOAD_TEMP]], align 8
1399-
//CHECK: %[[EXT:.*]] = extractvalue { float, float } %[[LOAD]], 0
1400-
//CHECK: store float %[[EXT]], ptr %[[Y]], align 4
1399+
//CHECK: store { float, float } %[[LOAD]], ptr %[[Y]], align 4
14011400
omp.atomic.read %3 = %17 : !llvm.ptr, !llvm.ptr, !llvm.struct<(f32, f32)>
14021401

14031402
//CHECK: %[[ATOMIC_LOAD_TEMP:.*]] = load atomic i32, ptr %[[Z]] monotonic, align 4
14041403
//CHECK: %[[CAST:.*]] = bitcast i32 %[[ATOMIC_LOAD_TEMP]] to float
1405-
//CHECK: %[[LOAD:.*]] = fpext float %[[CAST]] to double
1406-
//CHECK: store double %[[LOAD]], ptr %[[Y]], align 8
1404+
//CHECK: store float %[[CAST]], ptr %[[Y]], align 4
14071405
omp.atomic.read %3 = %1 : !llvm.ptr, !llvm.ptr, f32
14081406

14091407
//CHECK: %[[ATOMIC_LOAD_TEMP:.*]] = load atomic i32, ptr %[[W]] monotonic, align 4
1410-
//CHECK: %[[LOAD:.*]] = sitofp i32 %[[ATOMIC_LOAD_TEMP]] to double
1411-
//CHECK: store double %[[LOAD]], ptr %[[Y]], align 8
1408+
//CHECK: store i32 %[[ATOMIC_LOAD_TEMP]], ptr %[[Y]], align 4
14121409
omp.atomic.read %3 = %7 : !llvm.ptr, !llvm.ptr, i32
14131410

14141411
//CHECK: %[[ATOMIC_LOAD_TEMP:.*]] = load atomic i64, ptr %[[Y]] monotonic, align 4
14151412
//CHECK: %[[CAST:.*]] = bitcast i64 %[[ATOMIC_LOAD_TEMP]] to double
1416-
//CHECK: %[[LOAD:.*]] = fptrunc double %[[CAST]] to float
1417-
//CHECK: store float %[[LOAD]], ptr %[[Z]], align 4
1413+
//CHECK: store double %[[CAST]], ptr %[[Z]], align 8
14181414
omp.atomic.read %1 = %3 : !llvm.ptr, !llvm.ptr, f64
14191415

14201416
//CHECK: %[[ATOMIC_LOAD_TEMP:.*]] = load atomic i32, ptr %[[W]] monotonic, align 4
1421-
//CHECK: %[[LOAD:.*]] = sitofp i32 %[[ATOMIC_LOAD_TEMP]] to float
1422-
//CHECK: store float %[[LOAD]], ptr %[[Z]], align 4
1417+
//CHECK: store i32 %[[ATOMIC_LOAD_TEMP]], ptr %[[Z]], align 4
14231418
omp.atomic.read %1 = %7 : !llvm.ptr, !llvm.ptr, i32
14241419

14251420
//CHECK: %[[ATOMIC_LOAD_TEMP:.*]] = load atomic i64, ptr %[[Y]] monotonic, align 4
14261421
//CHECK: %[[CAST:.*]] = bitcast i64 %[[ATOMIC_LOAD_TEMP]] to double
1427-
//CHECK: %[[LOAD:.*]] = fptosi double %[[CAST]] to i32
1428-
//CHECK: store i32 %[[LOAD]], ptr %[[W]], align 4
1422+
//CHECK: store double %[[CAST]], ptr %[[W]], align 8
14291423
omp.atomic.read %7 = %3 : !llvm.ptr, !llvm.ptr, f64
14301424

14311425
//CHECK: %[[ATOMIC_LOAD_TEMP:.*]] = load atomic i32, ptr %[[Z]] monotonic, align 4
14321426
//CHECK: %[[CAST:.*]] = bitcast i32 %[[ATOMIC_LOAD_TEMP]] to float
1433-
//CHECK: %[[LOAD:.*]] = fptosi float %[[CAST]] to i32
1434-
//CHECK: store i32 %[[LOAD]], ptr %[[W]], align 4
1427+
//CHECK: store float %[[CAST]], ptr %[[W]], align 4
14351428
omp.atomic.read %7 = %1 : !llvm.ptr, !llvm.ptr, f32
14361429
llvm.return
14371430
}

0 commit comments

Comments
 (0)