Skip to content

[flang][llvm][OpenMP] Add implicit casts to omp.atomic #131603

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 80 additions & 3 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2889,9 +2889,82 @@ static void genAtomicRead(lower::AbstractConverter &converter,
fir::getBase(converter.genExprAddr(fromExpr, stmtCtx));
mlir::Value toAddress = fir::getBase(converter.genExprAddr(
*semantics::GetExpr(assignmentStmtVariable), stmtCtx));
genAtomicCaptureStatement(converter, fromAddress, toAddress,
leftHandClauseList, rightHandClauseList,
elementType, loc);

if (fromAddress.getType() != toAddress.getType()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add commenst for:
-> Why we cannot use the typedAssignment lowering and is using custom lowering here?
-> Why do these casts have to be added?
-> Why is it safe to do so?
-> Why we cannot use the typedAssignment lowering?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I have added an explanation for these

// Emit an implicit cast. Different yet compatible types on
// omp.atomic.read constitute valid Fortran. The OMPIRBuilder will
// emit atomic instructions (on primitive types) and `__atomic_load`
// libcall (on complex type) without explicitly converting
// between such compatible types. The OMPIRBuilder relies on the
// frontend to resolve such inconsistencies between `omp.atomic.read `
// operand types. Similar inconsistencies between operand types in
// `omp.atomic.write` are resolved through implicit casting by use of typed
// assignment (i.e. `evaluate::Assignment`). However, use of typed
// assignment in `omp.atomic.read` (of form `v = x`) leads to an unsafe,
// non-atomic load of `x` into a temporary `alloca`, followed by an atomic
// read of form `v = alloca`. Hence, it is needed to perform a custom
// implicit cast.

// An atomic read of form `v = x` would (without implicit casting)
// lower to `omp.atomic.read %v = %x : !fir.ref<type1>, !fir.ref<type2>,
// type2`. This implicit casting will rather generate the following FIR:
//
// %alloca = fir.alloca type2
// omp.atomic.read %alloca = %x : !fir.ref<type2>, !fir.ref<type2>, type2
// %load = fir.load %alloca : !fir.ref<type2>
// %cvt = fir.convert %load : (type2) -> type1
// fir.store %cvt to %v : !fir.ref<type1>

// These sequence of operations is thread-safe since each thread allocates
// the `alloca` in its stack, and performs `%alloca = %x` atomically. Once
// safely read, each thread performs the implicit cast on the local
// `alloca`, and writes the final result to `%v`.
mlir::Type toType = fir::unwrapRefType(toAddress.getType());
mlir::Type fromType = fir::unwrapRefType(fromAddress.getType());
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
auto oldIP = builder.saveInsertionPoint();
builder.setInsertionPointToStart(builder.getAllocaBlock());
mlir::Value alloca = builder.create<fir::AllocaOp>(
loc, fromType); // Thread scope `alloca` to atomically read `%x`.
builder.restoreInsertionPoint(oldIP);
genAtomicCaptureStatement(converter, fromAddress, alloca,
leftHandClauseList, rightHandClauseList,
elementType, loc);
auto load = builder.create<fir::LoadOp>(loc, alloca);
if (fir::isa_complex(fromType) && !fir::isa_complex(toType)) {
// Emit an additional `ExtractValueOp` if `fromAddress` is of complex
// type, but `toAddress` is not.
auto extract = builder.create<fir::ExtractValueOp>(
loc, mlir::cast<mlir::ComplexType>(fromType).getElementType(), load,
builder.getArrayAttr(
builder.getIntegerAttr(builder.getIndexType(), 0)));
auto cvt = builder.create<fir::ConvertOp>(loc, toType, extract);
builder.create<fir::StoreOp>(loc, cvt, toAddress);
} else if (!fir::isa_complex(fromType) && fir::isa_complex(toType)) {
// Emit an additional `InsertValueOp` if `toAddress` is of complex
// type, but `fromAddress` is not.
mlir::Value undef = builder.create<fir::UndefOp>(loc, toType);
mlir::Type complexEleTy =
mlir::cast<mlir::ComplexType>(toType).getElementType();
mlir::Value cvt = builder.create<fir::ConvertOp>(loc, complexEleTy, load);
mlir::Value zero = builder.createRealZeroConstant(loc, complexEleTy);
mlir::Value idx0 = builder.create<fir::InsertValueOp>(
loc, toType, undef, cvt,
builder.getArrayAttr(
builder.getIntegerAttr(builder.getIndexType(), 0)));
mlir::Value idx1 = builder.create<fir::InsertValueOp>(
loc, toType, idx0, zero,
builder.getArrayAttr(
builder.getIntegerAttr(builder.getIndexType(), 1)));
builder.create<fir::StoreOp>(loc, idx1, toAddress);
} else {
auto cvt = builder.create<fir::ConvertOp>(loc, toType, load);
builder.create<fir::StoreOp>(loc, cvt, toAddress);
}
} else
genAtomicCaptureStatement(converter, fromAddress, toAddress,
leftHandClauseList, rightHandClauseList,
elementType, loc);
}

