Skip to content

Commit f397bdf

Browse files
authored
[mlir][tensor] Fold consumer linalg transpose with producer tensor pack (#74206)
Partial fix to iree-org/iree#15367
1 parent 8bea83b commit f397bdf

File tree

2 files changed

+294
-1
lines changed

2 files changed

+294
-1
lines changed

mlir/lib/Dialect/Tensor/Transforms/FoldIntoPackAndUnpackPatterns.cpp

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
910
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1011
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
1112
#include "mlir/IR/PatternMatch.h"
@@ -81,10 +82,71 @@ struct FoldUnpackWithExtractSliceOp : public OpRewritePattern<ExtractSliceOp> {
8182
return success();
8283
}
8384
};
85+
86+
/// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose
87+
/// semantics.
88+
struct FoldProducerPackWithConsumerLinalgTransposeOp
89+
: public OpRewritePattern<linalg::TransposeOp> {
90+
using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
91+
92+
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
93+
PatternRewriter &rewriter) const override {
94+
auto packOp = transposeOp.getOperand(0).getDefiningOp<PackOp>();
95+
96+
if (!packOp)
97+
return failure();
98+
99+
auto innerDimsPos = packOp.getInnerDimsPos();
100+
auto mixedInnerTiles = packOp.getMixedTiles();
101+
auto outerDimsPerm = packOp.getOuterDimsPerm();
102+
auto transposePerm = transposeOp.getPermutation();
103+
SmallVector<int64_t> newOuterDimsPermVec;
104+
SmallVector<int64_t> newInnerDimsPosVec;
105+
SmallVector<OpFoldResult> newMixedInnerTilesVec;
106+
int64_t srcRank = packOp.getSourceRank();
107+
108+
// Process transpose operation for non-tiled outer dimensions
109+
for (unsigned int i = 0; i < srcRank; ++i) {
110+
int64_t remappedPosition = transposePerm[i];
111+
112+
// If tensor.pack has outer_dims_perm attribute, then consider it during
113+
// index remapping.
114+
if (!outerDimsPerm.empty()) {
115+
if (transposePerm[i] >= srcRank) {
116+
return rewriter.notifyMatchFailure(
117+
transposeOp,
118+
"Cannot fold in tensor.pack if a tile dimension was transposed "
119+
"with a non-tile dimension in linalg.transpose.");
120+
}
121+
remappedPosition = outerDimsPerm[remappedPosition];
122+
}
123+
124+
newOuterDimsPermVec.push_back(remappedPosition);
125+
}
126+
127+
// Process transpose operation for tiled inner dimensions
128+
for (unsigned int i = srcRank; i < transposePerm.size(); ++i) {
129+
int64_t remappedPosition = transposePerm[i] - srcRank;
130+
newMixedInnerTilesVec.push_back(mixedInnerTiles[remappedPosition]);
131+
newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
132+
}
133+
134+
Value output = packOp.createDestinationTensor(
135+
rewriter, transposeOp.getLoc(), packOp.getSource(),
136+
newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
137+
138+
rewriter.replaceOpWithNewOp<PackOp>(
139+
transposeOp, packOp.getSource(), output, newInnerDimsPosVec,
140+
newMixedInnerTilesVec, packOp.getPaddingValue(), newOuterDimsPermVec);
141+
142+
return success();
143+
}
144+
};
84145
} // namespace
85146

