Skip to content

Commit 74ed79f

Browse files
authored
[mlir][linalg] Add linalg.transpose constant folding (#92589)
There was existing support for constant folding a `linalg.generic` that was actually a transpose. This commit adds support for the named op, `linalg.transpose`, as well by making use of the `LinalgOp` interface.
1 parent d2a103e commit 74ed79f

File tree

3 files changed

+180
-163
lines changed

3 files changed

+180
-163
lines changed

mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -23,21 +23,21 @@ using namespace mlir;
2323
using namespace mlir::linalg;
2424

2525
namespace {
26-
/// Base class for constant folding linalg.generic ops with N inputs, 1 output,
27-
/// and permutation indexing maps.
26+
/// Base class for constant folding linalg structured ops with N inputs, 1
27+
/// output, and permutation indexing maps.
2828
///
2929
/// `ConcreteType` should provide methods with signatures
3030
///
3131
/// ```c++
32-
/// bool matchIndexingMaps(GenericOp genericOp) const;
33-
/// RegionComputationFn getRegionComputeFn(GenericOp) const;
32+
/// bool matchIndexingMaps(LinalgOp linalgOp) const;
33+
/// RegionComputationFn getRegionComputeFn(LinalgOp) const;
3434
/// ```
3535
///
3636
/// The latter inspects the region and returns the computation inside as a
3737
/// functor. The functor will be invoked with constant elements for all inputs
3838
/// and should return the corresponding computed constant element for output.
3939
template <typename ConcreteType>
40-
class FoldConstantBase : public OpRewritePattern<GenericOp> {
40+
class FoldConstantBase : public OpInterfaceRewritePattern<LinalgOp> {
4141
public:
4242
struct APIntOrFloat {
4343
std::optional<APInt> apInt;
@@ -52,25 +52,26 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
5252

5353
FoldConstantBase(MLIRContext *context, const ControlFusionFn &controlFn,
5454
PatternBenefit benefit = 1)
55-
: OpRewritePattern<GenericOp>(context, benefit), controlFn(controlFn) {}
55+
: OpInterfaceRewritePattern<LinalgOp>(context, benefit),
56+
controlFn(controlFn) {}
5657

57-
LogicalResult matchAndRewrite(GenericOp genericOp,
58+
LogicalResult matchAndRewrite(LinalgOp linalgOp,
5859
PatternRewriter &rewriter) const override {
5960
// Mixed and buffer sematics aren't supported.
60-
if (!genericOp.hasPureTensorSemantics())
61+
if (!linalgOp.hasPureTensorSemantics())
6162
return failure();
6263

6364
// Only support ops generating one output for now.
64-
if (genericOp.getNumDpsInits() != 1)
65+
if (linalgOp.getNumDpsInits() != 1)
6566
return failure();
6667

67-
auto outputType = dyn_cast<ShapedType>(genericOp.getResultTypes().front());
68+
auto outputType = dyn_cast<ShapedType>(linalgOp->getResultTypes().front());
6869
// Require the output types to be static given that we are generating
6970
// constants.
7071
if (!outputType || !outputType.hasStaticShape())
7172
return failure();
7273

73-
if (!llvm::all_of(genericOp.getInputs(), [](Value input) {
74+
if (!llvm::all_of(linalgOp.getDpsInputs(), [](Value input) {
7475
return isa<ShapedType>(input.getType());
7576
}))
7677
return failure();
@@ -80,7 +81,7 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
8081
return cast<ShapedType>(value.getType()).getElementType();
8182
};
8283
if (!llvm::all_equal(
83-
llvm::map_range(genericOp->getOperands(), getOperandElementType)))
84+
llvm::map_range(linalgOp->getOperands(), getOperandElementType)))
8485
return failure();
8586

8687
// We can only handle the case where we have int/float elements.
@@ -93,43 +94,42 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
9394
// entirely in the compiler, without needing to turn all indices into
9495
// Values, and then do affine apply on them, and then match back the
9596
// constant again.
96-
if (!llvm::all_of(genericOp.getIndexingMapsArray(),
97+
if (!llvm::all_of(linalgOp.getIndexingMapsArray(),
9798
[](AffineMap map) { return map.isPermutation(); }))
9899
return failure();
99100

100-
for (OpOperand &operand : genericOp.getDpsInitsMutable()) {
101-
if (genericOp.payloadUsesValueFromOperand(&operand))
101+
for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
102+
if (linalgOp.payloadUsesValueFromOperand(&operand))
102103
return failure();
103104
}
104105

105106
// Further check the indexing maps are okay for the ConcreteType.
106-
if (!static_cast<const ConcreteType *>(this)->matchIndexingMaps(genericOp))
107+
if (!static_cast<const ConcreteType *>(this)->matchIndexingMaps(linalgOp))
107108
return failure();
108109

109110
// Defer to the concrete type to check the region and discover the
110111
// computation inside.
111112
RegionComputationFn computeFn =
112-
static_cast<const ConcreteType *>(this)->getRegionComputeFn(genericOp);
113+
static_cast<const ConcreteType *>(this)->getRegionComputeFn(linalgOp);
113114
if (!computeFn)
114115
return failure();
115116

116117
// All inputs should be constants.
117-
int numInputs = genericOp.getNumDpsInputs();
118+
int numInputs = linalgOp.getNumDpsInputs();
118119
SmallVector<DenseIntOrFPElementsAttr> inputValues(numInputs);
119-
for (const auto &en : llvm::enumerate(genericOp.getDpsInputOperands())) {
120+
for (const auto &en : llvm::enumerate(linalgOp.getDpsInputOperands())) {
120121
if (!matchPattern(en.value()->get(),
121122
m_Constant(&inputValues[en.index()])))
122123
return failure();
123124
}
124125

125126
// Identified this as a potential candidate for folding. Now check the
126127
// policy to see whether we are allowed to proceed.
127-
for (OpOperand *operand : genericOp.getDpsInputOperands()) {
128+
for (OpOperand *operand : linalgOp.getDpsInputOperands()) {
128129
if (!controlFn(operand))
129130
return failure();
130131
}
131132

132-
auto linalgOp = cast<LinalgOp>(genericOp.getOperation());
133133
SmallVector<int64_t, 4> loopBounds = linalgOp.computeStaticLoopSizes();
134134
int64_t numElements = outputType.getNumElements();
135135

@@ -155,8 +155,8 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
155155

156156
SmallVector<SmallVector<unsigned>> inputDims;
157157
for (int i = 0; i < numInputs; ++i)
158-
inputDims.push_back(getDimPositions(genericOp.getIndexingMapsArray()[i]));
159-
auto outputDims = getDimPositions(genericOp.getIndexingMapsArray().back());
158+
inputDims.push_back(getDimPositions(linalgOp.getIndexingMapsArray()[i]));
159+
auto outputDims = getDimPositions(linalgOp.getIndexingMapsArray().back());
160160
auto outputShape = outputType.getShape();
161161

162162
// Allocate small vectors for index delinearization. Initial values do not
@@ -173,7 +173,7 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
173173
APIntOrFloatArray computeFnInputs;
174174

175175
auto inputShapes = llvm::to_vector<4>(
176-
llvm::map_range(genericOp.getInputs(), [](Value value) {
176+
llvm::map_range(linalgOp.getDpsInputs(), [](Value value) {
177177
return cast<ShapedType>(value.getType()).getShape();
178178
}));
179179

@@ -254,26 +254,28 @@ class FoldConstantBase : public OpRewritePattern<GenericOp> {
254254
isFloat ? DenseElementsAttr::get(outputType, fpOutputValues)
255255
: DenseElementsAttr::get(outputType, intOutputValues);
256256

257-
rewriter.replaceOpWithNewOp<arith::ConstantOp>(genericOp, outputAttr);
257+
rewriter.replaceOpWithNewOp<arith::ConstantOp>(linalgOp, outputAttr);
258258
return success();
259259
}
260260

261261
private:
262262
ControlFusionFn controlFn;
263263
};
264264

265-
// Folds linalg.generic ops that are actually transposes on constant values.
265+
// Folds linalg.transpose (and linalg.generic ops that are actually transposes)
266+
// on constant values.
266267
struct FoldConstantTranspose : public FoldConstantBase<FoldConstantTranspose> {
268+
267269
using FoldConstantBase::FoldConstantBase;
268270

269-
bool matchIndexingMaps(GenericOp genericOp) const {
271+
bool matchIndexingMaps(LinalgOp linalgOp) const {
270272
// We should have one input and one output.
271-
return genericOp.getIndexingMapsArray().size() == 2;
273+
return linalgOp.getIndexingMapsArray().size() == 2;
272274
}
273275

274-
RegionComputationFn getRegionComputeFn(GenericOp genericOp) const {
276+
RegionComputationFn getRegionComputeFn(LinalgOp linalgOp) const {
275277
// Make sure the region only contains a yield op.
276-
Block &body = genericOp.getRegion().front();
278+
Block &body = linalgOp->getRegion(0).front();
277279
if (!llvm::hasSingleElement(body))
278280
return nullptr;
279281
auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
// RUN: mlir-opt %s -linalg-fuse-elementwise-ops -split-input-file | FileCheck %s
2+
3+
// CHECK-LABEL: @transpose_fold_2d_fp32
4+
func.func @transpose_fold_2d_fp32(%init: tensor<3x2xf32>) -> tensor<3x2xf32> {
5+
%input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>
6+
// CHECK: %[[CST:.+]] = arith.constant
7+
// CHECK-SAME{LITERAL}: dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf32>
8+
%1 = linalg.generic {
9+
indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
10+
iterator_types = ["parallel", "parallel"]
11+
} ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) {
12+
^bb0(%arg1: f32, %arg2: f32):
13+
linalg.yield %arg1 : f32
14+
} -> tensor<3x2xf32>
15+
// CHECK: return %[[CST]]
16+
return %1 : tensor<3x2xf32>
17+
}
18+
19+
// -----
20+
21+
// CHECK-LABEL: @transpose_fold_2d_fp64
22+
func.func @transpose_fold_2d_fp64(%init: tensor<3x2xf64>) -> tensor<3x2xf64> {
23+
%input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf64>
24+
// CHECK: %[[CST:.+]] = arith.constant
25+
// CHECK-SAME{LITERAL}: dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf64>
26+
%1 = linalg.generic {
27+
indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
28+
iterator_types = ["parallel", "parallel"]
29+
} ins(%input : tensor<2x3xf64>) outs(%init : tensor<3x2xf64>) {
30+
^bb0(%arg1: f64, %arg2: f64):
31+
linalg.yield %arg1 : f64
32+
} -> tensor<3x2xf64>
33+
// CHECK: return %[[CST]]
34+
return %1 : tensor<3x2xf64>
35+
}
36+
37+
// -----
38+
39+
// CHECK-LABEL: @transpose_fold_4d_i32
40+
func.func @transpose_fold_4d_i32(%init: tensor<3x1x4x2xi32>) -> tensor<3x1x4x2xi32> {
41+
%input = arith.constant dense<[[
42+
[[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]],
43+
[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]
44+
]]> : tensor<1x2x3x4xi32>
45+
// CHECK: %[[CST:.+]] = arith.constant dense<[
46+
// CHECK-SAME{LITERAL}: [[[0, 12], [1, 13], [2, 14], [3, 15]]],
47+
// CHECK-SAME{LITERAL}: [[[4, 16], [5, 17], [6, 18], [7, 19]]],
48+
// CHECK-SAME{LITERAL}: [[[8, 20], [9, 21], [10, 22], [11, 23]]]
49+
// CHECK-SAME{LITERAL}: ]>
50+
%1 = linalg.generic {
51+
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d0, d3, d1)>],
52+
iterator_types = ["parallel", "parallel", "parallel", "parallel"]
53+
} ins(%input : tensor<1x2x3x4xi32>) outs(%init : tensor<3x1x4x2xi32>) {
54+
^bb0(%arg1: i32, %arg2: i32):
55+
linalg.yield %arg1 : i32
56+
} -> tensor<3x1x4x2xi32>
57+
// CHECK: return %[[CST]]
58+
return %1 : tensor<3x1x4x2xi32>
59+
}
60+
61+
// -----
62+
63+
// CHECK-LABEL: @transpose_fold_4d_i16
64+
func.func @transpose_fold_4d_i16(%init: tensor<3x1x4x2xi16>) -> tensor<3x1x4x2xi16> {
65+
%input = arith.constant dense<[[
66+
[[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]],
67+
[[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]]
68+
]]> : tensor<1x2x3x4xi16>
69+
// CHECK: %[[CST:.+]] = arith.constant dense<[
70+
// CHECK-SAME{LITERAL}: [[[0, 12], [1, 13], [2, 14], [3, 15]]],
71+
// CHECK-SAME{LITERAL}: [[[4, 16], [5, 17], [6, 18], [7, 19]]],
72+
// CHECK-SAME{LITERAL}: [[[8, 20], [9, 21], [10, 22], [11, 23]]]
73+
// CHECK-SAME{LITERAL}: ]>
74+
%1 = linalg.generic {
75+
indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d0, d3, d1)>],
76+
iterator_types = ["parallel", "parallel", "parallel", "parallel"]
77+
} ins(%input : tensor<1x2x3x4xi16>) outs(%init : tensor<3x1x4x2xi16>) {
78+
^bb0(%arg1: i16, %arg2: i16):
79+
linalg.yield %arg1 : i16
80+
} -> tensor<3x1x4x2xi16>
81+
// CHECK: return %[[CST]]
82+
return %1 : tensor<3x1x4x2xi16>
83+
}
84+
85+
// -----
86+
87+
// CHECK-LABEL: @transpose_nofold_non_cst_input
88+
func.func @transpose_nofold_non_cst_input(%input: tensor<2x3xf32>, %init: tensor<3x2xf32>) -> tensor<3x2xf32> {
89+
// CHECK: linalg.generic
90+
%1 = linalg.generic {
91+
indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
92+
iterator_types = ["parallel", "parallel"]
93+
} ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) {
94+
^bb0(%arg1: f32, %arg2: f32):
95+
linalg.yield %arg1 : f32
96+
} -> tensor<3x2xf32>
97+
return %1 : tensor<3x2xf32>
98+
}
99+
100+
// -----
101+
102+
// CHECK-LABEL: @transpose_nofold_yield_const
103+
func.func @transpose_nofold_yield_const(%init: tensor<3x2xf32>) -> tensor<3x2xf32> {
104+
%input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>
105+
%cst = arith.constant 8.0 : f32
106+
// CHECK: linalg.generic
107+
%1 = linalg.generic {
108+
indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
109+
iterator_types = ["parallel", "parallel"]
110+
} ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) {
111+
^bb0(%arg1: f32, %arg2: f32):
112+
linalg.yield %cst : f32
113+
} -> tensor<3x2xf32>
114+
return %1 : tensor<3x2xf32>
115+
}
116+
117+
// -----
118+
119+
// CHECK-LABEL: @transpose_nofold_multi_ops_in_region
120+
func.func @transpose_nofold_multi_ops_in_region(%init: tensor<3x2xf32>) -> tensor<3x2xf32> {
121+
%input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>
122+
// CHECK: linalg.generic
123+
%1 = linalg.generic {
124+
indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>],
125+
iterator_types = ["parallel", "parallel"]
126+
} ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) {
127+
^bb0(%arg1: f32, %arg2: f32):
128+
%add = arith.addf %arg1, %arg1 : f32
129+
linalg.yield %add : f32
130+
} -> tensor<3x2xf32>
131+
return %1 : tensor<3x2xf32>
132+
}
133+
134+
// -----
135+
136+
// CHECK-LABEL: @named_transpose_fold_2d_fp32
137+
func.func @named_transpose_fold_2d_fp32(%init: tensor<3x2xf32>) -> tensor<3x2xf32> {
138+
%input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>
139+
// CHECK: %[[CST:.+]] = arith.constant
140+
// CHECK-SAME{LITERAL}: dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf32>
141+
%1 = linalg.transpose ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) permutation = [1, 0]
142+
// CHECK: return %[[CST]]
143+
return %1 : tensor<3x2xf32>
144+
}
145+
146+
// -----
147+
148+

0 commit comments

Comments
 (0)