Skip to content

Commit 447af1c

Browse files
authored
[flang][openacc][openmp] Update stride computation for bounds (#72168)
This patch updates the stride computation for the outer dimensions of multidimensional arrays where the stride is read from the descriptor. For the inner dimension, the stride is the element size in bytes. Then it is multiplied by the n-1 extent for outer dimensions.
1 parent a409002 commit 447af1c

File tree

2 files changed

+30
-3
lines changed

2 files changed

+30
-3
lines changed

flang/lib/Lower/DirectivesCommon.h

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,7 @@ genBoundsOpsFromBox(fir::FirOpBuilder &builder, mlir::Location loc,
597597
mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
598598
assert(box.getType().isa<fir::BaseBoxType>() &&
599599
"expect fir.box or fir.class");
600+
mlir::Value byteStride;
600601
for (unsigned dim = 0; dim < dataExv.rank(); ++dim) {
601602
mlir::Value d = builder.createIntegerConstant(loc, idxTy, dim);
602603
mlir::Value baseLb =
@@ -606,9 +607,13 @@ genBoundsOpsFromBox(fir::FirOpBuilder &builder, mlir::Location loc,
606607
mlir::Value lb = builder.createIntegerConstant(loc, idxTy, 0);
607608
mlir::Value ub =
608609
builder.create<mlir::arith::SubIOp>(loc, dimInfo.getExtent(), one);
609-
mlir::Value bound =
610-
builder.create<BoundsOp>(loc, boundTy, lb, ub, mlir::Value(),
611-
dimInfo.getByteStride(), true, baseLb);
610+
if (dim == 0) // First stride is the element size.
611+
byteStride = dimInfo.getByteStride();
612+
mlir::Value bound = builder.create<BoundsOp>(
613+
loc, boundTy, lb, ub, mlir::Value(), byteStride, true, baseLb);
614+
// Compute the stride for the next dimension.
615+
byteStride = builder.create<mlir::arith::MulIOp>(loc, byteStride,
616+
dimInfo.getExtent());
612617
bounds.push_back(bound);
613618
}
614619
return bounds;

flang/test/Lower/OpenACC/acc-bounds.f90

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,4 +102,26 @@ subroutine acc_undefined_extent(a)
102102
! HLFIR: %[[PRESENT:.*]] = acc.present varPtr(%[[DECL_ARG0]]#1 : !fir.ref<!fir.array<?xf32>>) bounds(%[[BOUND]]) -> !fir.ref<!fir.array<?xf32>> {name = "a"}
103103
! CHECK: acc.kernels dataOperands(%[[PRESENT]] : !fir.ref<!fir.array<?xf32>>)
104104

105+
subroutine acc_multi_strides(a)
106+
real, dimension(:,:,:) :: a
107+
108+
!$acc kernels present(a)
109+
!$acc end kernels
110+
end subroutine
111+
112+
! CHECK-LABEL: func.func @_QMopenacc_boundsPacc_multi_strides(
113+
! CHECK-SAME: %[[ARG0:.*]]: !fir.box<!fir.array<?x?x?xf32>> {fir.bindc_name = "a"})
114+
! HLFIR: %[[DECL_ARG0:.*]]:2 = hlfir.declare %[[ARG0]] {uniq_name = "_QMopenacc_boundsFacc_multi_stridesEa"} : (!fir.box<!fir.array<?x?x?xf32>>) -> (!fir.box<!fir.array<?x?x?xf32>>, !fir.box<!fir.array<?x?x?xf32>>)
115+
! HLFIR: %[[BOX_DIMS0:.*]]:3 = fir.box_dims %[[DECL_ARG0]]#1, %c0{{.*}} : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)
116+
! HLFIR: %[[BOUNDS0:.*]] = acc.bounds lowerbound(%{{.*}} : index) upperbound(%{{.*}} : index) stride(%[[BOX_DIMS0]]#2 : index) startIdx(%{{.*}} : index) {strideInBytes = true}
117+
! HLFIR: %[[STRIDE1:.*]] = arith.muli %[[BOX_DIMS0]]#2, %[[BOX_DIMS0]]#1 : index
118+
! HLFIR: %[[BOX_DIMS1:.*]]:3 = fir.box_dims %[[DECL_ARG0]]#1, %c1{{.*}} : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)
119+
! HLFIR: %[[BOUNDS1:.*]] = acc.bounds lowerbound(%{{.*}} : index) upperbound(%{{.*}} : index) stride(%[[STRIDE1]] : index) startIdx(%{{.*}} : index) {strideInBytes = true}
120+
! HLFIR: %[[STRIDE2:.*]] = arith.muli %[[STRIDE1]], %[[BOX_DIMS1]]#1 : index
121+
! HLFIR: %[[BOX_DIMS2:.*]]:3 = fir.box_dims %[[DECL_ARG0]]#1, %c2{{.*}} : (!fir.box<!fir.array<?x?x?xf32>>, index) -> (index, index, index)
122+
! HLFIR: %[[BOUNDS2:.*]] = acc.bounds lowerbound(%{{.*}} : index) upperbound(%{{.*}} : index) stride(%[[STRIDE2]] : index) startIdx(%{{.*}} : index) {strideInBytes = true}
123+
! HLFIR: %[[BOX_ADDR:.*]] = fir.box_addr %[[DECL_ARG0]]#1 : (!fir.box<!fir.array<?x?x?xf32>>) -> !fir.ref<!fir.array<?x?x?xf32>>
124+
! HLFIR: %[[PRESENT:.*]] = acc.present varPtr(%[[BOX_ADDR]] : !fir.ref<!fir.array<?x?x?xf32>>) bounds(%29, %33, %37) -> !fir.ref<!fir.array<?x?x?xf32>> {name = "a"}
125+
! HLFIR: acc.kernels dataOperands(%[[PRESENT]] : !fir.ref<!fir.array<?x?x?xf32>>) {
126+
105127
end module

0 commit comments

Comments
 (0)