Skip to content

Commit 7ef83f5

Browse files
authored
[mlir] Add pack/unpack transpose foldings for linalg.generic ops, fix bugs (#93055)
This PR adds transpose + pack/unpack folding support for transpose ops in the form of `linalg.generic` ops. There were also some bugs with the permutation composing in the previous patterns, so this PR fixes these bugs and adds tests for them as well.
1 parent 2df68e0 commit 7ef83f5

File tree

2 files changed

+221
-41
lines changed

2 files changed

+221
-41
lines changed

mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp

Lines changed: 82 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,34 @@ static LogicalResult isPackOn1D(RewriterBase &rewriter, Operation *op,
4848
return success();
4949
}
5050

51+
// If the `linalgOp` represents a transpose, return the permutation vector for
52+
// the transpose. Otherwise, return failure.
53+
static FailureOr<SmallVector<int64_t>>
54+
getTransposeOpPermutation(linalg::LinalgOp linalgOp) {
55+
if (auto transposeOp = dyn_cast<linalg::TransposeOp>(linalgOp.getOperation()))
56+
return SmallVector<int64_t>(transposeOp.getPermutation());
57+
if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
58+
return failure();
59+
60+
if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
61+
return failure();
62+
auto mapRange = linalgOp.getIndexingMapsArray();
63+
if (!mapRange.front().isPermutation() || !mapRange.back().isPermutation() ||
64+
mapRange.front() == mapRange.back()) {
65+
return failure();
66+
}
67+
if (!llvm::hasSingleElement(linalgOp.getBlock()->getOperations()))
68+
return failure();
69+
AffineMap outMap = mapRange.back();
70+
AffineMap inMap = mapRange.front();
71+
// To get the permutation, look at each output index and find which
72+
// dimension in the input we're reading from for that index.
73+
return llvm::map_to_vector(outMap.getResults(),
74+
[&](AffineExpr expr) -> int64_t {
75+
return *inMap.getResultPosition(expr);
76+
});
77+
}
78+
5179
/// Packing one-dimensional tensor can be expressed as an expand shape op.
5280
struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
5381
using OpRewritePattern<PackOp>::OpRewritePattern;
@@ -246,14 +274,10 @@ static bool checkAndPermute(ArrayRef<int64_t> permutation,
246274

247275
for (unsigned int i = 0; i < rank; ++i) {
248276
int64_t remappedPosition = permutation[i];
249-
250-
if (!inVec.empty()) {
251-
if (remappedPosition >= rank) {
252-
return false;
253-
}
277+
if (remappedPosition >= rank)
278+
return false;
279+
if (!inVec.empty())
254280
remappedPosition = inVec[remappedPosition];
255-
}
256-
257281
resVec.push_back(remappedPosition);
258282
}
259283

@@ -263,20 +287,25 @@ static bool checkAndPermute(ArrayRef<int64_t> permutation,
263287
/// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose
264288
/// semantics.
265289
struct FoldProducerPackWithConsumerLinalgTransposeOp
266-
: public OpRewritePattern<linalg::TransposeOp> {
267-
using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
290+
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
291+
using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
268292

269-
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
293+
LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
270294
PatternRewriter &rewriter) const override {
271-
auto packOp = transposeOp.getOperand(0).getDefiningOp<PackOp>();
295+
auto packOp = linalgOp->getOperand(0).getDefiningOp<PackOp>();
272296

273297
if (!packOp)
274298
return failure();
275299

300+
FailureOr<SmallVector<int64_t>> maybePerm =
301+
getTransposeOpPermutation(linalgOp);
302+
if (failed(maybePerm))
303+
return failure();
304+
276305
auto innerDimsPos = packOp.getInnerDimsPos();
277306
auto mixedInnerTiles = packOp.getMixedTiles();
278307
auto outerDimsPerm = packOp.getOuterDimsPerm();
279-
auto transposePerm = transposeOp.getPermutation();
308+
auto transposePerm = maybePerm.value();
280309
SmallVector<int64_t> newOuterDimsPermVec;
281310
SmallVector<int64_t> newInnerDimsPosVec;
282311
SmallVector<OpFoldResult> newMixedInnerTilesVec;
@@ -285,7 +314,7 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
285314
if (!checkAndPermute(transposePerm, outerDimsPerm, newOuterDimsPermVec,
286315
srcRank))
287316
return rewriter.notifyMatchFailure(
288-
transposeOp,
317+
linalgOp,
289318
"Cannot fold in tensor.pack if a tile dimension was transposed "
290319
"with a non-tile dimension in linalg.transpose.");
291320

@@ -297,11 +326,11 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
297326
}
298327

299328
Value output = packOp.createDestinationTensor(
300-
rewriter, transposeOp.getLoc(), packOp.getSource(),
301-
newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
329+
rewriter, linalgOp.getLoc(), packOp.getSource(), newMixedInnerTilesVec,
330+
newInnerDimsPosVec, newOuterDimsPermVec);
302331

303332
rewriter.replaceOpWithNewOp<PackOp>(
304-
transposeOp, packOp.getSource(), output, newInnerDimsPosVec,
333+
linalgOp, packOp.getSource(), output, newInnerDimsPosVec,
305334
newMixedInnerTilesVec, packOp.getPaddingValue(), newOuterDimsPermVec);
306335

307336
return success();
@@ -316,12 +345,16 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
316345

317346
LogicalResult matchAndRewrite(PackOp packOp,
318347
PatternRewriter &rewriter) const override {
319-
auto transposeOp = packOp.getSource().getDefiningOp<linalg::TransposeOp>();
348+
auto linalgOp = packOp.getSource().getDefiningOp<linalg::LinalgOp>();
349+
if (!linalgOp)
350+
return failure();
320351

321-
if (!transposeOp)
352+
FailureOr<SmallVector<int64_t>> maybePerm =
353+
getTransposeOpPermutation(linalgOp);
354+
if (failed(maybePerm))
322355
return failure();
323356

324-
auto transposePermutation = transposeOp.getPermutation();
357+
auto transposePermutation = maybePerm.value();
325358
auto outerDimsPerm = packOp.getOuterDimsPerm();
326359
auto innerDimsPos = packOp.getInnerDimsPos();
327360
SmallVector<int64_t> newInnerDimsPosVec;
@@ -337,11 +370,11 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
337370
newInnerDimsPosVec.push_back(transposePermutation[dim]);
338371

339372
Value output = packOp.createDestinationTensor(
340-
rewriter, packOp.getLoc(), transposeOp.getOperand(0),
373+
rewriter, packOp.getLoc(), linalgOp->getOperand(0),
341374
packOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec);
342375

343376
rewriter.replaceOpWithNewOp<PackOp>(
344-
packOp, transposeOp.getOperand(0), output, newInnerDimsPosVec,
377+
packOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
345378
packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPermVec);
346379

347380
return success();
@@ -351,34 +384,38 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
351384
/// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has
352385
/// transpose semantics.
353386
struct FoldProducerUnPackWithConsumerLinalgTransposeOp
354-
: public OpRewritePattern<linalg::TransposeOp> {
355-
using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
387+
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
388+
using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
356389

357-
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
390+
LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
358391
PatternRewriter &rewriter) const override {
359-
auto unPackOp = transposeOp.getOperand(0).getDefiningOp<UnPackOp>();
392+
auto unPackOp = linalgOp->getOperand(0).getDefiningOp<UnPackOp>();
360393

361394
if (!unPackOp)
362395
return failure();
363396

364-
auto transposePermutation = transposeOp.getPermutation();
397+
FailureOr<SmallVector<int64_t>> maybePerm =
398+
getTransposeOpPermutation(linalgOp);
399+
if (failed(maybePerm))
400+
return failure();
401+
365402
auto outerDimsPerm = unPackOp.getOuterDimsPerm();
366403
auto innerDimsPos = unPackOp.getInnerDimsPos();
367404
SmallVector<int64_t> newInnerDimsPosVec;
368405
SmallVector<int64_t> newOuterDimsPermVec =
369-
llvm::to_vector(transposePermutation);
370-
371-
if (!outerDimsPerm.empty())
372-
applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
406+
invertPermutationVector(maybePerm.value());
373407

374408
// Can't use applyPermutationToVector for newInnerDimsPosVec since input and
375409
// permutation rank won't necessarily be equal in all cases.
376410
for (auto dim : innerDimsPos)
377-
newInnerDimsPosVec.push_back(transposePermutation[dim]);
411+
newInnerDimsPosVec.push_back(newOuterDimsPermVec[dim]);
412+
413+
if (!outerDimsPerm.empty())
414+
applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
378415

379416
// Reuse the destination of the transpose op.
380417
rewriter.replaceOpWithNewOp<UnPackOp>(
381-
transposeOp, unPackOp.getSource(), transposeOp.getDpsInits()[0],
418+
linalgOp, unPackOp.getSource(), linalgOp.getDpsInits()[0],
382419
newInnerDimsPosVec, unPackOp.getMixedTiles(), newOuterDimsPermVec);
383420

384421
return success();
@@ -393,13 +430,17 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
393430

394431
LogicalResult matchAndRewrite(UnPackOp unPackOp,
395432
PatternRewriter &rewriter) const override {
396-
auto transposeOp =
397-
unPackOp.getSource().getDefiningOp<linalg::TransposeOp>();
433+
auto linalgOp = unPackOp.getSource().getDefiningOp<linalg::LinalgOp>();
434+
if (!linalgOp)
435+
return failure();
398436

399-
if (!transposeOp)
437+
FailureOr<SmallVector<int64_t>> maybePerm =
438+
getTransposeOpPermutation(linalgOp);
439+
if (failed(maybePerm))
400440
return failure();
401441

402-
auto transposePermutation = transposeOp.getPermutation();
442+
SmallVector<int64_t> inverseTransposePerm =
443+
invertPermutationVector(maybePerm.value());
403444
auto outerDimsPerm = unPackOp.getOuterDimsPerm();
404445
auto innerDimsPos = unPackOp.getInnerDimsPos();
405446
int64_t destRank = unPackOp.getSourceRank() - innerDimsPos.size();
@@ -408,26 +449,26 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
408449
SmallVector<int64_t> newInnerDimsPosVec;
409450
SmallVector<OpFoldResult> newMixedInnerTilesVec;
410451

411-
if (!checkAndPermute(transposePermutation, outerDimsPerm,
452+
if (!checkAndPermute(inverseTransposePerm, outerDimsPerm,
412453
newOuterDimsPermVec, destRank))
413454
return rewriter.notifyMatchFailure(
414455
unPackOp,
415456
"Cannot fold in tensor.unpack if a tile dimension was transposed "
416457
"with a non-tile dimension in linalg.transpose.");
417458

418459
// Process transpose operation for tiled inner dimensions
419-
for (unsigned int i = destRank; i < transposePermutation.size(); ++i) {
420-
int64_t remappedPosition = transposePermutation[i] - destRank;
460+
for (unsigned int i = destRank; i < inverseTransposePerm.size(); ++i) {
461+
int64_t remappedPosition = inverseTransposePerm[i] - destRank;
421462
newMixedInnerTilesVec.push_back(mixedInnerTilesVec[remappedPosition]);
422463
newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
423464
}
424465

425466
Value output = unPackOp.createDestinationTensor(
426-
rewriter, unPackOp.getLoc(), transposeOp.getOperand(0),
467+
rewriter, unPackOp.getLoc(), linalgOp->getOperand(0),
427468
newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
428469

429470
rewriter.replaceOpWithNewOp<UnPackOp>(
430-
unPackOp, transposeOp.getOperand(0), output, newInnerDimsPosVec,
471+
unPackOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
431472
newMixedInnerTilesVec, newOuterDimsPermVec);
432473

433474
return success();

mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -636,3 +636,142 @@ func.func @tensor_padded_unpack_linalg_transpose_fold(%arg0: tensor<71x7x4x16x16
636636
// CHECK-SAME: into %[[OUT:.+]] : tensor<71x7x4x16x16xf32> -> tensor<100x71x64xf32>
637637
// CHECK: return %[[UNPACK]] : tensor<100x71x64xf32>
638638
// CHECK: }
639+
640+
// -----
641+
642+
func.func @non_involution_transpose_unpack_fold(%arg0: tensor<2x3x5x4x16xi32>) -> tensor<5x48x8xi32> {
643+
%0 = tensor.empty() : tensor<5x2x3x16x4xi32>
644+
%transposed = linalg.transpose ins(%arg0 : tensor<2x3x5x4x16xi32>)
645+
outs(%0 : tensor<5x2x3x16x4xi32>)
646+
permutation = [2, 0, 1, 4, 3]
647+
%1 = tensor.empty() : tensor<5x48x8xi32>
648+
%unpack = tensor.unpack %transposed
649+
outer_dims_perm = [0, 2, 1]
650+
inner_dims_pos = [1, 2]
651+
inner_tiles = [16, 4] into
652+
%1 : tensor<5x2x3x16x4xi32> -> tensor<5x48x8xi32>
653+
return %unpack : tensor<5x48x8xi32>
654+
}
655+
//CHECK-LABEL: func.func @non_involution_transpose_unpack_fold(
656+
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x5x4x16xi32>) -> tensor<5x48x8xi32> {
657+
// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<5x48x8xi32>
658+
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
659+
// CHECK-SAME: outer_dims_perm = [2, 1, 0]
660+
// CHECK-SAME: inner_dims_pos = [2, 1]
661+
// CHECK-SAME: inner_tiles = [4, 16]
662+
// CHEKC-SAME: into %[[OUT]] : tensor<2x3x5x4x16xi32> -> tensor<5x48x8xi32>
663+
// CHECK: return %[[UNPACK]] : tensor<5x48x8xi32>
664+
// CHECK: }
665+
666+
// -----
667+
668+
func.func @unpack_non_involution_transpose_fold(%arg0: tensor<57x3x56x1x64xf32>) -> tensor<3648x3x56xf32> {
669+
%0 = tensor.empty() : tensor<3x56x3648xf32>
670+
%unpack = tensor.unpack %arg0
671+
outer_dims_perm = [2, 0, 1]
672+
inner_dims_pos = [1, 2]
673+
inner_tiles = [1, 64]
674+
into %0 : tensor<57x3x56x1x64xf32> -> tensor<3x56x3648xf32>
675+
676+
%1 = tensor.empty() : tensor<3648x3x56xf32>
677+
%transposed = linalg.transpose
678+
ins(%unpack : tensor<3x56x3648xf32>)
679+
outs(%1 : tensor<3648x3x56xf32>)
680+
permutation = [2, 0, 1]
681+
return %transposed : tensor<3648x3x56xf32>
682+
}
683+
// CHECK-LABEL: func.func @unpack_non_involution_transpose_fold(
684+
// CHECK-SAME: %[[ARG0:.+]]: tensor<57x3x56x1x64xf32>) -> tensor<3648x3x56xf32> {
685+
// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<3648x3x56xf32>
686+
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
687+
// CHECK-SAME: outer_dims_perm = [0, 1, 2]
688+
// CHECK-SAME: inner_dims_pos = [2, 0]
689+
// CHECK-SAME: inner_tiles = [1, 64]
690+
// CHECK-SAME: into %[[OUT:.+]] : tensor<57x3x56x1x64xf32> -> tensor<3648x3x56xf32>
691+
// CHECK: return %[[UNPACK]] : tensor<3648x3x56xf32>
692+
// CHECK: }
693+
694+
// -----
695+
696+
func.func @transpose_unpacked_dims_no_fold(%arg0: tensor<2x16x5x4x3xi32>) -> tensor<5x32x12xi32> {
697+
%0 = tensor.empty() : tensor<5x2x3x16x4xi32>
698+
%transposed = linalg.transpose ins(%arg0 : tensor<2x16x5x4x3xi32>)
699+
outs(%0 : tensor<5x2x3x16x4xi32>)
700+
permutation = [2, 0, 4, 1, 3]
701+
%1 = tensor.empty() : tensor<5x32x12xi32>
702+
%unpack = tensor.unpack %transposed
703+
inner_dims_pos = [1, 2]
704+
inner_tiles = [16, 4] into
705+
%1 : tensor<5x2x3x16x4xi32> -> tensor<5x32x12xi32>
706+
return %unpack : tensor<5x32x12xi32>
707+
}
708+
//CHECK-LABEL: func.func @transpose_unpacked_dims_no_fold(
709+
// CHECK: linalg.transpose
710+
// CHECK: tensor.unpack
711+
712+
// -----
713+
714+
#map = affine_map<(d0, d1, d2, d3, d4)->(d1, d2, d0, d4, d3)>
715+
#map1 = affine_map<(d0, d1, d2, d3, d4)->(d0, d1, d2, d3, d4)>
716+
func.func @generic_transpose_unpack_fold(%arg0: tensor<2x3x5x4x16xi32>) -> tensor<5x48x8xi32> {
717+
%0 = tensor.empty() : tensor<5x2x3x16x4xi32>
718+
%transposed = linalg.generic {
719+
iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"],
720+
indexing_maps = [#map, #map1]}
721+
ins(%arg0 : tensor<2x3x5x4x16xi32>)
722+
outs(%0 : tensor<5x2x3x16x4xi32>) {
723+
^bb0(%in : i32, %out : i32):
724+
linalg.yield %in : i32
725+
} -> tensor<5x2x3x16x4xi32>
726+
%1 = tensor.empty() : tensor<5x48x8xi32>
727+
%unpack = tensor.unpack %transposed
728+
outer_dims_perm = [0, 2, 1]
729+
inner_dims_pos = [1, 2]
730+
inner_tiles = [16, 4] into
731+
%1 : tensor<5x2x3x16x4xi32> -> tensor<5x48x8xi32>
732+
return %unpack : tensor<5x48x8xi32>
733+
}
734+
//CHECK-LABEL: func.func @generic_transpose_unpack_fold(
735+
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x5x4x16xi32>) -> tensor<5x48x8xi32> {
736+
// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<5x48x8xi32>
737+
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
738+
// CHECK-SAME: outer_dims_perm = [2, 1, 0]
739+
// CHECK-SAME: inner_dims_pos = [2, 1]
740+
// CHECK-SAME: inner_tiles = [4, 16]
741+
// CHEKC-SAME: into %[[OUT]] : tensor<2x3x5x4x16xi32> -> tensor<5x48x8xi32>
742+
// CHECK: return %[[UNPACK]] : tensor<5x48x8xi32>
743+
// CHECK: }
744+
745+
// -----
746+
747+
#map = affine_map<(d0, d1, d2)->(d1, d2, d0)>
748+
#map1 = affine_map<(d0, d1, d2)->(d0, d1, d2)>
749+
func.func @unpack_generic_transpose_fold(%arg0: tensor<57x3x56x1x64xf32>) -> tensor<3648x3x56xf32> {
750+
%0 = tensor.empty() : tensor<3x56x3648xf32>
751+
%unpack = tensor.unpack %arg0
752+
outer_dims_perm = [2, 0, 1]
753+
inner_dims_pos = [1, 2]
754+
inner_tiles = [1, 64]
755+
into %0 : tensor<57x3x56x1x64xf32> -> tensor<3x56x3648xf32>
756+
757+
%1 = tensor.empty() : tensor<3648x3x56xf32>
758+
%transposed = linalg.generic {
759+
iterator_types = ["parallel", "parallel", "parallel"],
760+
indexing_maps = [#map, #map1]}
761+
ins(%unpack : tensor<3x56x3648xf32>)
762+
outs(%1 : tensor<3648x3x56xf32>) {
763+
^bb0(%in : f32, %out : f32):
764+
linalg.yield %in : f32
765+
} -> tensor<3648x3x56xf32>
766+
return %transposed : tensor<3648x3x56xf32>
767+
}
768+
// CHECK-LABEL: func.func @unpack_generic_transpose_fold(
769+
// CHECK-SAME: %[[ARG0:.+]]: tensor<57x3x56x1x64xf32>) -> tensor<3648x3x56xf32> {
770+
// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<3648x3x56xf32>
771+
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
772+
// CHECK-SAME: outer_dims_perm = [0, 1, 2]
773+
// CHECK-SAME: inner_dims_pos = [2, 0]
774+
// CHECK-SAME: inner_tiles = [1, 64]
775+
// CHECK-SAME: into %[[OUT:.+]] : tensor<57x3x56x1x64xf32> -> tensor<3648x3x56xf32>
776+
// CHECK: return %[[UNPACK]] : tensor<3648x3x56xf32>
777+
// CHECK: }

0 commit comments

Comments
 (0)