Skip to content

Commit b20e063

Browse files
[flang][acc] Generate acc.bounds operation from FIR shape (#136637)
This PR adds support to be able to generate `acc.bounds` operation through `MappableType`'s `generateAccBounds` when there is no fir.box entity. This is especially useful because the FIR type does not capture size information for explicit-shape arrays and current implementation relied on finding the box entity. This scenario is possible because during HLFIRtoFIR, `fir.array_coor` and `fir.box_addr` operations are often optimized to use raw address. If one tries to map the ssa value that represents such a variable, correct dimensions need extracted from the shape information held in the fir declare operation.
1 parent 4dbf67d commit b20e063

File tree

2 files changed

+134
-9
lines changed

2 files changed

+134
-9
lines changed

flang/lib/Optimizer/OpenACC/FIROpenACCTypeInterfaces.cpp

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,77 @@ OpenACCMappableModel<fir::SequenceType>::generateAccBounds(
188188
mlir::acc::DataBoundsType>(
189189
firBuilder, loc, exv, info);
190190
}
191+
192+
if (mlir::isa<hlfir::DeclareOp, fir::DeclareOp>(varPtr.getDefiningOp())) {
193+
mlir::Value zero =
194+
firBuilder.createIntegerConstant(loc, builder.getIndexType(), 0);
195+
mlir::Value one =
196+
firBuilder.createIntegerConstant(loc, builder.getIndexType(), 1);
197+
198+
mlir::Value shape;
199+
if (auto declareOp =
200+
mlir::dyn_cast_if_present<fir::DeclareOp>(varPtr.getDefiningOp()))
201+
shape = declareOp.getShape();
202+
else if (auto declareOp = mlir::dyn_cast_if_present<hlfir::DeclareOp>(
203+
varPtr.getDefiningOp()))
204+
shape = declareOp.getShape();
205+
206+
const bool strideIncludeLowerExtent = true;
207+
208+
llvm::SmallVector<mlir::Value> accBounds;
209+
if (auto shapeOp =
210+
mlir::dyn_cast_if_present<fir::ShapeOp>(shape.getDefiningOp())) {
211+
mlir::Value cummulativeExtent = one;
212+
for (auto extent : shapeOp.getExtents()) {
213+
mlir::Value upperbound =
214+
builder.create<mlir::arith::SubIOp>(loc, extent, one);
215+
mlir::Value stride = one;
216+
if (strideIncludeLowerExtent) {
217+
stride = cummulativeExtent;
218+
cummulativeExtent = builder.create<mlir::arith::MulIOp>(
219+
loc, cummulativeExtent, extent);
220+
}
221+
auto accBound = builder.create<mlir::acc::DataBoundsOp>(
222+
loc, mlir::acc::DataBoundsType::get(builder.getContext()),
223+
/*lowerbound=*/zero, /*upperbound=*/upperbound,
224+
/*extent=*/extent, /*stride=*/stride, /*strideInBytes=*/false,
225+
/*startIdx=*/one);
226+
accBounds.push_back(accBound);
227+
}
228+
} else if (auto shapeShiftOp =
229+
mlir::dyn_cast_if_present<fir::ShapeShiftOp>(
230+
shape.getDefiningOp())) {
231+
mlir::Value lowerbound;
232+
mlir::Value cummulativeExtent = one;
233+
for (auto [idx, val] : llvm::enumerate(shapeShiftOp.getPairs())) {
234+
if (idx % 2 == 0) {
235+
lowerbound = val;
236+
} else {
237+
mlir::Value extent = val;
238+
mlir::Value upperbound =
239+
builder.create<mlir::arith::SubIOp>(loc, extent, one);
240+
upperbound = builder.create<mlir::arith::AddIOp>(loc, lowerbound,
241+
upperbound);
242+
mlir::Value stride = one;
243+
if (strideIncludeLowerExtent) {
244+
stride = cummulativeExtent;
245+
cummulativeExtent = builder.create<mlir::arith::MulIOp>(
246+
loc, cummulativeExtent, extent);
247+
}
248+
auto accBound = builder.create<mlir::acc::DataBoundsOp>(
249+
loc, mlir::acc::DataBoundsType::get(builder.getContext()),
250+
/*lowerbound=*/zero, /*upperbound=*/upperbound,
251+
/*extent=*/extent, /*stride=*/stride, /*strideInBytes=*/false,
252+
/*startIdx=*/lowerbound);
253+
accBounds.push_back(accBound);
254+
}
255+
}
256+
}
257+
258+
if (!accBounds.empty())
259+
return accBounds;
260+
}
261+
191262
assert(false && "array with unknown dimension expected to have descriptor");
192263
return {};
193264
}

