Skip to content

Commit 6bcfab3

Browse files
committed
[flang][hlfir] allow recursive intrinsic lowering
We need to allow recursive application of intrinsic lowering patterns, otherwise we cannot lower nested calls of the same intrinsic e.g. matmul(matmul(a, b), c). matmul(matmul(a, b), matmul(c, d)) requires hlfir.associate of hlfir expr with more than one use (TODO). Differential Revision: https://reviews.llvm.org/D152284
1 parent 89227b6 commit 6bcfab3

File tree

2 files changed

+53
-2
lines changed

2 files changed

+53
-2
lines changed

flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIRIntrinsics.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===- LowerHLFIRIntrinsics.cpp - Bufferize HLFIR ------------------------===//
1+
//===- LowerHLFIRIntrinsics.cpp - Transformational intrinsics to FIR ------===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -37,7 +37,22 @@ namespace {
3737
/// runtime calls
3838
template <class OP>
3939
class HlfirIntrinsicConversion : public mlir::OpRewritePattern<OP> {
40-
using mlir::OpRewritePattern<OP>::OpRewritePattern;
40+
public:
41+
explicit HlfirIntrinsicConversion(mlir::MLIRContext *ctx)
42+
: mlir::OpRewritePattern<OP>{ctx} {
43+
// required for cases where intrinsics are chained together e.g.
44+
// matmul(matmul(a, b), c)
45+
// because converting the inner operation then invalidates the
46+
// outer operation: causing the pattern to apply recursively.
47+
//
48+
// This is safe because we always progress with each iteration. Circular
49+
// applications of operations are not expressible in MLIR because we use
50+
// an SSA form and one must become first. E.g.
51+
// %a = hlfir.matmul %b %d
52+
// %b = hlfir.matmul %a %d
53+
// cannot be written.
54+
mlir::OpConversionPattern<OP>::setHasBoundedRewriteRecursion(true);
55+
}
4156

4257
protected:
4358
struct IntrinsicArgument {

flang/test/HLFIR/matmul-lowering.fir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,3 +43,39 @@ func.func @_QPmatmul1(%arg0: !fir.box<!fir.array<?x?xi32>> {fir.bindc_name = "lh
4343
// CHECK: hlfir.destroy %[[ASEXPR]]
4444
// CHECK-NEXT: return
4545
// CHECK-NEXT: }
46+
47+
// nested matmuls leading to recursive pattern application
48+
func.func @_QPtest(%arg0: !fir.ref<!fir.array<3x3xf32>> {fir.bindc_name = "a"}, %arg1: !fir.ref<!fir.array<3x3xf32>> {fir.bindc_name = "b"}, %arg2: !fir.ref<!fir.array<3x3xf32>> {fir.bindc_name = "c"}, %arg3: !fir.ref<!fir.array<3x3xf32>> {fir.bindc_name = "out"}) {
49+
%c3 = arith.constant 3 : index
50+
%c3_0 = arith.constant 3 : index
51+
%0 = fir.shape %c3, %c3_0 : (index, index) -> !fir.shape<2>
52+
%1:2 = hlfir.declare %arg0(%0) {uniq_name = "_QFtestEa"} : (!fir.ref<!fir.array<3x3xf32>>, !fir.shape<2>) -> (!fir.ref<!fir.array<3x3xf32>>, !fir.ref<!fir.array<3x3xf32>>)
53+
%c3_1 = arith.constant 3 : index
54+
%c3_2 = arith.constant 3 : index
55+
%2 = fir.shape %c3_1, %c3_2 : (index, index) -> !fir.shape<2>
56+
%3:2 = hlfir.declare %arg1(%2) {uniq_name = "_QFtestEb"} : (!fir.ref<!fir.array<3x3xf32>>, !fir.shape<2>) -> (!fir.ref<!fir.array<3x3xf32>>, !fir.ref<!fir.array<3x3xf32>>)
57+
%c3_3 = arith.constant 3 : index
58+
%c3_4 = arith.constant 3 : index
59+
%4 = fir.shape %c3_3, %c3_4 : (index, index) -> !fir.shape<2>
60+
%5:2 = hlfir.declare %arg2(%4) {uniq_name = "_QFtestEc"} : (!fir.ref<!fir.array<3x3xf32>>, !fir.shape<2>) -> (!fir.ref<!fir.array<3x3xf32>>, !fir.ref<!fir.array<3x3xf32>>)
61+
%c3_5 = arith.constant 3 : index
62+
%c3_6 = arith.constant 3 : index
63+
%6 = fir.shape %c3_5, %c3_6 : (index, index) -> !fir.shape<2>
64+
%7:2 = hlfir.declare %arg3(%6) {uniq_name = "_QFtestEout"} : (!fir.ref<!fir.array<3x3xf32>>, !fir.shape<2>) -> (!fir.ref<!fir.array<3x3xf32>>, !fir.ref<!fir.array<3x3xf32>>)
65+
%8 = hlfir.matmul %1#0 %3#0 {fastmath = #arith.fastmath<contract>} : (!fir.ref<!fir.array<3x3xf32>>, !fir.ref<!fir.array<3x3xf32>>) -> !hlfir.expr<3x3xf32>
66+
%9 = hlfir.matmul %8 %5#0 {fastmath = #arith.fastmath<contract>} : (!hlfir.expr<3x3xf32>, !fir.ref<!fir.array<3x3xf32>>) -> !hlfir.expr<3x3xf32>
67+
hlfir.assign %9 to %7#0 : !hlfir.expr<3x3xf32>, !fir.ref<!fir.array<3x3xf32>>
68+
hlfir.destroy %9 : !hlfir.expr<3x3xf32>
69+
hlfir.destroy %8 : !hlfir.expr<3x3xf32>
70+
return
71+
}
72+
// just check that we apply the patterns successfully. The details are checked above
73+
// CHECK-LABEL: func.func @_QPtest(
74+
// CHECK: %arg0: !fir.ref<!fir.array<3x3xf32>> {fir.bindc_name = "a"},
75+
// CHECK-SAME: %arg1: !fir.ref<!fir.array<3x3xf32>> {fir.bindc_name = "b"},
76+
// CHECK-SAME: %arg2: !fir.ref<!fir.array<3x3xf32>> {fir.bindc_name = "c"},
77+
// CHECK-SAME: %arg3: !fir.ref<!fir.array<3x3xf32>> {fir.bindc_name = "out"}) {
78+
// CHECK: fir.call @_FortranAMatmul(
79+
// CHECK; fir.call @_FortranAMatmul(%40, %41, %42, %43, %c20_i32) : (!fir.ref<!fir.box<none>>, !fir.box<none>, !fir.box<none>, !fir.ref<i8>, i32) -> none
80+
// CHECK: return
81+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)