Skip to content

Commit 6038fd4

Browse files
authored
[flang][cuda] Lower atomiccas, atomicxor and atomicexch (#128242)
Lower atomiccas, atomicxor and atomicexch to corresponding llvm atomic operations.
1 parent 8ffdc3b commit 6038fd4

File tree

4 files changed

+192
-62
lines changed

4 files changed

+192
-62
lines changed

flang/include/flang/Optimizer/Builder/IntrinsicCall.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,12 +187,15 @@ struct IntrinsicLibrary {
187187
mlir::Value genAtanpi(mlir::Type, llvm::ArrayRef<mlir::Value>);
188188
mlir::Value genAtomicAdd(mlir::Type, llvm::ArrayRef<mlir::Value>);
189189
mlir::Value genAtomicAnd(mlir::Type, llvm::ArrayRef<mlir::Value>);
190-
mlir::Value genAtomicOr(mlir::Type, llvm::ArrayRef<mlir::Value>);
190+
mlir::Value genAtomicCas(mlir::Type, llvm::ArrayRef<mlir::Value>);
191191
mlir::Value genAtomicDec(mlir::Type, llvm::ArrayRef<mlir::Value>);
192+
mlir::Value genAtomicExch(mlir::Type, llvm::ArrayRef<mlir::Value>);
192193
mlir::Value genAtomicInc(mlir::Type, llvm::ArrayRef<mlir::Value>);
193194
mlir::Value genAtomicMax(mlir::Type, llvm::ArrayRef<mlir::Value>);
194195
mlir::Value genAtomicMin(mlir::Type, llvm::ArrayRef<mlir::Value>);
196+
mlir::Value genAtomicOr(mlir::Type, llvm::ArrayRef<mlir::Value>);
195197
mlir::Value genAtomicSub(mlir::Type, llvm::ArrayRef<mlir::Value>);
198+
mlir::Value genAtomicXor(mlir::Type, llvm::ArrayRef<mlir::Value>);
196199
fir::ExtendedValue
197200
genCommandArgumentCount(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
198201
mlir::Value genAsind(mlir::Type, llvm::ArrayRef<mlir::Value>);

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,39 @@ static constexpr IntrinsicHandler handlers[]{
152152
{"atomicaddi", &I::genAtomicAdd, {{{"a", asAddr}, {"v", asValue}}}, false},
153153
{"atomicaddl", &I::genAtomicAdd, {{{"a", asAddr}, {"v", asValue}}}, false},
154154
{"atomicandi", &I::genAtomicAnd, {{{"a", asAddr}, {"v", asValue}}}, false},
155+
{"atomiccasd",
156+
&I::genAtomicCas,
157+
{{{"a", asAddr}, {"v1", asValue}, {"v2", asValue}}},
158+
false},
159+
{"atomiccasf",
160+
&I::genAtomicCas,
161+
{{{"a", asAddr}, {"v1", asValue}, {"v2", asValue}}},
162+
false},
163+
{"atomiccasi",
164+
&I::genAtomicCas,
165+
{{{"a", asAddr}, {"v1", asValue}, {"v2", asValue}}},
166+
false},
167+
{"atomiccasul",
168+
&I::genAtomicCas,
169+
{{{"a", asAddr}, {"v1", asValue}, {"v2", asValue}}},
170+
false},
155171
{"atomicdeci", &I::genAtomicDec, {{{"a", asAddr}, {"v", asValue}}}, false},
172+
{"atomicexchd",
173+
&I::genAtomicExch,
174+
{{{"a", asAddr}, {"v", asValue}}},
175+
false},
176+
{"atomicexchf",
177+
&I::genAtomicExch,
178+
{{{"a", asAddr}, {"v", asValue}}},
179+
false},
180+
{"atomicexchi",
181+
&I::genAtomicExch,
182+
{{{"a", asAddr}, {"v", asValue}}},
183+
false},
184+
{"atomicexchul",
185+
&I::genAtomicExch,
186+
{{{"a", asAddr}, {"v", asValue}}},
187+
false},
156188
{"atomicinci", &I::genAtomicInc, {{{"a", asAddr}, {"v", asValue}}}, false},
157189
{"atomicmaxd", &I::genAtomicMax, {{{"a", asAddr}, {"v", asValue}}}, false},
158190
{"atomicmaxf", &I::genAtomicMax, {{{"a", asAddr}, {"v", asValue}}}, false},
@@ -167,6 +199,7 @@ static constexpr IntrinsicHandler handlers[]{
167199
{"atomicsubf", &I::genAtomicSub, {{{"a", asAddr}, {"v", asValue}}}, false},
168200
{"atomicsubi", &I::genAtomicSub, {{{"a", asAddr}, {"v", asValue}}}, false},
169201
{"atomicsubl", &I::genAtomicSub, {{{"a", asAddr}, {"v", asValue}}}, false},
202+
{"atomicxori", &I::genAtomicXor, {{{"a", asAddr}, {"v", asValue}}}, false},
170203
{"bessel_jn",
171204
&I::genBesselJn,
172205
{{{"n1", asValue}, {"n2", asValue}, {"x", asValue}}},
@@ -2691,6 +2724,22 @@ mlir::Value IntrinsicLibrary::genAtomicOr(mlir::Type resultType,
26912724
return genAtomBinOp(builder, loc, binOp, args[0], args[1]);
26922725
}
26932726

2727+
// ATOMICCAS
2728+
mlir::Value IntrinsicLibrary::genAtomicCas(mlir::Type resultType,
2729+
llvm::ArrayRef<mlir::Value> args) {
2730+
assert(args.size() == 3);
2731+
assert(args[1].getType() == args[2].getType());
2732+
auto successOrdering = mlir::LLVM::AtomicOrdering::acq_rel;
2733+
auto failureOrdering = mlir::LLVM::AtomicOrdering::monotonic;
2734+
auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get(resultType.getContext());
2735+
auto address =
2736+
builder.create<mlir::UnrealizedConversionCastOp>(loc, llvmPtrTy, args[0])
2737+
.getResult(0);
2738+
auto cmpxchg = builder.create<mlir::LLVM::AtomicCmpXchgOp>(
2739+
loc, address, args[1], args[2], successOrdering, failureOrdering);
2740+
return builder.create<mlir::LLVM::ExtractValueOp>(loc, cmpxchg, 1);
2741+
}
2742+
26942743
mlir::Value IntrinsicLibrary::genAtomicDec(mlir::Type resultType,
26952744
llvm::ArrayRef<mlir::Value> args) {
26962745
assert(args.size() == 2);
@@ -2700,6 +2749,16 @@ mlir::Value IntrinsicLibrary::genAtomicDec(mlir::Type resultType,
27002749
return genAtomBinOp(builder, loc, binOp, args[0], args[1]);
27012750
}
27022751

2752+
// ATOMICEXCH
2753+
mlir::Value IntrinsicLibrary::genAtomicExch(mlir::Type resultType,
2754+
llvm::ArrayRef<mlir::Value> args) {
2755+
assert(args.size() == 2);
2756+
assert(mlir::isa<mlir::IntegerType>(args[1].getType()));
2757+
2758+
mlir::LLVM::AtomicBinOp binOp = mlir::LLVM::AtomicBinOp::xchg;
2759+
return genAtomBinOp(builder, loc, binOp, args[0], args[1]);
2760+
}
2761+
27032762
mlir::Value IntrinsicLibrary::genAtomicInc(mlir::Type resultType,
27042763
llvm::ArrayRef<mlir::Value> args) {
27052764
assert(args.size() == 2);
@@ -2731,6 +2790,16 @@ mlir::Value IntrinsicLibrary::genAtomicMin(mlir::Type resultType,
27312790
return genAtomBinOp(builder, loc, binOp, args[0], args[1]);
27322791
}
27332792

2793+
// ATOMICXOR
2794+
mlir::Value IntrinsicLibrary::genAtomicXor(mlir::Type resultType,
2795+
llvm::ArrayRef<mlir::Value> args) {
2796+
assert(args.size() == 2);
2797+
assert(mlir::isa<mlir::IntegerType>(args[1].getType()));
2798+
2799+
mlir::LLVM::AtomicBinOp binOp = mlir::LLVM::AtomicBinOp::_xor;
2800+
return genAtomBinOp(builder, loc, binOp, args[0], args[1]);
2801+
}
2802+
27342803
// ASSOCIATED
27352804
fir::ExtendedValue
27362805
IntrinsicLibrary::genAssociated(mlir::Type resultType,

flang/module/cudadevice.f90

Lines changed: 107 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -557,59 +557,117 @@ attributes(device) pure integer function atomicdeci(address, val)
557557
end function
558558
end interface
559559

560+
interface atomiccas
561+
attributes(device) pure integer function atomiccasi(address, val, val2)
562+
!dir$ ignore_tkr (rd) address, (d) val, (d) val2
563+
integer, intent(inout) :: address
564+
integer, value :: val, val2
565+
end function
566+
attributes(device) pure integer(8) function atomiccasul(address, val, val2)
567+
!dir$ ignore_tkr (rd) address, (dk) val, (dk) val2
568+
integer(8), intent(inout) :: address
569+
integer(8), value :: val, val2
570+
end function
571+
attributes(device) pure real function atomiccasf(address, val, val2)
572+
!dir$ ignore_tkr (rd) address, (d) val, (d) val2
573+
real, intent(inout) :: address
574+
real, value :: val, val2
575+
end function
576+
attributes(device) pure double precision function atomiccasd(address, val, val2)
577+
!dir$ ignore_tkr (rd) address, (d) val, (d) val2
578+
double precision, intent(inout) :: address
579+
double precision, value :: val, val2
580+
end function
581+
end interface
582+
583+
interface atomicexch
584+
attributes(device) pure integer function atomicexchi(address, val)
585+
!dir$ ignore_tkr (rd) address, (d) val
586+
integer, intent(inout) :: address
587+
integer, value :: val
588+
end function
589+
attributes(device) pure integer(8) function atomicexchul(address, val)
590+
!dir$ ignore_tkr (rd) address, (dk) val
591+
integer(8), intent(inout) :: address
592+
integer(8), value :: val
593+
end function
594+
attributes(device) pure real function atomicexchf(address, val)
595+
!dir$ ignore_tkr (rd) address, (d) val
596+
real, intent(inout) :: address
597+
real, value :: val
598+
end function
599+
attributes(device) pure double precision function atomicexchd(address, val)
600+
!dir$ ignore_tkr (rd) address, (d) val
601+
double precision, intent(inout) :: address
602+
double precision, value :: val
603+
end function
604+
end interface
605+
606+
interface atomicxor
607+
attributes(device) pure integer function atomicxori(address, val)
608+
!dir$ ignore_tkr (rd) address, (d) val
609+
integer, intent(inout) :: address
610+
integer, value :: val
611+
end function
612+
end interface
613+
614+
! Time function
615+
560616
interface
561617
attributes(device) integer(8) function clock64()
562618
end function
563619
end interface
564620

565-
interface match_all_sync
566-
attributes(device) integer function match_all_syncjj(mask, val, pred)
567-
!dir$ ignore_tkr(d) mask, (d) val, (d) pred
568-
integer(4), value :: mask
569-
integer(4), value :: val
570-
integer(4) :: pred
571-
end function
572-
attributes(device) integer function match_all_syncjx(mask, val, pred)
573-
!dir$ ignore_tkr(d) mask, (d) val, (d) pred
574-
integer(4), value :: mask
575-
integer(8), value :: val
576-
integer(4) :: pred
577-
end function
578-
attributes(device) integer function match_all_syncjf(mask, val, pred)
579-
!dir$ ignore_tkr(d) mask, (d) val, (d) pred
580-
integer(4), value :: mask
581-
real(4), value :: val
582-
integer(4) :: pred
583-
end function
584-
attributes(device) integer function match_all_syncjd(mask, val, pred)
585-
!dir$ ignore_tkr(d) mask, (d) val, (d) pred
586-
integer(4), value :: mask
587-
real(8), value :: val
588-
integer(4) :: pred
589-
end function
590-
end interface
591-
592-
interface match_any_sync
593-
attributes(device) integer function match_any_syncjj(mask, val)
594-
!dir$ ignore_tkr(d) mask, (d) val
595-
integer(4), value :: mask
596-
integer(4), value :: val
597-
end function
598-
attributes(device) integer function match_any_syncjx(mask, val)
599-
!dir$ ignore_tkr(d) mask, (d) val
600-
integer(4), value :: mask
601-
integer(8), value :: val
602-
end function
603-
attributes(device) integer function match_any_syncjf(mask, val)
604-
!dir$ ignore_tkr(d) mask, (d) val
605-
integer(4), value :: mask
606-
real(4), value :: val
607-
end function
608-
attributes(device) integer function match_any_syncjd(mask, val)
609-
!dir$ ignore_tkr(d) mask, (d) val
610-
integer(4), value :: mask
611-
real(8), value :: val
612-
end function
613-
end interface
621+
! Warp Match Functions
622+
623+
interface match_all_sync
624+
attributes(device) integer function match_all_syncjj(mask, val, pred)
625+
!dir$ ignore_tkr(d) mask, (d) val, (d) pred
626+
integer(4), value :: mask
627+
integer(4), value :: val
628+
integer(4) :: pred
629+
end function
630+
attributes(device) integer function match_all_syncjx(mask, val, pred)
631+
!dir$ ignore_tkr(d) mask, (d) val, (d) pred
632+
integer(4), value :: mask
633+
integer(8), value :: val
634+
integer(4) :: pred
635+
end function
636+
attributes(device) integer function match_all_syncjf(mask, val, pred)
637+
!dir$ ignore_tkr(d) mask, (d) val, (d) pred
638+
integer(4), value :: mask
639+
real(4), value :: val
640+
integer(4) :: pred
641+
end function
642+
attributes(device) integer function match_all_syncjd(mask, val, pred)
643+
!dir$ ignore_tkr(d) mask, (d) val, (d) pred
644+
integer(4), value :: mask
645+
real(8), value :: val
646+
integer(4) :: pred
647+
end function
648+
end interface
649+
650+
interface match_any_sync
651+
attributes(device) integer function match_any_syncjj(mask, val)
652+
!dir$ ignore_tkr(d) mask, (d) val
653+
integer(4), value :: mask
654+
integer(4), value :: val
655+
end function
656+
attributes(device) integer function match_any_syncjx(mask, val)
657+
!dir$ ignore_tkr(d) mask, (d) val
658+
integer(4), value :: mask
659+
integer(8), value :: val
660+
end function
661+
attributes(device) integer function match_any_syncjf(mask, val)
662+
!dir$ ignore_tkr(d) mask, (d) val
663+
integer(4), value :: mask
664+
real(4), value :: val
665+
end function
666+
attributes(device) integer function match_any_syncjd(mask, val)
667+
!dir$ ignore_tkr(d) mask, (d) val
668+
integer(4), value :: mask
669+
real(8), value :: val
670+
end function
671+
end interface
614672

615673
end module

flang/test/Lower/CUDA/cuda-device-proc.cuf

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -150,15 +150,15 @@ end subroutine
150150
! CHECK: fir.convert %{{.*}} : (f64) -> i64
151151
! CHECK: fir.call @llvm.nvvm.match.any.sync.i64p
152152

153-
! CHECK: func.func private @llvm.nvvm.barrier0()
154-
! CHECK: func.func private @llvm.nvvm.bar.warp.sync(i32)
155-
! CHECK: func.func private @llvm.nvvm.membar.gl()
156-
! CHECK: func.func private @llvm.nvvm.membar.cta()
157-
! CHECK: func.func private @llvm.nvvm.membar.sys()
158-
! CHECK: func.func private @llvm.nvvm.barrier0.and(i32) -> i32
159-
! CHECK: func.func private @llvm.nvvm.barrier0.popc(i32) -> i32
160-
! CHECK: func.func private @llvm.nvvm.barrier0.or(i32) -> i32
161-
! CHECK: func.func private @llvm.nvvm.match.all.sync.i32p(i32, i32) -> tuple<i32, i1>
162-
! CHECK: func.func private @llvm.nvvm.match.all.sync.i64p(i32, i64) -> tuple<i32, i1>
163-
! CHECK: func.func private @llvm.nvvm.match.any.sync.i32p(i32, i32) -> i32
164-
! CHECK: func.func private @llvm.nvvm.match.any.sync.i64p(i32, i64) -> i32
153+
attributes(device) subroutine testAtomic()
154+
integer :: a, istat, j
155+
istat = atomicexch(a,0)
156+
istat = atomicxor(a, j)
157+
istat = atomiccas(a, i, 14)
158+
end subroutine
159+
160+
! CHECK-LABEL: func.func @_QPtestatomic()
161+
! CHECK: llvm.atomicrmw xchg %{{.*}}, %c0{{.*}} seq_cst : !llvm.ptr, i32
162+
! CHECK: llvm.atomicrmw _xor %{{.*}}, %{{.*}} seq_cst : !llvm.ptr, i32
163+
! CHECK: %[[ADDR:.*]] = builtin.unrealized_conversion_cast %{{.*}}#1 : !fir.ref<i32> to !llvm.ptr
164+
! CHECK: llvm.cmpxchg %[[ADDR]], %{{.*}}, %c14{{.*}} acq_rel monotonic : !llvm.ptr, i32

0 commit comments

Comments
 (0)