Skip to content

Commit aedabce

Browse files
committed
Fix merge conflicts
1 parent a960703 commit aedabce

File tree

2 files changed

+204
-1
lines changed

2 files changed

+204
-1
lines changed

mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1010
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1111
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
12+
#include "mlir/Dialect/Utils/IndexingUtils.h"
1213
#include "mlir/IR/PatternMatch.h"
1314
#include "llvm/Support/Debug.h"
1415

@@ -223,11 +224,52 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
223224
return success();
224225
}
225226
};
227+
228+
/// Fold 'transpose' -> 'pack' into 'pack' since 'pack' already has transpose
229+
/// semantics.
230+
struct FoldConsumerPackWithProducerLinalgTransposeOp
231+
: public OpRewritePattern<PackOp> {
232+
using OpRewritePattern<PackOp>::OpRewritePattern;
233+
234+
LogicalResult matchAndRewrite(PackOp packOp,
235+
PatternRewriter &rewriter) const override {
236+
auto transposeOp = packOp.getSource().getDefiningOp<linalg::TransposeOp>();
237+
238+
if (!transposeOp)
239+
return failure();
240+
241+
auto transposePermutation = transposeOp.getPermutation();
242+
auto outerDimsPerm = packOp.getOuterDimsPerm();
243+
auto innerDimsPos = packOp.getInnerDimsPos();
244+
SmallVector<int64_t> newInnerDimsPosVec;
245+
SmallVector<int64_t> newOuterDimsPermVec =
246+
llvm::to_vector(transposePermutation);
247+
248+
if (!outerDimsPerm.empty())
249+
applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
250+
251+
for (auto dim : innerDimsPos) {
252+
newInnerDimsPosVec.push_back(llvm::find(transposePermutation, dim) -
253+
transposePermutation.begin());
254+
}
255+
256+
Value output = packOp.createDestinationTensor(
257+
rewriter, packOp.getLoc(), transposeOp.getOperand(0),
258+
packOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec);
259+
260+
rewriter.replaceOpWithNewOp<PackOp>(
261+
packOp, transposeOp.getOperand(0), output, newInnerDimsPosVec,
262+
packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPermVec);
263+
264+
return success();
265+
}
266+
};
226267
} // namespace
227268