86147
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) {
87-
patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp>(
148+
patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp,
149+
FoldProducerPackWithConsumerLinalgTransposeOp>(
88150
patterns.getContext());
89151
}
90152

mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir

Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,3 +114,234 @@ func.func @pad_pack_different_padding_value(%src: tensor<16641x16xf32>) -> tenso
114114
// CHECK-LABEL: func.func @pad_pack_different_padding_value
115115
// CHECK: tensor.pad
116116
// CHECK: tensor.pack
117+
118+
// -----
119+
120+
func.func @tensor_pack_linalg_transpose_fold(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x57x56x2x32xf32> {
121+
%0 = tensor.empty() : tensor<56x2x1x57x32xf32>
122+
%pack = tensor.pack %arg0
123+
outer_dims_perm = [0, 3, 2, 1]
124+
inner_dims_pos = [3]
125+
inner_tiles = [32]
126+
into %0 : tensor<56x57x1x64xf32> -> tensor<56x2x1x57x32xf32>
127+
128+
%1 = tensor.empty() : tensor<1x57x56x2x32xf32>
129+
%transposed = linalg.transpose
130+
ins(%pack : tensor<56x2x1x57x32xf32>)
131+
outs(%1 : tensor<1x57x56x2x32xf32>)
132+
permutation = [2, 3, 0, 1, 4]
133+
return %transposed : tensor<1x57x56x2x32xf32>
134+
}
135+
// CHECK: func @tensor_pack_linalg_transpose_fold(
136+
// CHECK-SAME: %[[ARG0:.+]]: tensor<56x57x1x64xf32>)
137+
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x57x56x2x32xf32>
138+
// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]]
139+
// CHECK-SAME: outer_dims_perm = [2, 1, 0, 3]
140+
// CHECK-SAME: inner_dims_pos = [3] inner_tiles = [32]
141+
// CHECK-SAME: into %[[INIT]]
142+
// CHECK: return %[[PACK]]
143+
144+
// -----
145+
146+
func.func @tensor_pack_linalg_transpose_fold_with_padding(%arg0: tensor<56x57x1x55xf32>, %padding: f32) -> tensor<1x57x56x2x32xf32> {
147+
%0 = tensor.empty() : tensor<56x2x1x57x32xf32>
148+
%pack = tensor.pack %arg0 padding_value(%padding : f32)
149+
outer_dims_perm = [0, 3, 2, 1]
150+
inner_dims_pos = [3]
151+
inner_tiles = [32]
152+
into %0 : tensor<56x57x1x55xf32> -> tensor<56x2x1x57x32xf32>
153+
154+
%1 = tensor.empty() : tensor<1x57x56x2x32xf32>
155+
%transposed = linalg.transpose
156+
ins(%pack : tensor<56x2x1x57x32xf32>)
157+
outs(%1 : tensor<1x57x56x2x32xf32>)
158+
permutation = [2, 3, 0, 1, 4]
159+
return %transposed : tensor<1x57x56x2x32xf32>
160+
}
161+
// CHECK: func @tensor_pack_linalg_transpose_fold_with_padding(
162+
// CHECK-SAME: %[[ARG0:.+]]: tensor<56x57x1x55xf32>, %[[PADDING:.+]]: f32)
163+
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x57x56x2x32xf32>
164+
// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] padding_value(%[[PADDING]] : f32)
165+
// CHECK-SAME: outer_dims_perm = [2, 1, 0, 3]
166+
// CHECK-SAME: inner_dims_pos = [3] inner_tiles = [32]
167+
// CHECK-SAME: into %[[INIT]]
168+
// CHECK: return %[[PACK]]
169+
170+
// -----
171+
172+
func.func @tensor_pack_linalg_transpose_fold_no_outer_dims_perm(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x2x56x57x32xf32> {
173+
%0 = tensor.empty() : tensor<56x57x1x2x32xf32>
174+
%pack = tensor.pack %arg0
175+
inner_dims_pos = [3]
176+
inner_tiles = [32]
177+
into %0 : tensor<56x57x1x64xf32> -> tensor<56x57x1x2x32xf32>
178+
179+
%1 = tensor.empty() : tensor<1x2x56x57x32xf32>
180+
%transposed = linalg.transpose
181+
ins(%pack : tensor<56x57x1x2x32xf32>)
182+
outs(%1 : tensor<1x2x56x57x32xf32>)
183+
permutation = [2, 3, 0, 1, 4]
184+
return %transposed : tensor<1x2x56x57x32xf32>
185+
}
186+
// CHECK: func @tensor_pack_linalg_transpose_fold_no_outer_dims_perm(
187+
// CHECK-SAME: %[[ARG0:.+]]: tensor<56x57x1x64xf32>)
188+
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x2x56x57x32xf32>
189+
// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]]
190+
// CHECK-SAME: outer_dims_perm = [2, 3, 0, 1]
191+
// CHECK-SAME: inner_dims_pos = [3] inner_tiles = [32]
192+
// CHECK-SAME: into %[[INIT]]
193+
// CHECK: return %[[PACK]]
194+
195+
// -----
196+
197+
func.func @tensor_pack_linalg_transpose_fold_tile_dims_transpose(%arg0: tensor<56x72x24x128xf32>) -> tensor<12x56x4x9x32x8x2xf32> {
198+
%0 = tensor.empty() : tensor<4x9x12x56x8x2x32xf32>
199+
%pack = tensor.pack %arg0
200+
outer_dims_perm = [3, 1, 2, 0]
201+
inner_dims_pos = [1, 2, 3]
202+
inner_tiles = [8, 2, 32]
203+
into %0 : tensor<56x72x24x128xf32> -> tensor<4x9x12x56x8x2x32xf32>
204+
205+
%1 = tensor.empty() : tensor<12x56x4x9x32x8x2xf32>
206+
%transposed = linalg.transpose
207+
ins(%pack : tensor<4x9x12x56x8x2x32xf32>)
208+
outs(%1 : tensor<12x56x4x9x32x8x2xf32>)
209+
permutation = [2, 3, 0, 1, 6, 4, 5]
210+
return %transposed : tensor<12x56x4x9x32x8x2xf32>
211+
}
212+
// CHECK: func @tensor_pack_linalg_transpose_fold_tile_dims_transpose(
213+
// CHECK-SAME: %[[ARG0:.+]]: tensor<56x72x24x128xf32>)
214+
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<12x56x4x9x32x8x2xf32>
215+
// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]]
216+
// CHECK-SAME: outer_dims_perm = [2, 0, 3, 1]
217+
// CHECK-SAME: inner_dims_pos = [3, 1, 2] inner_tiles = [32, 8, 2]
218+
// CHECK-SAME: into %[[INIT]]
219+
// CHECK: return %[[PACK]]
220+
221+
// -----
222+
223+
func.func @tensor_pack_linalg_transpose_fold_tile_dims_outer_dims_transpose(%arg0: tensor<56x72x24x128xf32>) -> tensor<9x56x2x12x32x8x4xf32> {
224+
%0 = tensor.empty() : tensor<4x12x9x56x8x2x32xf32>
225+
%pack = tensor.pack %arg0
226+
outer_dims_perm = [3, 2, 1, 0]
227+
inner_dims_pos = [1, 2, 3]
228+
inner_tiles = [8, 2, 32]
229+
into %0 : tensor<56x72x24x128xf32> -> tensor<4x12x9x56x8x2x32xf32>
230+
231+
%1 = tensor.empty() : tensor<9x56x2x12x32x8x4xf32>
232+
%transposed = linalg.transpose
233+
ins(%pack : tensor<4x12x9x56x8x2x32xf32>)
234+
outs(%1 : tensor<9x56x2x12x32x8x4xf32>)
235+
permutation = [2, 3, 5, 1, 6, 4, 0]
236+
return %transposed : tensor<9x56x2x12x32x8x4xf32>
237+
}
238+
// CHECK: func @tensor_pack_linalg_transpose_fold_tile_dims_outer_dims_transpose(
239+
// CHECK-SAME: %[[ARG0:.+]]: tensor<56x72x24x128xf32>)
240+
// CHECK: tensor.pack
241+
// CHECK: linalg.transpose
242+
243+
// -----
244+
245+
func.func @tensor_pack_linalg_transpose_fold_dynamic_outer_dims(%arg0: tensor<56x?x?x64xf32>) -> tensor<?x?x56x2x32xf32> {
246+
%0 = tensor.empty() : tensor<56x2x1x57x32xf32>
247+
%pack = tensor.pack %arg0
248+
outer_dims_perm = [0, 3, 2, 1]
249+
inner_dims_pos = [3]
250+
inner_tiles = [32]
251+
into %0 : tensor<56x?x?x64xf32> -> tensor<56x2x1x57x32xf32>
252+
253+
%1 = tensor.empty() : tensor<1x57x56x2x32xf32>
254+
%transposed = linalg.transpose
255+
ins(%pack : tensor<56x2x1x57x32xf32>)
256+
outs(%1 : tensor<1x57x56x2x32xf32>)
257+
permutation = [2, 3, 0, 1, 4]
258+
259+
%return_value = tensor.cast %transposed : tensor<1x57x56x2x32xf32> to tensor<?x?x56x2x32xf32>
260+
return %return_value : tensor<?x?x56x2x32xf32>
261+
}
262+
// CHECK: func @tensor_pack_linalg_transpose_fold_dynamic_outer_dims(
263+
// CHECK-SAME: %[[ARG0:.+]]: tensor<56x?x?x64xf32>)
264+
// CHECK: %[[c1:.+]] = arith.constant 1 : index
265+
// CHECK: %[[c2:.+]] = arith.constant 2 : index
266+
// CHECK: %[[dim:.+]] = tensor.dim %[[ARG0]], %[[c1]] : tensor<56x?x?x64xf32>
267+
// CHECK: %[[dim_0:.+]] = tensor.dim %[[ARG0]], %[[c2]] : tensor<56x?x?x64xf32>
268+
// CHECK: %[[INIT:.+]] = tensor.empty(%[[dim_0]], %[[dim]]) : tensor<?x?x56x2x32xf32>
269+
// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]]
270+
// CHECK-SAME: outer_dims_perm = [2, 1, 0, 3]
271+
// CHECK-SAME: inner_dims_pos = [3] inner_tiles = [32]
272+
// CHECK-SAME: into %[[INIT]]
273+
// CHECK: return %[[PACK]]
274+
275+
// -----
276+
277+
func.func @tensor_pack_linalg_transpose_fold_dynamic_outer_and_tile_dims(%arg0: tensor<56x?x?x128xf32>) -> tensor<?x?x56x9x32x8x2xf32> {
278+
%0 = tensor.empty() : tensor<56x9x12x4x8x2x32xf32>
279+
%pack = tensor.pack %arg0
280+
inner_dims_pos = [1, 2, 3]
281+
inner_tiles = [8, 2, 32]
282+
into %0 : tensor<56x?x?x128xf32> -> tensor<56x9x12x4x8x2x32xf32>
283+
284+
%1 = tensor.empty() : tensor<12x4x56x9x32x8x2xf32>
285+
%transposed = linalg.transpose
286+
ins(%pack : tensor<56x9x12x4x8x2x32xf32>)
287+
outs(%1 : tensor<12x4x56x9x32x8x2xf32>)
288+
permutation = [2, 3, 0, 1, 6, 4, 5]
289+
290+
%return_value = tensor.cast %transposed : tensor<12x4x56x9x32x8x2xf32> to tensor<?x?x56x9x32x8x2xf32>
291+
return %return_value : tensor<?x?x56x9x32x8x2xf32>
292+
}
293+
// CHECK: #[[map:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
294+
// CHECK: #[[map1:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)>
295+
// CHECK: module {
296+
// CHECK: func.func @tensor_pack_linalg_transpose_fold_dynamic_outer_and_tile_dims(
297+
// CHECK-SAME: %[[ARG0:.+]]: tensor<56x?x?x128xf32>)
298+
// CHECK: %[[c1:.+]] = arith.constant 1 : index
299+
// CHECK: %[[c2:.+]] = arith.constant 2 : index
300+
// CHECK: %[[dim:.+]] = tensor.dim %[[ARG0]], %[[c1]] : tensor<56x?x?x128xf32>
301+
// CHECK: %[[dim_0:.+]] = tensor.dim %[[ARG0]], %[[c2]] : tensor<56x?x?x128xf32>
302+
// CHECK: %[[mapped_dim1:.+]] = affine.apply #[[map:.+]]()[%[[dim]]]
303+
// CHECK: %[[mapped_dim2:.+]] = affine.apply #[[map1:.+]]()[%[[dim_0]]]
304+
// CHECK: %[[INIT:.+]] = tensor.empty(%[[mapped_dim2]], %[[mapped_dim1]]) : tensor<?x4x56x?x32x8x2xf32>
305+
// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [2, 3, 0, 1] inner_dims_pos = [3, 1, 2] inner_tiles = [32, 8, 2] into %[[INIT]] : tensor<56x?x?x128xf32> -> tensor<?x4x56x?x32x8x2xf32>
306+
// CHECK: %[[CAST:.+]] = tensor.cast %[[PACK]] : tensor<?x4x56x?x32x8x2xf32> to tensor<?x?x56x9x32x8x2xf32>
307+
// CHECK: return %[[CAST]] : tensor<?x?x56x9x32x8x2xf32>
308+
// CHECK: }
309+
310+
// -----
311+
312+
func.func @tensor_pack_linalg_transpose_fold_dynamic_outer_dims_tile_dims_tile_sizes(%arg0: tensor<?x?x?x?xf32>, %pack_dest: tensor<?x?x?x?x?x?x?xf32>, %transpose_dest: tensor<?x?x?x?x?x?x?xf32>, %tile_p : index, %tile_q : index, %tile_r : index) -> tensor<?x?x?x?x?x?x?xf32> {
313+
%pack = tensor.pack %arg0
314+
outer_dims_perm = [3, 0, 2, 1]
315+
inner_dims_pos = [1, 2, 3]
316+
inner_tiles = [%tile_p, %tile_q, %tile_r]
317+
into %pack_dest : tensor<?x?x?x?xf32> -> tensor<?x?x?x?x?x?x?xf32>
318+
319+
%transposed = linalg.transpose
320+
ins(%pack : tensor<?x?x?x?x?x?x?xf32>)
321+
outs(%transpose_dest : tensor<?x?x?x?x?x?x?xf32>)
322+
permutation = [2, 3, 0, 1, 6, 4, 5]
323+
324+
return %transposed : tensor<?x?x?x?x?x?x?xf32>
325+
}
326+
// CHECK: #[[map:.+]] = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
327+
// CHECK: module {
328+
// CHECK: func.func @tensor_pack_linalg_transpose_fold_dynamic_outer_dims_tile_dims_tile_sizes(
329+
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>,
330+
// CHECK-SAME: %[[PACK_DEST:.+]]: tensor<?x?x?x?x?x?x?xf32>, %[[TRANSPOSE_DEST:.+]]: tensor<?x?x?x?x?x?x?xf32>,
331+
// CHECK-SAME: %[[ARG1:.+]]: index, %[[ARG2:.+]]: index,
332+
// CHECK-SAME: %[[ARG3:.+]]: index)
333+
// CHECK: %[[c0:.+]] = arith.constant 0 : index
334+
// CHECK: %[[c1:.+]] = arith.constant 1 : index
335+
// CHECK: %[[c2:.+]] = arith.constant 2 : index
336+
// CHECK: %[[c3:.+]] = arith.constant 3 : index
337+
// CHECK: %[[dim:.+]] = tensor.dim %[[ARG0]], %[[c0]] : tensor<?x?x?x?xf32>
338+
// CHECK: %[[dim_0:.+]] = tensor.dim %[[ARG0]], %[[c1]] : tensor<?x?x?x?xf32>
339+
// CHECK: %[[dim_1:.+]] = tensor.dim %[[ARG0]], %[[c2]] : tensor<?x?x?x?xf32>
340+
// CHECK: %[[dim_2:.+]] = tensor.dim %[[ARG0]], %[[c3]] : tensor<?x?x?x?xf32>
341+
// CHECK: %[[mapped_dim0:.+]] = affine.apply #[[map:.+]]()[%[[dim_2]], %[[ARG3]]]
342+
// CHECK: %[[mapped_dim1:.+]] = affine.apply #[[map:.+]]()[%[[dim_0]], %[[ARG1]]]
343+
// CHECK: %[[mapped_dim2:.+]] = affine.apply #[[map:.+]]()[%[[dim_1]], %[[ARG2]]]
344+
// CHECK: %[[INIT:.+]] = tensor.empty(%[[mapped_dim2]], %[[mapped_dim1]], %[[mapped_dim0]], %[[dim]], %[[ARG3]], %[[ARG1]], %[[ARG2]]) : tensor<?x?x?x?x?x?x?xf32>
345+
// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [2, 1, 3, 0] inner_dims_pos = [3, 1, 2] inner_tiles = [%[[ARG3]], %[[ARG1]], %[[ARG2]]] into %[[INIT]] : tensor<?x?x?x?xf32> -> tensor<?x?x?x?x?x?x?xf32>
346+
// CHECK: return %[[PACK]] : tensor<?x?x?x?x?x?x?xf32>
347+
// CHECK: }

0 commit comments

Comments
 (0)