flang/test/Fir/OpenACC/openacc-mappable.fir

Lines changed: 63 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// RUN: fir-opt %s -pass-pipeline='builtin.module(test-fir-openacc-interfaces)' -split-input-file --mlir-disable-threading 2>&1 | FileCheck %s
33

44
module attributes {dlti.dl_spec = #dlti.dl_spec<f16 = dense<16> : vector<2xi64>, f128 = dense<128> : vector<2xi64>, !llvm.ptr<270> = dense<32> : vector<4xi64>, f64 = dense<64> : vector<2xi64>, !llvm.ptr<271> = dense<32> : vector<4xi64>, !llvm.ptr<272> = dense<64> : vector<4xi64>, i64 = dense<64> : vector<2xi64>, i128 = dense<128> : vector<2xi64>, f80 = dense<128> : vector<2xi64>, !llvm.ptr = dense<64> : vector<4xi64>, i1 = dense<8> : vector<2xi64>, i8 = dense<8> : vector<2xi64>, i16 = dense<16> : vector<2xi64>, i32 = dense<32> : vector<2xi64>, "dlti.endianness" = "little", "dlti.stack_alignment" = 128 : i64>, fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"} {
5+
// This test exercises explicit-shape local array of form "arr(2:10)"
56
func.func @_QPsub() {
67
%c2 = arith.constant 2 : index
78
%c10 = arith.constant 10 : index
@@ -15,13 +16,66 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<f16 = dense<16> : vector<2xi64>,
1516
acc.enter_data dataOperands(%5, %6 : !fir.box<!fir.array<10xf32>>, !fir.ref<!fir.array<10xf32>>)
1617
return
1718
}
18-
}
1919