228269
void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) {
229270
patterns.insert<FoldUnpackWithExtractSliceOp, FoldPadWithPackOp,
230-
FoldProducerPackWithConsumerLinalgTransposeOp>(
271+
FoldProducerPackWithConsumerLinalgTransposeOp,
272+
FoldConsumerPackWithProducerLinalgTransposeOp>(
231273
patterns.getContext());
232274
}
233275

mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,3 +345,164 @@ func.func @tensor_pack_linalg_transpose_fold_dynamic_outer_dims_tile_dims_tile_s
345345
// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [2, 1, 3, 0] inner_dims_pos = [3, 1, 2] inner_tiles = [%[[ARG3]], %[[ARG1]], %[[ARG2]]] into %[[INIT]] : tensor<?x?x?x?xf32> -> tensor<?x?x?x?x?x?x?xf32>
346346
// CHECK: return %[[PACK]] : tensor<?x?x?x?x?x?x?xf32>
347347
// CHECK: }
348+
349+
// -----
350+
351+
func.func @linalg_transpose_tensor_pack_fold(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x57x56x2x32xf32> {
352+
%0 = tensor.empty() : tensor<1x56x57x64xf32>
353+
%transposed = linalg.transpose
354+
ins(%arg0 : tensor<56x57x1x64xf32>)
355+
outs(%0 : tensor<1x56x57x64xf32>)
356+
permutation = [2, 0, 1, 3]
357+
358+
%1 = tensor.empty() : tensor<1x57x56x2x32xf32>
359+
%pack = tensor.pack %transposed
360+
outer_dims_perm = [0, 2, 1, 3]
361+
inner_dims_pos = [3]
362+
inner_tiles = [32]
363+
into %1 : tensor<1x56x57x64xf32> -> tensor<1x57x56x2x32xf32>
364+
return %pack : tensor<1x57x56x2x32xf32>
365+
}
366+
//CHECK-LABEL: func @linalg_transpose_tensor_pack_fold(
367+
// CHECK-SAME: %[[ARG0:.+]]: tensor<56x57x1x64xf32>)
368+
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x57x56x2x32xf32>
369+
// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]]
370+
// CHECK-SAME: outer_dims_perm = [2, 1, 0, 3]
371+
// CHECK-SAME: inner_dims_pos = [3] inner_tiles = [32]
372+
// CHECK-SAME: into %[[INIT]]
373+
// CHECK: return %[[PACK]]
374+
375+
// -----
376+
377+
func.func @linalg_transpose_tensor_pack_fold_with_padding(%arg0: tensor<56x57x1x55xf32>, %padding: f32) -> tensor<1x57x56x2x32xf32> {
378+
%0 = tensor.empty() : tensor<1x56x57x55xf32>
379+
%transpose = linalg.transpose
380+
ins(%arg0 : tensor<56x57x1x55xf32>)
381+
outs(%0 : tensor<1x56x57x55xf32>)
382+
permutation = [2, 0, 1, 3]
383+
384+
%1 = tensor.empty() : tensor<1x57x56x2x32xf32>
385+
%pack = tensor.pack %transpose padding_value(%padding : f32)
386+
outer_dims_perm = [0, 2, 1, 3]
387+
inner_dims_pos = [3]
388+
inner_tiles = [32]
389+
into %1 : tensor<1x56x57x55xf32> -> tensor<1x57x56x2x32xf32>
390+
return %pack : tensor<1x57x56x2x32xf32>
391+
}
392+
//CHECK-LABEL: func @linalg_transpose_tensor_pack_fold_with_padding(
393+
// CHECK-SAME: %[[ARG0:.+]]: tensor<56x57x1x55xf32>, %[[PADDING:.+]]: f32)
394+
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x57x56x2x32xf32>
395+
// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] padding_value(%[[PADDING]] : f32)
396+
// CHECK-SAME: outer_dims_perm = [2, 1, 0, 3]
397+
// CHECK-SAME: inner_dims_pos = [3] inner_tiles = [32]
398+
// CHECK-SAME: into %[[INIT]]
399+
// CHECK: return %[[PACK]]
400+
401+
// -----
402+
403+
func.func @linalg_transpose_tensor_pack_fold_no_outer_dims_perm(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x56x57x2x32xf32> {
404+
%0 = tensor.empty() : tensor<1x56x57x64xf32>
405+
%transposed = linalg.transpose
406+
ins(%arg0 : tensor<56x57x1x64xf32>)
407+
outs(%0 : tensor<1x56x57x64xf32>)
408+
permutation = [2, 0, 1, 3]
409+
410+
%1 = tensor.empty() : tensor<1x56x57x2x32xf32>
411+
%pack = tensor.pack %transposed
412+
inner_dims_pos = [3]
413+
inner_tiles = [32]
414+
into %1 : tensor<1x56x57x64xf32> -> tensor<1x56x57x2x32xf32>
415+
return %pack : tensor<1x56x57x2x32xf32>
416+
}
417+
//CHECK-LABEL: func @linalg_transpose_tensor_pack_fold_no_outer_dims_perm(
418+
// CHECK-SAME: %[[ARG0:.+]]: tensor<56x57x1x64xf32>)
419+
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x56x57x2x32xf32>
420+
// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]]
421+
// CHECK-SAME: outer_dims_perm = [2, 0, 1, 3]
422+
// CHECK-SAME: inner_dims_pos = [3] inner_tiles = [32]
423+
// CHECK-SAME: into %[[INIT]]
424+
// CHECK: return %[[PACK]]
425+
426+
// -----
427+
428+
func.func @linalg_transpose_tensor_pack_fold_complex_inner_dims_change(%arg0: tensor<25x30x35x40xf32>, %transpose_dest: tensor<35x40x25x30xf32>, %pack_dest: tensor<3x35x5x8x5x10x5xf32>) -> tensor<3x35x5x8x5x10x5xf32> {
429+
%transposed = linalg.transpose
430+
ins(%arg0 : tensor<25x30x35x40xf32>)
431+
outs(%transpose_dest : tensor<35x40x25x30xf32>)
432+
permutation = [2, 3, 0, 1]
433+
434+
%pack = tensor.pack %transposed
435+
outer_dims_perm = [3, 0, 2, 1]
436+
inner_dims_pos = [1, 3, 2]
437+
inner_tiles = [5, 10, 5]
438+
into %pack_dest : tensor<35x40x25x30xf32> -> tensor<3x35x5x8x5x10x5xf32>
439+
return %pack : tensor<3x35x5x8x5x10x5xf32>
440+
}
441+
//CHECK-LABEL: func.func @linalg_transpose_tensor_pack_fold_complex_inner_dims_change(
442+
// CHECK-SAME: %[[ARG0:.+]]: tensor<25x30x35x40xf32>,
443+
// CHECK-SAME: %[[ARG1:.+]]: tensor<35x40x25x30xf32>,
444+
// CHECK-SAME: %[[ARG2:.+]]: tensor<3x35x5x8x5x10x5xf32>) -> tensor<3x35x5x8x5x10x5xf32> {
445+
// CHECK: %[[VAL0:.+]] = tensor.empty() : tensor<3x35x5x8x5x10x5xf32>
446+
// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]]
447+
// CHECK-SAME: outer_dims_perm = [1, 2, 0, 3]
448+
// CHECK-SAME: inner_dims_pos = [3, 1, 0]
449+
// CHECK-SAME: inner_tiles = [5, 10, 5]
450+
// CHECK-SAME: into %[[VAL0]]
451+
// CHECK: return %[[PACK]]
452+
453+
// -----
454+
455+
func.func @linalg_transpose_tensor_pack_fold_dynamic_outer_dims_tile_dims_tile_sizes(%arg0: tensor<?x?x?x?xf32>, %transpose_dest: tensor<?x?x?x?xf32>, %pack_dest: tensor<?x?x?x?x?x?x?xf32>, %tile_p : index, %tile_q : index, %tile_r : index) -> tensor<?x?x?x?x?x?x?xf32> {
456+
%transposed = linalg.transpose
457+
ins(%arg0 : tensor<?x?x?x?xf32>)
458+
outs(%transpose_dest : tensor<?x?x?x?xf32>)
459+
permutation = [2, 3, 0, 1]
460+
461+
%pack = tensor.pack %transposed
462+
outer_dims_perm = [3, 0, 2, 1]
463+
inner_dims_pos = [1, 3, 2]
464+
inner_tiles = [%tile_p, %tile_q, %tile_r]
465+
into %pack_dest : tensor<?x?x?x?xf32> -> tensor<?x?x?x?x?x?x?xf32>
466+
return %pack : tensor<?x?x?x?x?x?x?xf32>
467+
}
468+
// CHECK: #[[map:.+]] = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
469+
//CHECK-LABEL: func.func @linalg_transpose_tensor_pack_fold_dynamic_outer_dims_tile_dims_tile_sizes(
470+
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>, %[[ARG1:.+]]: tensor<?x?x?x?xf32>,
471+
// CHECK-SAME: %[[ARG2:.+]]: tensor<?x?x?x?x?x?x?xf32>, %[[ARG3:.+]]: index, %[[ARG4:.+]]: index, %[[ARG5:.+]]: index) -> tensor<?x?x?x?x?x?x?xf32> {
472+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
473+
// CHECK: %[[C1:.+]] = arith.constant 1 : index
474+
// CHECK: %[[C2:.+]] = arith.constant 2 : index
475+
// CHECK: %[[C3:.+]] = arith.constant 3 : index
476+
// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x?x?x?xf32>
477+
// CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<?x?x?x?xf32>
478+
// CHECK: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor<?x?x?x?xf32>
479+
// CHECK: %[[DIM2:.+]] = tensor.dim %[[ARG0]], %[[C3]] : tensor<?x?x?x?xf32>
480+
// CHECK: %[[VAL0:.+]] = affine.apply #[[map:.+]]()[%[[DIM2]], %[[ARG3]]]
481+
// CHECK: %[[VAL1:.+]] = affine.apply #[[map:.+]]()[%[[DIM0]], %[[ARG4]]]
482+
// CHECK: %[[VAL2:.+]] = affine.apply #[[map:.+]]()[%[[DIM]], %[[ARG5]]]
483+
// CHECK: %[[VAL3:.+]] = tensor.empty(%[[VAL1]], %[[DIM1]], %[[VAL2]], %[[VAL0]], %[[ARG3]], %[[ARG4]], %[[ARG5]]) : tensor<?x?x?x?x?x?x?xf32>
484+
// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] outer_dims_perm = [1, 2, 0, 3] inner_dims_pos = [3, 1, 0] inner_tiles = [%[[ARG3]], %[[ARG4]], %[[ARG5]]] into %[[VAL3]] : tensor<?x?x?x?xf32> -> tensor<?x?x?x?x?x?x?xf32>
485+
// CHECK: return %[[PACK]] : tensor<?x?x?x?x?x?x?xf32>
486+
487+
// -----
488+
489+
func.func @linalg_transpose_tensor_cast_tensor_pack_fold(%arg0: tensor<56x57x1x64xf32>) -> tensor<1x57x56x2x32xf32> {
490+
%0 = tensor.empty() : tensor<1x56x57x64xf32>
491+
%transposed = linalg.transpose
492+
ins(%arg0 : tensor<56x57x1x64xf32>)
493+
outs(%0 : tensor<1x56x57x64xf32>)
494+
permutation = [2, 0, 1, 3]
495+
496+
%transposed_cast = tensor.cast %transposed : tensor<1x56x57x64xf32> to tensor<?x56x57x64xf32>
497+
%1 = tensor.empty() : tensor<1x57x56x2x32xf32>
498+
%pack = tensor.pack %transposed_cast
499+
outer_dims_perm = [0, 2, 1, 3]
500+
inner_dims_pos = [3]
501+
inner_tiles = [32]
502+
into %1 : tensor<?x56x57x64xf32> -> tensor<1x57x56x2x32xf32>
503+
return %pack : tensor<1x57x56x2x32xf32>
504+
}
505+
//CHECK-LABEL: func @linalg_transpose_tensor_cast_tensor_pack_fold(
506+
// CHECK-SAME: %[[ARG0:.+]]: tensor<56x57x1x64xf32>)
507+
// CHECK: linalg.transpose
508+
// CHECK: tensor.pack

0 commit comments

Comments
 (0)