Skip to content

[flang][acc] Generate acc.copyout for the reduction clause on compute constructs #144623

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion flang/lib/Lower/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2676,7 +2676,8 @@ static Op createComputeOp(
llvm::SmallVector<mlir::Value> waitOperands, attachEntryOperands,
copyEntryOperands, copyinEntryOperands, copyoutEntryOperands,
createEntryOperands, nocreateEntryOperands, presentEntryOperands,
dataClauseOperands, numGangs, numWorkers, vectorLength, async;
reductionEntryOperands, dataClauseOperands, numGangs, numWorkers,
vectorLength, async;
llvm::SmallVector<mlir::Attribute> numGangsDeviceTypes, numWorkersDeviceTypes,
vectorLengthDeviceTypes, asyncDeviceTypes, asyncOnlyDeviceTypes,
waitOperandsDeviceTypes, waitOnlyDeviceTypes;
Expand Down Expand Up @@ -2912,9 +2913,12 @@ static Op createComputeOp(
// combined construct implies a copy clause so issue an implicit copy
// instead.
if (!combinedConstructs) {
auto crtDataStart = reductionOperands.size();
genReductions(reductionClause->v, converter, semanticsContext, stmtCtx,
reductionOperands, reductionRecipes, async,
asyncDeviceTypes, asyncOnlyDeviceTypes);
reductionEntryOperands.append(reductionOperands.begin() + crtDataStart,
reductionOperands.end());
} else {
auto crtDataStart = dataClauseOperands.size();
genDataOperandOperations<mlir::acc::CopyinOp>(
Expand Down Expand Up @@ -3024,6 +3028,8 @@ static Op createComputeOp(
builder.setInsertionPointAfter(computeOp);

// Create the exit operations after the region.
genDataExitOperations<mlir::acc::ReductionOp, mlir::acc::CopyoutOp>(
builder, reductionEntryOperands, /*structured=*/true);
genDataExitOperations<mlir::acc::CopyinOp, mlir::acc::CopyoutOp>(
builder, copyEntryOperands, /*structured=*/true);
genDataExitOperations<mlir::acc::CopyinOp, mlir::acc::DeleteOp>(
Expand Down
20 changes: 19 additions & 1 deletion flang/test/Lower/OpenACC/acc-reduction-unwrap-defaultbounds.f90
Original file line number Diff line number Diff line change
Expand Up @@ -1001,6 +1001,7 @@ subroutine acc_reduction_iand()
! CHECK-LABEL: func.func @_QPacc_reduction_iand()
! CHECK: %[[RED:.*]] = acc.reduction varPtr(%{{.*}} : !fir.ref<i32>) -> !fir.ref<i32> {name = "i"}
! CHECK: acc.parallel reduction(@reduction_iand_ref_i32 -> %[[RED]] : !fir.ref<i32>)
! CHECK: acc.copyout accPtr(%[[RED]] : !fir.ref<i32>) to varPtr(%{{.*}} : !fir.ref<i32>) {dataClause = #acc<data_clause acc_reduction>, name = "i"}

subroutine acc_reduction_ior()
integer :: i
Expand All @@ -1011,6 +1012,7 @@ subroutine acc_reduction_ior()
! CHECK-LABEL: func.func @_QPacc_reduction_ior()
! CHECK: %[[RED:.*]] = acc.reduction varPtr(%{{.*}} : !fir.ref<i32>) -> !fir.ref<i32> {name = "i"}
! CHECK: acc.parallel reduction(@reduction_ior_ref_i32 -> %[[RED]] : !fir.ref<i32>)
! CHECK: acc.copyout accPtr(%[[RED]] : !fir.ref<i32>) to varPtr(%{{.*}} : !fir.ref<i32>) {dataClause = #acc<data_clause acc_reduction>, name = "i"}

subroutine acc_reduction_ieor()
integer :: i
Expand All @@ -1021,6 +1023,7 @@ subroutine acc_reduction_ieor()
! CHECK-LABEL: func.func @_QPacc_reduction_ieor()
! CHECK: %[[RED:.*]] = acc.reduction varPtr(%{{.*}} : !fir.ref<i32>) -> !fir.ref<i32> {name = "i"}
! CHECK: acc.parallel reduction(@reduction_xor_ref_i32 -> %[[RED]] : !fir.ref<i32>)
! CHECK: acc.copyout accPtr(%[[RED]] : !fir.ref<i32>) to varPtr(%{{.*}} : !fir.ref<i32>) {dataClause = #acc<data_clause acc_reduction>, name = "i"}

subroutine acc_reduction_and()
logical :: l
Expand All @@ -1033,6 +1036,7 @@ subroutine acc_reduction_and()
! CHECK: %[[DECLL:.*]]:2 = hlfir.declare %[[L]]
! CHECK: %[[RED:.*]] = acc.reduction varPtr(%[[DECLL]]#0 : !fir.ref<!fir.logical<4>>) -> !fir.ref<!fir.logical<4>> {name = "l"}
! CHECK: acc.parallel reduction(@reduction_land_ref_l32 -> %[[RED]] : !fir.ref<!fir.logical<4>>)
! CHECK: acc.copyout accPtr(%[[RED]] : !fir.ref<!fir.logical<4>>) to varPtr(%{{.*}} : !fir.ref<!fir.logical<4>>) {dataClause = #acc<data_clause acc_reduction>, name = "l"}

subroutine acc_reduction_or()
logical :: l
Expand All @@ -1043,6 +1047,7 @@ subroutine acc_reduction_or()
! CHECK-LABEL: func.func @_QPacc_reduction_or()
! CHECK: %[[RED:.*]] = acc.reduction varPtr(%{{.*}} : !fir.ref<!fir.logical<4>>) -> !fir.ref<!fir.logical<4>> {name = "l"}
! CHECK: acc.parallel reduction(@reduction_lor_ref_l32 -> %[[RED]] : !fir.ref<!fir.logical<4>>)
! CHECK: acc.copyout accPtr(%[[RED]] : !fir.ref<!fir.logical<4>>) to varPtr(%{{.*}} : !fir.ref<!fir.logical<4>>) {dataClause = #acc<data_clause acc_reduction>, name = "l"}

subroutine acc_reduction_eqv()
logical :: l
Expand All @@ -1053,6 +1058,7 @@ subroutine acc_reduction_eqv()
! CHECK-LABEL: func.func @_QPacc_reduction_eqv()
! CHECK: %[[RED:.*]] = acc.reduction varPtr(%{{.*}} : !fir.ref<!fir.logical<4>>) -> !fir.ref<!fir.logical<4>> {name = "l"}
! CHECK: acc.parallel reduction(@reduction_eqv_ref_l32 -> %[[RED]] : !fir.ref<!fir.logical<4>>)
! CHECK: acc.copyout accPtr(%[[RED]] : !fir.ref<!fir.logical<4>>) to varPtr(%{{.*}} : !fir.ref<!fir.logical<4>>) {dataClause = #acc<data_clause acc_reduction>, name = "l"}

subroutine acc_reduction_neqv()
logical :: l
Expand All @@ -1063,6 +1069,7 @@ subroutine acc_reduction_neqv()
! CHECK-LABEL: func.func @_QPacc_reduction_neqv()
! CHECK: %[[RED:.*]] = acc.reduction varPtr(%{{.*}} : !fir.ref<!fir.logical<4>>) -> !fir.ref<!fir.logical<4>> {name = "l"}
! CHECK: acc.parallel reduction(@reduction_neqv_ref_l32 -> %[[RED]] : !fir.ref<!fir.logical<4>>)
! CHECK: acc.copyout accPtr(%[[RED]] : !fir.ref<!fir.logical<4>>) to varPtr(%{{.*}} : !fir.ref<!fir.logical<4>>) {dataClause = #acc<data_clause acc_reduction>, name = "l"}

subroutine acc_reduction_add_cmplx()
complex :: c
Expand All @@ -1073,6 +1080,7 @@ subroutine acc_reduction_add_cmplx()
! CHECK-LABEL: func.func @_QPacc_reduction_add_cmplx()
! CHECK: %[[RED:.*]] = acc.reduction varPtr(%{{.*}} : !fir.ref<complex<f32>>) -> !fir.ref<complex<f32>> {name = "c"}
! CHECK: acc.parallel reduction(@reduction_add_ref_z32 -> %[[RED]] : !fir.ref<complex<f32>>)
! CHECK: acc.copyout accPtr(%[[RED]] : !fir.ref<complex<f32>>) to varPtr(%{{.*}} : !fir.ref<complex<f32>>) {dataClause = #acc<data_clause acc_reduction>, name = "c"}

subroutine acc_reduction_mul_cmplx()
complex :: c
Expand All @@ -1083,6 +1091,7 @@ subroutine acc_reduction_mul_cmplx()
! CHECK-LABEL: func.func @_QPacc_reduction_mul_cmplx()
! CHECK: %[[RED:.*]] = acc.reduction varPtr(%{{.*}} : !fir.ref<complex<f32>>) -> !fir.ref<complex<f32>> {name = "c"}
! CHECK: acc.parallel reduction(@reduction_mul_ref_z32 -> %[[RED]] : !fir.ref<complex<f32>>)
! CHECK: acc.copyout accPtr(%[[RED]] : !fir.ref<complex<f32>>) to varPtr(%{{.*}} : !fir.ref<complex<f32>>) {dataClause = #acc<data_clause acc_reduction>, name = "c"}

subroutine acc_reduction_add_alloc()
integer, allocatable :: i
Expand All @@ -1098,6 +1107,7 @@ subroutine acc_reduction_add_alloc()
! CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[LOAD]] : (!fir.box<!fir.heap<i32>>) -> !fir.heap<i32>
! CHECK: %[[RED:.*]] = acc.reduction varPtr(%[[BOX_ADDR]] : !fir.heap<i32>) -> !fir.heap<i32> {name = "i"}
! CHECK: acc.parallel reduction(@reduction_add_heap_i32 -> %[[RED]] : !fir.heap<i32>)
! CHECK: acc.copyout accPtr(%[[RED]] : !fir.heap<i32>) to varPtr(%[[BOX_ADDR]] : !fir.heap<i32>) {dataClause = #acc<data_clause acc_reduction>, name = "i"}

subroutine acc_reduction_add_pointer(i)
integer, pointer :: i
Expand All @@ -1112,6 +1122,7 @@ subroutine acc_reduction_add_pointer(i)
! CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[LOAD]] : (!fir.box<!fir.ptr<i32>>) -> !fir.ptr<i32>
! CHECK: %[[RED:.*]] = acc.reduction varPtr(%[[BOX_ADDR]] : !fir.ptr<i32>) -> !fir.ptr<i32> {name = "i"}
! CHECK: acc.parallel reduction(@reduction_add_ptr_i32 -> %[[RED]] : !fir.ptr<i32>)
! CHECK: acc.copyout accPtr(%[[RED]] : !fir.ptr<i32>) to varPtr(%[[BOX_ADDR]] : !fir.ptr<i32>) {dataClause = #acc<data_clause acc_reduction>, name = "i"}

subroutine acc_reduction_add_static_slice(a)
integer :: a(100)
Expand All @@ -1129,6 +1140,7 @@ subroutine acc_reduction_add_static_slice(a)
! CHECK: %[[BOUND:.*]] = acc.bounds lowerbound(%[[LB]] : index) upperbound(%[[UB]] : index) extent(%[[C100]] : index) stride(%[[C1]] : index) startIdx(%[[C1]] : index)
! CHECK: %[[RED:.*]] = acc.reduction varPtr(%[[DECLARG0]]#0 : !fir.ref<!fir.array<100xi32>>) bounds(%[[BOUND]]) -> !fir.ref<!fir.array<100xi32>> {name = "a(11:20)"}
! CHECK: acc.parallel reduction(@reduction_add_section_lb10.ub19_ref_100xi32 -> %[[RED]] : !fir.ref<!fir.array<100xi32>>)
! CHECK: acc.copyout accPtr(%[[RED]] : !fir.ref<!fir.array<100xi32>>) bounds(%[[BOUND]]) to varPtr(%[[DECLARG0]]#0 : !fir.ref<!fir.array<100xi32>>) {dataClause = #acc<data_clause acc_reduction>, name = "a(11:20)"}

subroutine acc_reduction_add_dynamic_extent_add(a)
integer :: a(:)
Expand All @@ -1141,6 +1153,7 @@ subroutine acc_reduction_add_dynamic_extent_add(a)
! CHECK: %[[DECLARG0:.*]]:2 = hlfir.declare %[[ARG0]]
! CHECK: %[[RED:.*]] = acc.reduction varPtr(%{{.*}} : !fir.ref<!fir.array<?xi32>>) bounds(%{{.*}}) -> !fir.ref<!fir.array<?xi32>> {name = "a"}
! CHECK: acc.parallel reduction(@reduction_add_box_Uxi32 -> %[[RED:.*]] : !fir.ref<!fir.array<?xi32>>)
! CHECK: acc.copyout accPtr(%[[RED]] : !fir.ref<!fir.array<?xi32>>) bounds(%{{.*}}) to varPtr(%{{.*}} : !fir.ref<!fir.array<?xi32>>) {dataClause = #acc<data_clause acc_reduction>, name = "a"}

subroutine acc_reduction_add_assumed_shape_max(a)
real :: a(:)
Expand All @@ -1153,6 +1166,7 @@ subroutine acc_reduction_add_assumed_shape_max(a)
! CHECK: %[[DECLARG0:.*]]:2 = hlfir.declare %[[ARG0]]
! CHECK: %[[RED:.*]] = acc.reduction varPtr(%{{.*}} : !fir.ref<!fir.array<?xf32>>) bounds(%{{.*}}) -> !fir.ref<!fir.array<?xf32>> {name = "a"}
! CHECK: acc.parallel reduction(@reduction_max_box_Uxf32 -> %[[RED]] : !fir.ref<!fir.array<?xf32>>) {
! CHECK: acc.copyout accPtr(%[[RED]] : !fir.ref<!fir.array<?xf32>>) bounds(%{{.*}}) to varPtr(%{{.*}} : !fir.ref<!fir.array<?xf32>>) {dataClause = #acc<data_clause acc_reduction>, name = "a"}

subroutine acc_reduction_add_dynamic_extent_add_with_section(a)
integer :: a(:)
Expand All @@ -1167,6 +1181,7 @@ subroutine acc_reduction_add_dynamic_extent_add_with_section(a)
! CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[DECL]]#0 : (!fir.box<!fir.array<?xi32>>) -> !fir.ref<!fir.array<?xi32>>
! CHECK: %[[RED:.*]] = acc.reduction varPtr(%[[BOX_ADDR]] : !fir.ref<!fir.array<?xi32>>) bounds(%[[BOUND]]) -> !fir.ref<!fir.array<?xi32>> {name = "a(2:4)"}
! CHECK: acc.parallel reduction(@reduction_add_section_lb1.ub3_box_Uxi32 -> %[[RED]] : !fir.ref<!fir.array<?xi32>>)
! CHECK: acc.copyout accPtr(%[[RED]] : !fir.ref<!fir.array<?xi32>>) bounds(%[[BOUND]]) to varPtr(%[[BOX_ADDR]] : !fir.ref<!fir.array<?xi32>>) {dataClause = #acc<data_clause acc_reduction>, name = "a(2:4)"}

subroutine acc_reduction_add_allocatable(a)
real, allocatable :: a(:)
Expand All @@ -1180,8 +1195,9 @@ subroutine acc_reduction_add_allocatable(a)
! CHECK: %[[BOX:.*]] = fir.load %[[DECL]]#0 : !fir.ref<!fir.box<!fir.heap<!fir.array<?xf32>>>>
! CHECK: %[[BOUND:.*]] = acc.bounds lowerbound(%c0{{.*}} : index) upperbound(%{{.*}} : index) extent(%{{.*}}#1 : index) stride(%{{.*}}#2 : index) startIdx(%{{.*}}#0 : index) {strideInBytes = true}
! CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[BOX]] : (!fir.box<!fir.heap<!fir.array<?xf32>>>) -> !fir.heap<!fir.array<?xf32>>
! CHECK: %[[RED:.*]] = acc.reduction varPtr(%[[BOX_ADDR]] : !fir.heap<!fir.array<?xf32>>) bounds(%{{[0-9]+}}) -> !fir.heap<!fir.array<?xf32>> {name = "a"}
! CHECK: %[[RED:.*]] = acc.reduction varPtr(%[[BOX_ADDR]] : !fir.heap<!fir.array<?xf32>>) bounds(%[[BOUND]]) -> !fir.heap<!fir.array<?xf32>> {name = "a"}
! CHECK: acc.parallel reduction(@reduction_max_box_heap_Uxf32 -> %[[RED]] : !fir.heap<!fir.array<?xf32>>)
! CHECK: acc.copyout accPtr(%[[RED]] : !fir.heap<!fir.array<?xf32>>) bounds(%[[BOUND]]) to varPtr(%[[BOX_ADDR]] : !fir.heap<!fir.array<?xf32>>) {dataClause = #acc<data_clause acc_reduction>, name = "a"}

subroutine acc_reduction_add_pointer_array(a)
real, pointer :: a(:)
Expand All @@ -1197,6 +1213,7 @@ subroutine acc_reduction_add_pointer_array(a)
! CHECK: %[[BOX_ADDR:.*]] = fir.box_addr %[[BOX]] : (!fir.box<!fir.ptr<!fir.array<?xf32>>>) -> !fir.ptr<!fir.array<?xf32>>
! CHECK: %[[RED:.*]] = acc.reduction varPtr(%[[BOX_ADDR]] : !fir.ptr<!fir.array<?xf32>>) bounds(%[[BOUND]]) -> !fir.ptr<!fir.array<?xf32>> {name = "a"}
! CHECK: acc.parallel reduction(@reduction_max_box_ptr_Uxf32 -> %[[RED]] : !fir.ptr<!fir.array<?xf32>>)
! CHECK: acc.copyout accPtr(%[[RED]] : !fir.ptr<!fir.array<?xf32>>) bounds(%[[BOUND]]) to varPtr(%[[BOX_ADDR]] : !fir.ptr<!fir.array<?xf32>>) {dataClause = #acc<data_clause acc_reduction>, name = "a"}

subroutine acc_reduction_max_dynamic_extent_max(a, n)
integer :: n
Expand All @@ -1211,3 +1228,4 @@ subroutine acc_reduction_max_dynamic_extent_max(a, n)
! CHECK: %[[ADDR:.*]] = fir.box_addr %[[DECL_A]]#0 : (!fir.box<!fir.array<?x?xf32>>) -> !fir.ref<!fir.array<?x?xf32>>
! CHECK: %[[RED:.*]] = acc.reduction varPtr(%[[ADDR]] : !fir.ref<!fir.array<?x?xf32>>) bounds(%{{.*}}, %{{.*}}) -> !fir.ref<!fir.array<?x?xf32>> {name = "a"}
! CHECK: acc.parallel reduction(@reduction_max_box_UxUxf32 -> %[[RED]] : !fir.ref<!fir.array<?x?xf32>>)
! CHECK: acc.copyout accPtr(%[[RED]] : !fir.ref<!fir.array<?x?xf32>>) bounds(%{{.*}}) to varPtr(%{{.*}} : !fir.ref<!fir.array<?x?xf32>>) {dataClause = #acc<data_clause acc_reduction>, name = "a"}
Loading
Loading