20-
// CHECK: Visiting: %{{.*}} = acc.copyin var(%{{.*}} : !fir.box<!fir.array<10xf32>>) -> !fir.box<!fir.array<10xf32>> {name = "arr", structured = false}
21-
// CHECK: Mappable: !fir.box<!fir.array<10xf32>>
22-
// CHECK: Type category: array
23-
// CHECK: Size: 40
24-
// CHECK: Visiting: %{{.*}} = acc.copyin varPtr(%{{.*}} : !fir.ref<!fir.array<10xf32>>) -> !fir.ref<!fir.array<10xf32>> {name = "arr", structured = false}
25-
// CHECK: Mappable: !fir.array<10xf32>
26-
// CHECK: Type category: array
27-
// CHECK: Size: 40
20+
// CHECK: Visiting: %{{.*}} = acc.copyin var(%{{.*}} : !fir.box<!fir.array<10xf32>>) -> !fir.box<!fir.array<10xf32>> {name = "arr", structured = false}
21+
// CHECK: Mappable: !fir.box<!fir.array<10xf32>>
22+
// CHECK: Type category: array
23+
// CHECK: Size: 40
24+
25+
// CHECK: Visiting: %{{.*}} = acc.copyin varPtr(%{{.*}} : !fir.ref<!fir.array<10xf32>>) -> !fir.ref<!fir.array<10xf32>> {name = "arr", structured = false}
26+
// CHECK: Mappable: !fir.array<10xf32>
27+
// CHECK: Type category: array
28+
// CHECK: Size: 40
29+
30+
// This second test exercises argument of explicit-shape arrays in following forms:
31+
// `real :: arr1(nn), arr2(2:nn), arr3(10)`
32+
// It uses the reference instead of the box in the clauses to test that bounds
33+
// can be generated from the shape operations.
34+
func.func @_QPacc_explicit_shape(%arg0: !fir.ref<!fir.array<?xf32>> {fir.bindc_name = "arr1"}, %arg1: !fir.ref<!fir.array<?xf32>> {fir.bindc_name = "arr2"}, %arg2: !fir.ref<i32> {fir.bindc_name = "nn"}) {
35+
%c-1 = arith.constant -1 : index
36+
%c2 = arith.constant 2 : index
37+
%c0 = arith.constant 0 : index
38+
%c10 = arith.constant 10 : index
39+
%0 = fir.dummy_scope : !fir.dscope
40+
%1:2 = hlfir.declare %arg2 dummy_scope %0 {uniq_name = "_QFacc_explicit_shapeEnn"} : (!fir.ref<i32>, !fir.dscope) -> (!fir.ref<i32>, !fir.ref<i32>)
41+
%2 = fir.alloca !fir.array<10xf32> {bindc_name = "arr3", uniq_name = "_QFacc_explicit_shapeEarr3"}
42+
%3 = fir.shape %c10 : (index) -> !fir.shape<1>
43+
%4:2 = hlfir.declare %2(%3) {uniq_name = "_QFacc_explicit_shapeEarr3"} : (!fir.ref<!fir.array<10xf32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<10xf32>>, !fir.ref<!fir.array<10xf32>>)
44+
%5 = fir.load %1#0 : !fir.ref<i32>
45+
%6 = fir.convert %5 : (i32) -> index
46+
%7 = arith.cmpi sgt, %6, %c0 : index
47+
%8 = arith.select %7, %6, %c0 : index
48+
%9 = fir.shape %8 : (index) -> !fir.shape<1>
49+
%10:2 = hlfir.declare %arg0(%9) dummy_scope %0 {uniq_name = "_QFacc_explicit_shapeEarr1"} : (!fir.ref<!fir.array<?xf32>>, !fir.shape<1>, !fir.dscope) -> (!fir.box<!fir.array<?xf32>>, !fir.ref<!fir.array<?xf32>>)
50+
%11 = arith.addi %6, %c-1 : index
51+
%12 = arith.cmpi sgt, %11, %c0 : index
52+
%13 = arith.select %12, %11, %c0 : index
53+
%14 = fir.shape_shift %c2, %13 : (index, index) -> !fir.shapeshift<1>
54+
%15:2 = hlfir.declare %arg1(%14) dummy_scope %0 {uniq_name = "_QFacc_explicit_shapeEarr2"} : (!fir.ref<!fir.array<?xf32>>, !fir.shapeshift<1>, !fir.dscope) -> (!fir.box<!fir.array<?xf32>>, !fir.ref<!fir.array<?xf32>>)
55+
%16 = acc.copyin var(%10#1 : !fir.ref<!fir.array<?xf32>>) -> !fir.ref<!fir.array<?xf32>> {name = "arr1", structured = false}
56+
%17 = acc.copyin var(%15#1 : !fir.ref<!fir.array<?xf32>>) -> !fir.ref<!fir.array<?xf32>> {name = "arr2", structured = false}
57+
%18 = acc.copyin varPtr(%4#0 : !fir.ref<!fir.array<10xf32>>) -> !fir.ref<!fir.array<10xf32>> {name = "arr3", structured = false}
58+
acc.enter_data dataOperands(%16, %17, %18 : !fir.ref<!fir.array<?xf32>>, !fir.ref<!fir.array<?xf32>>, !fir.ref<!fir.array<10xf32>>)
59+
return
60+
}
61+
62+
// CHECK: Visiting: %{{.*}} = acc.copyin varPtr(%{{.*}} : !fir.ref<!fir.array<?xf32>>) -> !fir.ref<!fir.array<?xf32>> {name = "arr1", structured = false}
63+
// CHECK: Pointer-like: !fir.ref<!fir.array<?xf32>>
64+
// CHECK: Mappable: !fir.array<?xf32>
65+
// CHECK: Type category: array
66+
// CHECK: Bound[0]: %{{.*}} = acc.bounds lowerbound(%c0{{.*}} : index) upperbound(%{{.*}} : index) extent(%{{.*}} : index) stride(%c1{{.*}} : index) startIdx(%c1{{.*}} : index)
67+
68+
// CHECK: Visiting: %{{.*}} = acc.copyin varPtr(%{{.*}} : !fir.ref<!fir.array<?xf32>>) -> !fir.ref<!fir.array<?xf32>> {name = "arr2", structured = false}
69+
// CHECK: Pointer-like: !fir.ref<!fir.array<?xf32>>
70+
// CHECK: Mappable: !fir.array<?xf32>
71+
// CHECK: Type category: array
72+
// CHECK: Bound[0]: %{{.*}} = acc.bounds lowerbound(%c0{{.*}} : index) upperbound(%{{.*}} : index) extent(%{{.*}} : index) stride(%c1{{.*}} : index) startIdx(%c2{{.*}} : index)
73+
74+
// CHECK: Visiting: %{{.*}} = acc.copyin varPtr(%{{.*}} : !fir.ref<!fir.array<10xf32>>) -> !fir.ref<!fir.array<10xf32>> {name = "arr3", structured = false}
75+
// CHECK: Pointer-like: !fir.ref<!fir.array<10xf32>>
76+
// CHECK: Mappable: !fir.array<10xf32>
77+
// CHECK: Type category: array
78+
// CHECK: Size: 40
79+
// CHECK: Offset: 0
80+
// CHECK: Bound[0]: %{{.*}} = acc.bounds lowerbound(%c0{{.*}} : index) upperbound(%{{.*}} : index) extent(%c10{{.*}} : index) stride(%c1{{.*}} : index) startIdx(%c1{{.*}} : index)
81+
}

0 commit comments

Comments
 (0)