Skip to content

Commit e22cb93

Browse files
authored
[Flang] Any and All elemental lowering (#75776)
This is an extension of #75774, with Any and All lowering added alongside Count.
1 parent 7ce010f commit e22cb93

File tree

3 files changed

+314
-1
lines changed

3 files changed

+314
-1
lines changed

flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -729,7 +729,37 @@ class ReductionElementalConversion : public mlir::OpRewritePattern<Op> {
729729

730730
mlir::Value init;
731731
GenBodyFn genBodyFn;
732-
if constexpr (std::is_same_v<Op, hlfir::CountOp>) {
732+
if constexpr (std::is_same_v<Op, hlfir::AnyOp>) {
733+
init = builder.createIntegerConstant(loc, builder.getI1Type(), 0);
734+
genBodyFn = [elemental](fir::FirOpBuilder builder, mlir::Location loc,
735+
mlir::Value reduction,
736+
const llvm::SmallVectorImpl<mlir::Value> &indices)
737+
-> mlir::Value {
738+
// Inline the elemental and get the condition from it.
739+
auto yield = inlineElementalOp(loc, builder, elemental, indices);
740+
mlir::Value cond = builder.create<fir::ConvertOp>(
741+
loc, builder.getI1Type(), yield.getElementValue());
742+
yield->erase();
743+
744+
// Conditionally set the reduction variable.
745+
return builder.create<mlir::arith::OrIOp>(loc, reduction, cond);
746+
};
747+
} else if constexpr (std::is_same_v<Op, hlfir::AllOp>) {
748+
init = builder.createIntegerConstant(loc, builder.getI1Type(), 1);
749+
genBodyFn = [elemental](fir::FirOpBuilder builder, mlir::Location loc,
750+
mlir::Value reduction,
751+
const llvm::SmallVectorImpl<mlir::Value> &indices)
752+
-> mlir::Value {
753+
// Inline the elemental and get the condition from it.
754+
auto yield = inlineElementalOp(loc, builder, elemental, indices);
755+
mlir::Value cond = builder.create<fir::ConvertOp>(
756+
loc, builder.getI1Type(), yield.getElementValue());
757+
yield->erase();
758+
759+
// Conditionally set the reduction variable.
760+
return builder.create<mlir::arith::AndIOp>(loc, reduction, cond);
761+
};
762+
} else if constexpr (std::is_same_v<Op, hlfir::CountOp>) {
733763
init = builder.createIntegerConstant(loc, op.getType(), 0);
734764
genBodyFn = [elemental](fir::FirOpBuilder builder, mlir::Location loc,
735765
mlir::Value reduction,
@@ -800,6 +830,8 @@ class OptimizedBufferizationPass
800830
patterns.insert<BroadcastAssignBufferization>(context);
801831
patterns.insert<VariableAssignBufferization>(context);
802832
patterns.insert<ReductionElementalConversion<hlfir::CountOp>>(context);
833+
patterns.insert<ReductionElementalConversion<hlfir::AnyOp>>(context);
834+
patterns.insert<ReductionElementalConversion<hlfir::AllOp>>(context);
803835

804836
if (mlir::failed(mlir::applyPatternsAndFoldGreedily(
805837
func, std::move(patterns), config))) {

flang/test/HLFIR/all-elemental.fir

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
// RUN: fir-opt %s -opt-bufferization | FileCheck %s
2+
3+
func.func @_QFPtest(%arg0: !fir.ref<!fir.array<4x7xi32>> {fir.bindc_name = "b"}, %arg1: !fir.ref<i32> {fir.bindc_name = "row"}, %arg2: !fir.ref<i32> {fir.bindc_name = "val"}) -> !fir.logical<4> {
4+
%c1 = arith.constant 1 : index
5+
%c4 = arith.constant 4 : index
6+
%c7 = arith.constant 7 : index
7+
%0 = fir.shape %c4, %c7 : (index, index) -> !fir.shape<2>
8+
%1:2 = hlfir.declare %arg0(%0) {uniq_name = "_QFFtestEb"} : (!fir.ref<!fir.array<4x7xi32>>, !fir.shape<2>) -> (!fir.ref<!fir.array<4x7xi32>>, !fir.ref<!fir.array<4x7xi32>>)
9+
%2:2 = hlfir.declare %arg1 {uniq_name = "_QFFtestErow"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
10+
%3 = fir.alloca !fir.logical<4> {bindc_name = "test", uniq_name = "_QFFtestEtest"}
11+
%4:2 = hlfir.declare %3 {uniq_name = "_QFFtestEtest"} : (!fir.ref<!fir.logical<4>>) -> (!fir.ref<!fir.logical<4>>, !fir.ref<!fir.logical<4>>)
12+
%5:2 = hlfir.declare %arg2 {uniq_name = "_QFFtestEval"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
13+
%6 = fir.load %2#0 : !fir.ref<i32>
14+
%7 = fir.convert %6 : (i32) -> i64
15+
%8 = fir.shape %c7 : (index) -> !fir.shape<1>
16+
%9 = hlfir.designate %1#0 (%7, %c1:%c7:%c1) shape %8 : (!fir.ref<!fir.array<4x7xi32>>, i64, index, index, index, !fir.shape<1>) -> !fir.box<!fir.array<7xi32>>
17+
%10 = fir.load %5#0 : !fir.ref<i32>
18+
%11 = hlfir.elemental %8 unordered : (!fir.shape<1>) -> !hlfir.expr<7x!fir.logical<4>> {
19+
^bb0(%arg3: index):
20+
%14 = hlfir.designate %9 (%arg3) : (!fir.box<!fir.array<7xi32>>, index) -> !fir.ref<i32>
21+
%15 = fir.load %14 : !fir.ref<i32>
22+
%16 = arith.cmpi sge, %15, %10 : i32
23+
%17 = fir.convert %16 : (i1) -> !fir.logical<4>
24+
hlfir.yield_element %17 : !fir.logical<4>
25+
}
26+
%12 = hlfir.all %11 : (!hlfir.expr<7x!fir.logical<4>>) -> !fir.logical<4>
27+
hlfir.assign %12 to %4#0 : !fir.logical<4>, !fir.ref<!fir.logical<4>>
28+
hlfir.destroy %11 : !hlfir.expr<7x!fir.logical<4>>
29+
%13 = fir.load %4#1 : !fir.ref<!fir.logical<4>>
30+
return %13 : !fir.logical<4>
31+
}
32+
// CHECK-LABEL: func.func @_QFPtest(%arg0: !fir.ref<!fir.array<4x7xi32>> {fir.bindc_name = "b"}, %arg1: !fir.ref<i32> {fir.bindc_name = "row"}, %arg2: !fir.ref<i32> {fir.bindc_name = "val"}) -> !fir.logical<4> {
33+
// CHECK-NEXT: %true = arith.constant true
34+
// CHECK-NEXT: %c1 = arith.constant 1 : index
35+
// CHECK-NEXT: %c4 = arith.constant 4 : index
36+
// CHECK-NEXT: %c7 = arith.constant 7 : index
37+
// CHECK-NEXT: %[[V0:.*]] = fir.shape %c4, %c7 : (index, index) -> !fir.shape<2>
38+
// CHECK-NEXT: %[[V1:.*]]:2 = hlfir.declare %arg0(%[[V0]])
39+
// CHECK-NEXT: %[[V2:.*]]:2 = hlfir.declare %arg1
40+
// CHECK-NEXT: %[[V3:.*]] = fir.alloca !fir.logical<4>
41+
// CHECK-NEXT: %[[V4:.*]]:2 = hlfir.declare %[[V3]]
42+
// CHECK-NEXT: %[[V5:.*]]:2 = hlfir.declare %arg2
43+
// CHECK-NEXT: %[[V6:.*]] = fir.load %[[V2]]#0 : !fir.ref<i32>
44+
// CHECK-NEXT: %[[V7:.*]] = fir.convert %[[V6]] : (i32) -> i64
45+
// CHECK-NEXT: %[[V8:.*]] = fir.shape %c7 : (index) -> !fir.shape<1>
46+
// CHECK-NEXT: %[[V9:.*]] = hlfir.designate %[[V1]]#0 (%[[V7]], %c1:%c7:%c1) shape %[[V8]] : (!fir.ref<!fir.array<4x7xi32>>, i64, index, index, index, !fir.shape<1>) -> !fir.box<!fir.array<7xi32>>
47+
// CHECK-NEXT: %[[V10:.*]] = fir.load %[[V5]]#0 : !fir.ref<i32>
48+
// CHECK-NEXT: %[[V11:.*]] = fir.do_loop %arg3 = %c1 to %c7 step %c1 iter_args(%arg4 = %true) -> (i1) {
49+
// CHECK-NEXT: %[[V14:.*]] = hlfir.designate %[[V9]] (%arg3) : (!fir.box<!fir.array<7xi32>>, index) -> !fir.ref<i32>
50+
// CHECK-NEXT: %[[V15:.*]] = fir.load %[[V14]] : !fir.ref<i32>
51+
// CHECK-NEXT: %[[V16:.*]] = arith.cmpi sge, %[[V15]], %[[V10]] : i32
52+
// CHECK-NEXT: %[[V17:.*]] = arith.andi %arg4, %[[V16]] : i1
53+
// CHECK-NEXT: fir.result %[[V17]] : i1
54+
// CHECK-NEXT: }
55+
// CHECK-NEXT: %[[V12:.*]] = fir.convert %[[V11]] : (i1) -> !fir.logical<4>
56+
// CHECK-NEXT: hlfir.assign %[[V12]] to %[[V4]]#0 : !fir.logical<4>, !fir.ref<!fir.logical<4>>
57+
// CHECK-NEXT: %[[V13:.*]] = fir.load %[[V4]]#1 : !fir.ref<!fir.logical<4>>
58+
// CHECK-NEXT: return %[[V13]] : !fir.logical<4>
59+
60+
61+
func.func @_QFPtest_dim(%arg0: !fir.ref<!fir.array<4x7xi32>> {fir.bindc_name = "b"}, %arg1: !fir.ref<i32> {fir.bindc_name = "row"}, %arg2: !fir.ref<i32> {fir.bindc_name = "val"}) -> !fir.array<4x!fir.logical<4>> {
62+
%c2_i32 = arith.constant 2 : i32
63+
%c1 = arith.constant 1 : index
64+
%c4 = arith.constant 4 : index
65+
%c7 = arith.constant 7 : index
66+
%0 = fir.shape %c4, %c7 : (index, index) -> !fir.shape<2>
67+
%1:2 = hlfir.declare %arg0(%0) {uniq_name = "_QFFtestEb"} : (!fir.ref<!fir.array<4x7xi32>>, !fir.shape<2>) -> (!fir.ref<!fir.array<4x7xi32>>, !fir.ref<!fir.array<4x7xi32>>)
68+
%2:2 = hlfir.declare %arg1 {uniq_name = "_QFFtestErow"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
69+
%3 = fir.alloca !fir.array<4x!fir.logical<4>> {bindc_name = "test", uniq_name = "_QFFtestEtest"}
70+
%4 = fir.shape %c4 : (index) -> !fir.shape<1>
71+
%5:2 = hlfir.declare %3(%4) {uniq_name = "_QFFtestEtest"} : (!fir.ref<!fir.array<4x!fir.logical<4>>>, !fir.shape<1>) -> (!fir.ref<!fir.array<4x!fir.logical<4>>>, !fir.ref<!fir.array<4x!fir.logical<4>>>)
72+
%6:2 = hlfir.declare %arg2 {uniq_name = "_QFFtestEval"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
73+
%7 = hlfir.designate %1#0 (%c1:%c4:%c1, %c1:%c7:%c1) shape %0 : (!fir.ref<!fir.array<4x7xi32>>, index, index, index, index, index, index, !fir.shape<2>) -> !fir.ref<!fir.array<4x7xi32>>
74+
%8 = fir.load %6#0 : !fir.ref<i32>
75+
%9 = hlfir.elemental %0 unordered : (!fir.shape<2>) -> !hlfir.expr<4x7x!fir.logical<4>> {
76+
^bb0(%arg3: index, %arg4: index):
77+
%12 = hlfir.designate %7 (%arg3, %arg4) : (!fir.ref<!fir.array<4x7xi32>>, index, index) -> !fir.ref<i32>
78+
%13 = fir.load %12 : !fir.ref<i32>
79+
%14 = arith.cmpi sge, %13, %8 : i32
80+
%15 = fir.convert %14 : (i1) -> !fir.logical<4>
81+
hlfir.yield_element %15 : !fir.logical<4>
82+
}
83+
%10 = hlfir.all %9 dim %c2_i32 : (!hlfir.expr<4x7x!fir.logical<4>>, i32) -> !hlfir.expr<4x!fir.logical<4>>
84+
hlfir.assign %10 to %5#0 : !hlfir.expr<4x!fir.logical<4>>, !fir.ref<!fir.array<4x!fir.logical<4>>>
85+
hlfir.destroy %10 : !hlfir.expr<4x!fir.logical<4>>
86+
hlfir.destroy %9 : !hlfir.expr<4x7x!fir.logical<4>>
87+
%11 = fir.load %5#1 : !fir.ref<!fir.array<4x!fir.logical<4>>>
88+
return %11 : !fir.array<4x!fir.logical<4>>
89+
}
90+
// CHECK-LABEL: func.func @_QFPtest_dim(
91+
// CHECK: %10 = hlfir.all %9 dim %c2_i32

0 commit comments

Comments
 (0)