/// Processes an atomic construct with update clause.
Expand Down Expand Up @@ -2976,6 +3049,10 @@ static void genAtomicCapture(lower::AbstractConverter &converter,
mlir::Type stmt2VarType =
fir::getBase(converter.genExprValue(assign2.lhs, stmtCtx)).getType();

// Check if implicit type is needed
if (stmt1VarType != stmt2VarType)
TODO(loc, "atomic capture requiring implicit type casts");

mlir::Operation *atomicCaptureOp = nullptr;
mlir::IntegerAttr hint = nullptr;
mlir::omp::ClauseMemoryOrderKindAttr memoryOrder = nullptr;
Expand Down
48 changes: 48 additions & 0 deletions flang/test/Lower/OpenMP/Todo/atomic-capture-implicit-cast.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
!RUN: %not_todo_cmd %flang_fc1 -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s

!CHECK: not yet implemented: atomic capture requiring implicit type casts
subroutine capture_with_convert_f32_to_i32()
implicit none
integer :: k, v, i

k = 1
v = 0

!$omp atomic capture
v = k
k = (i + 1) * 3.14
!$omp end atomic
end subroutine

subroutine capture_with_convert_i32_to_f64()
real(8) :: x
integer :: v
x = 1.0
v = 0
!$omp atomic capture
v = x
x = v
!$omp end atomic
end subroutine capture_with_convert_i32_to_f64

subroutine capture_with_convert_f64_to_i32()
integer :: x
real(8) :: v
x = 1
v = 0
!$omp atomic capture
x = v
v = x
!$omp end atomic
end subroutine capture_with_convert_f64_to_i32

subroutine capture_with_convert_i32_to_f32()
real(4) :: x
integer :: v
x = 1.0
v = 0
!$omp atomic capture
v = x
x = x + v
!$omp end atomic
end subroutine capture_with_convert_i32_to_f32
56 changes: 56 additions & 0 deletions flang/test/Lower/OpenMP/atomic-implicit-cast.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
! REQUIRES : openmp_runtime

! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s

! CHECK: func.func @_QPatomic_implicit_cast_read() {
subroutine atomic_implicit_cast_read
! CHECK: %[[ALLOCA3:.*]] = fir.alloca complex<f32>
! CHECK: %[[ALLOCA2:.*]] = fir.alloca complex<f32>
! CHECK: %[[ALLOCA1:.*]] = fir.alloca i32
! CHECK: %[[ALLOCA0:.*]] = fir.alloca f32

! CHECK: %[[M:.*]] = fir.alloca complex<f64> {bindc_name = "m", uniq_name = "_QFatomic_implicit_cast_readEm"}
! CHECK: %[[M_DECL:.*]]:2 = hlfir.declare %[[M]] {uniq_name = "_QFatomic_implicit_cast_readEm"} : (!fir.ref<complex<f64>>) -> (!fir.ref<complex<f64>>, !fir.ref<complex<f64>>)
! CHECK: %[[W:.*]] = fir.alloca complex<f32> {bindc_name = "w", uniq_name = "_QFatomic_implicit_cast_readEw"}
! CHECK: %[[W_DECL:.*]]:2 = hlfir.declare %[[W]] {uniq_name = "_QFatomic_implicit_cast_readEw"} : (!fir.ref<complex<f32>>) -> (!fir.ref<complex<f32>>, !fir.ref<complex<f32>>)
! CHECK: %[[X:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFatomic_implicit_cast_readEx"}
! CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X]] {uniq_name = "_QFatomic_implicit_cast_readEx"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
! CHECK: %[[Y:.*]] = fir.alloca f32 {bindc_name = "y", uniq_name = "_QFatomic_implicit_cast_readEy"}
! CHECK: %[[Y_DECL:.*]]:2 = hlfir.declare %[[Y]] {uniq_name = "_QFatomic_implicit_cast_readEy"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
! CHECK: %[[Z:.*]] = fir.alloca f64 {bindc_name = "z", uniq_name = "_QFatomic_implicit_cast_readEz"}
! CHECK: %[[Z_DECL:.*]]:2 = hlfir.declare %[[Z]] {uniq_name = "_QFatomic_implicit_cast_readEz"} : (!fir.ref<f64>) -> (!fir.ref<f64>, !fir.ref<f64>)
integer :: x
real :: y
double precision :: z
complex :: w
complex(8) :: m

! CHECK: omp.atomic.read %[[ALLOCA0:.*]] = %[[Y_DECL]]#0 : !fir.ref<f32>, !fir.ref<f32>, f32
! CHECK: %[[LOAD:.*]] = fir.load %[[ALLOCA0]] : !fir.ref<f32>
! CHECK: %[[CVT:.*]] = fir.convert %[[LOAD]] : (f32) -> i32
! CHECK: fir.store %[[CVT]] to %[[X_DECL]]#0 : !fir.ref<i32>
!$omp atomic read
x = y

! CHECK: omp.atomic.read %[[ALLOCA1:.*]] = %[[X_DECL]]#0 : !fir.ref<i32>, !fir.ref<i32>, i32
! CHECK: %[[LOAD:.*]] = fir.load %[[ALLOCA1]] : !fir.ref<i32>
! CHECK: %[[CVT:.*]] = fir.convert %[[LOAD]] : (i32) -> f64
! CHECK: fir.store %[[CVT]] to %[[Z_DECL]]#0 : !fir.ref<f64>
!$omp atomic read
z = x

! CHECK: omp.atomic.read %[[ALLOCA2:.*]] = %[[W_DECL]]#0 : !fir.ref<complex<f32>>, !fir.ref<complex<f32>>, complex<f32>
! CHECK: %[[LOAD:.*]] = fir.load %[[ALLOCA2]] : !fir.ref<complex<f32>>
! CHECK: %[[EXTRACT:.*]] = fir.extract_value %[[LOAD]], [0 : index] : (complex<f32>) -> f32
! CHECK: %[[CVT:.*]] = fir.convert %[[EXTRACT]] : (f32) -> i32
! CHECK: fir.store %[[CVT]] to %[[X_DECL]]#0 : !fir.ref<i32>
!$omp atomic read
x = w

! CHECK: omp.atomic.read %[[ALLOCA3:.*]] = %[[W_DECL]]#0 : !fir.ref<complex<f32>>, !fir.ref<complex<f32>>, complex<f32>
! CHECK: %[[LOAD:.*]] = fir.load %[[ALLOCA3]] : !fir.ref<complex<f32>>
! CHECK: %[[CVT:.*]] = fir.convert %[[LOAD]] : (complex<f32>) -> complex<f64>
! CHECK: fir.store %[[CVT]] to %[[M_DECL]]#0 : !fir.ref<complex<f64>>
!$omp atomic read
m = w
end subroutine
31 changes: 0 additions & 31 deletions llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,33 +268,6 @@ computeOpenMPScheduleType(ScheduleKind ClauseKind, bool HasChunks,
return Result;
}

/// Emit an implicit cast to convert \p XRead to type of variable \p V
static llvm::Value *emitImplicitCast(IRBuilder<> &Builder, llvm::Value *XRead,
llvm::Value *V) {
// TODO: Add this functionality to the `AtomicInfo` interface
llvm::Type *XReadType = XRead->getType();
llvm::Type *VType = V->getType();
if (llvm::AllocaInst *vAlloca = dyn_cast<llvm::AllocaInst>(V))
VType = vAlloca->getAllocatedType();

if (XReadType->isStructTy() && VType->isStructTy())
// No need to extract or convert. A direct
// `store` will suffice.
return XRead;

if (XReadType->isStructTy())
XRead = Builder.CreateExtractValue(XRead, /*Idxs=*/0);
if (VType->isIntegerTy() && XReadType->isFloatingPointTy())
XRead = Builder.CreateFPToSI(XRead, VType);
else if (VType->isFloatingPointTy() && XReadType->isIntegerTy())
XRead = Builder.CreateSIToFP(XRead, VType);
else if (VType->isIntegerTy() && XReadType->isIntegerTy())
XRead = Builder.CreateIntCast(XRead, VType, true);
else if (VType->isFloatingPointTy() && XReadType->isFloatingPointTy())
XRead = Builder.CreateFPCast(XRead, VType);
return XRead;
}

/// Make \p Source branch to \p Target.
///
/// Handles two situations:
Expand Down Expand Up @@ -8685,8 +8658,6 @@ OpenMPIRBuilder::createAtomicRead(const LocationDescription &Loc,
}
}
checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Read);
if (XRead->getType() != V.Var->getType())
XRead = emitImplicitCast(Builder, XRead, V.Var);
Builder.CreateStore(XRead, V.Var, V.IsVolatile);
return Builder.saveIP();
}
Expand Down Expand Up @@ -8983,8 +8954,6 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createAtomicCapture(
return AtomicResult.takeError();
Value *CapturedVal =
(IsPostfixUpdate ? AtomicResult->first : AtomicResult->second);
if (CapturedVal->getType() != V.Var->getType())
CapturedVal = emitImplicitCast(Builder, CapturedVal, V.Var);
Builder.CreateStore(CapturedVal, V.Var, V.IsVolatile);

checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Capture);
Expand Down
21 changes: 7 additions & 14 deletions mlir/test/Target/LLVMIR/openmp-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1396,42 +1396,35 @@ llvm.func @omp_atomic_read_implicit_cast () {

//CHECK: call void @__atomic_load(i64 8, ptr %[[X_ELEMENT]], ptr %[[ATOMIC_LOAD_TEMP]], i32 0)
//CHECK: %[[LOAD:.*]] = load { float, float }, ptr %[[ATOMIC_LOAD_TEMP]], align 8
//CHECK: %[[EXT:.*]] = extractvalue { float, float } %[[LOAD]], 0
//CHECK: store float %[[EXT]], ptr %[[Y]], align 4
//CHECK: store { float, float } %[[LOAD]], ptr %[[Y]], align 4
omp.atomic.read %3 = %17 : !llvm.ptr, !llvm.ptr, !llvm.struct<(f32, f32)>

//CHECK: %[[ATOMIC_LOAD_TEMP:.*]] = load atomic i32, ptr %[[Z]] monotonic, align 4
//CHECK: %[[CAST:.*]] = bitcast i32 %[[ATOMIC_LOAD_TEMP]] to float
//CHECK: %[[LOAD:.*]] = fpext float %[[CAST]] to double
//CHECK: store double %[[LOAD]], ptr %[[Y]], align 8
//CHECK: store float %[[CAST]], ptr %[[Y]], align 4
omp.atomic.read %3 = %1 : !llvm.ptr, !llvm.ptr, f32

//CHECK: %[[ATOMIC_LOAD_TEMP:.*]] = load atomic i32, ptr %[[W]] monotonic, align 4
//CHECK: %[[LOAD:.*]] = sitofp i32 %[[ATOMIC_LOAD_TEMP]] to double
//CHECK: store double %[[LOAD]], ptr %[[Y]], align 8
//CHECK: store i32 %[[ATOMIC_LOAD_TEMP]], ptr %[[Y]], align 4
omp.atomic.read %3 = %7 : !llvm.ptr, !llvm.ptr, i32

//CHECK: %[[ATOMIC_LOAD_TEMP:.*]] = load atomic i64, ptr %[[Y]] monotonic, align 4
//CHECK: %[[CAST:.*]] = bitcast i64 %[[ATOMIC_LOAD_TEMP]] to double
//CHECK: %[[LOAD:.*]] = fptrunc double %[[CAST]] to float
//CHECK: store float %[[LOAD]], ptr %[[Z]], align 4
//CHECK: store double %[[CAST]], ptr %[[Z]], align 8
omp.atomic.read %1 = %3 : !llvm.ptr, !llvm.ptr, f64

//CHECK: %[[ATOMIC_LOAD_TEMP:.*]] = load atomic i32, ptr %[[W]] monotonic, align 4
//CHECK: %[[LOAD:.*]] = sitofp i32 %[[ATOMIC_LOAD_TEMP]] to float
//CHECK: store float %[[LOAD]], ptr %[[Z]], align 4
//CHECK: store i32 %[[ATOMIC_LOAD_TEMP]], ptr %[[Z]], align 4
omp.atomic.read %1 = %7 : !llvm.ptr, !llvm.ptr, i32

//CHECK: %[[ATOMIC_LOAD_TEMP:.*]] = load atomic i64, ptr %[[Y]] monotonic, align 4
//CHECK: %[[CAST:.*]] = bitcast i64 %[[ATOMIC_LOAD_TEMP]] to double
//CHECK: %[[LOAD:.*]] = fptosi double %[[CAST]] to i32
//CHECK: store i32 %[[LOAD]], ptr %[[W]], align 4
//CHECK: store double %[[CAST]], ptr %[[W]], align 8
omp.atomic.read %7 = %3 : !llvm.ptr, !llvm.ptr, f64

//CHECK: %[[ATOMIC_LOAD_TEMP:.*]] = load atomic i32, ptr %[[Z]] monotonic, align 4
//CHECK: %[[CAST:.*]] = bitcast i32 %[[ATOMIC_LOAD_TEMP]] to float
//CHECK: %[[LOAD:.*]] = fptosi float %[[CAST]] to i32
//CHECK: store i32 %[[LOAD]], ptr %[[W]], align 4
//CHECK: store float %[[CAST]], ptr %[[W]], align 4
omp.atomic.read %7 = %1 : !llvm.ptr, !llvm.ptr, f32
llvm.return
}
Expand Down