Skip to content

Commit e422446

Browse files
committed
[mlir] Add pack transpose foldings for linalg.generic transpose ops and fix bugs
1 parent 5345901 commit e422446

File tree

2 files changed

+226
-39
lines changed

2 files changed

+226
-39
lines changed

mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp

Lines changed: 87 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,34 @@ static LogicalResult isPackOn1D(RewriterBase &rewriter, Operation *op,
4848
return success();
4949
}
5050

51+
// If the `linalgOp` represents a transpose, return the permutation vector for
52+
// the transpose. Otherwise, return failure.
53+
static FailureOr<SmallVector<int64_t>>
54+
getTransposeOpPermutation(linalg::LinalgOp linalgOp) {
55+
if (auto transposeOp = dyn_cast<linalg::TransposeOp>(linalgOp.getOperation()))
56+
return SmallVector<int64_t>(transposeOp.getPermutation());
57+
if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
58+
return failure();
59+
60+
if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
61+
return failure();
62+
auto mapRange = linalgOp.getIndexingMapsArray();
63+
if (mapRange.size() != 2 || !mapRange.front().isPermutation() ||
64+
!mapRange.back().isPermutation() || mapRange.front() == mapRange.back()) {
65+
return failure();
66+
}
67+
if (!llvm::hasSingleElement(linalgOp.getBlock()->getOperations()))
68+
return failure();
69+
AffineMap outMap = mapRange.back();
70+
AffineMap inMap = mapRange.front();
71+
// To get the permutation, look at each output index and find which
72+
// dimension in the input we're reading from for that index.
73+
return llvm::map_to_vector(outMap.getResults(),
74+
[&](AffineExpr expr) -> int64_t {
75+
return *inMap.getResultPosition(expr);
76+
});
77+
}
78+
5179
/// Packing one-dimensional tensor can be expressed as an expand shape op.
5280
struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
5381
using OpRewritePattern<PackOp>::OpRewritePattern;
@@ -246,14 +274,10 @@ static bool checkAndPermute(ArrayRef<int64_t> permutation,
246274

247275
for (unsigned int i = 0; i < rank; ++i) {
248276
int64_t remappedPosition = permutation[i];
249-
250-
if (!inVec.empty()) {
251-
if (remappedPosition >= rank) {
252-
return false;
253-
}
277+
if (remappedPosition >= rank)
278+
return false;
279+
if (!inVec.empty())
254280
remappedPosition = inVec[remappedPosition];
255-
}
256-
257281
resVec.push_back(remappedPosition);
258282
}
259283

@@ -263,20 +287,26 @@ static bool checkAndPermute(ArrayRef<int64_t> permutation,
263287
/// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose
264288
/// semantics.
265289
struct FoldProducerPackWithConsumerLinalgTransposeOp
266-
: public OpRewritePattern<linalg::TransposeOp> {
267-
using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
290+
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
291+
using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
268292

269-
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
293+
LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
270294
PatternRewriter &rewriter) const override {
271-
auto packOp = transposeOp.getOperand(0).getDefiningOp<PackOp>();
295+
auto packOp = linalgOp->getOperand(0).getDefiningOp<PackOp>();
272296

273297
if (!packOp)
274298
return failure();
275299

300+
FailureOr<SmallVector<int64_t>> maybePerm =
301+
getTransposeOpPermutation(linalgOp);
302+
if (failed(maybePerm)) {
303+
return failure();
304+
}
305+
276306
auto innerDimsPos = packOp.getInnerDimsPos();
277307
auto mixedInnerTiles = packOp.getMixedTiles();
278308
auto outerDimsPerm = packOp.getOuterDimsPerm();
279-
auto transposePerm = transposeOp.getPermutation();
309+
auto transposePerm = maybePerm.value();
280310
SmallVector<int64_t> newOuterDimsPermVec;
281311
SmallVector<int64_t> newInnerDimsPosVec;
282312
SmallVector<OpFoldResult> newMixedInnerTilesVec;
@@ -285,7 +315,7 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
285315
if (!checkAndPermute(transposePerm, outerDimsPerm, newOuterDimsPermVec,
286316
srcRank))
287317
return rewriter.notifyMatchFailure(
288-
transposeOp,
318+
linalgOp,
289319
"Cannot fold in tensor.pack if a tile dimension was transposed "
290320
"with a non-tile dimension in linalg.transpose.");
291321

@@ -297,11 +327,11 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
297327
}
298328

299329
Value output = packOp.createDestinationTensor(
300-
rewriter, transposeOp.getLoc(), packOp.getSource(),
301-
newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
330+
rewriter, linalgOp.getLoc(), packOp.getSource(), newMixedInnerTilesVec,
331+
newInnerDimsPosVec, newOuterDimsPermVec);
302332

303333
rewriter.replaceOpWithNewOp<PackOp>(
304-
transposeOp, packOp.getSource(), output, newInnerDimsPosVec,
334+
linalgOp, packOp.getSource(), output, newInnerDimsPosVec,
305335
newMixedInnerTilesVec, packOp.getPaddingValue(), newOuterDimsPermVec);
306336

307337
return success();
@@ -316,12 +346,17 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
316346

317347
LogicalResult matchAndRewrite(PackOp packOp,
318348
PatternRewriter &rewriter) const override {
319-
auto transposeOp = packOp.getSource().getDefiningOp<linalg::TransposeOp>();
349+
auto linalgOp = packOp.getSource().getDefiningOp<linalg::LinalgOp>();
350+
if (!linalgOp)
351+
return failure();
320352

321-
if (!transposeOp)
353+
FailureOr<SmallVector<int64_t>> maybePerm =
354+
getTransposeOpPermutation(linalgOp);
355+
if (failed(maybePerm)) {
322356
return failure();
357+
}
323358

324-
auto transposePermutation = transposeOp.getPermutation();
359+
auto transposePermutation = maybePerm.value();
325360
auto outerDimsPerm = packOp.getOuterDimsPerm();
326361
auto innerDimsPos = packOp.getInnerDimsPos();
327362
SmallVector<int64_t> newInnerDimsPosVec;
@@ -337,11 +372,11 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
337372
newInnerDimsPosVec.push_back(transposePermutation[dim]);
338373

339374
Value output = packOp.createDestinationTensor(
340-
rewriter, packOp.getLoc(), transposeOp.getOperand(0),
375+
rewriter, packOp.getLoc(), linalgOp->getOperand(0),
341376
packOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec);
342377

343378
rewriter.replaceOpWithNewOp<PackOp>(
344-
packOp, transposeOp.getOperand(0), output, newInnerDimsPosVec,
379+
packOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
345380
packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPermVec);
346381

347382
return success();
@@ -351,34 +386,41 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
351386
/// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has
352387
/// transpose semantics.
353388
struct FoldProducerUnPackWithConsumerLinalgTransposeOp
354-
: public OpRewritePattern<linalg::TransposeOp> {
355-
using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
389+
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
390+
using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
356391

357-
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
392+
LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
358393
PatternRewriter &rewriter) const override {
359-
auto unPackOp = transposeOp.getOperand(0).getDefiningOp<UnPackOp>();
394+
auto unPackOp = linalgOp->getOperand(0).getDefiningOp<UnPackOp>();
360395

361396
if (!unPackOp)
362397
return failure();
363398

364-
auto transposePermutation = transposeOp.getPermutation();
399+
FailureOr<SmallVector<int64_t>> maybePerm =
400+
getTransposeOpPermutation(linalgOp);
401+
if (failed(maybePerm)) {
402+
return failure();
403+
}
404+
405+
auto transposePermutation = maybePerm.value();
406+
SmallVector<int64_t> inverseTransposePerm =
407+
invertPermutationVector(transposePermutation);
365408
auto outerDimsPerm = unPackOp.getOuterDimsPerm();
366409
auto innerDimsPos = unPackOp.getInnerDimsPos();
367410
SmallVector<int64_t> newInnerDimsPosVec;
368-
SmallVector<int64_t> newOuterDimsPermVec =
369-
llvm::to_vector(transposePermutation);
411+
SmallVector<int64_t> newOuterDimsPermVec = inverseTransposePerm;
370412

371413
if (!outerDimsPerm.empty())
372414
applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
373415

374416
// Can't use applyPermutationToVector for newInnerDimsPosVec since input and
375417
// permutation rank won't necessarily be equal in all cases.
376418
for (auto dim : innerDimsPos)
377-
newInnerDimsPosVec.push_back(transposePermutation[dim]);
419+
newInnerDimsPosVec.push_back(inverseTransposePerm[dim]);
378420

379421
// Reuse the destination of the transpose op.
380422
rewriter.replaceOpWithNewOp<UnPackOp>(
381-
transposeOp, unPackOp.getSource(), transposeOp.getDpsInits()[0],
423+
linalgOp, unPackOp.getSource(), linalgOp.getDpsInits()[0],
382424
newInnerDimsPosVec, unPackOp.getMixedTiles(), newOuterDimsPermVec);
383425

384426
return success();
@@ -393,13 +435,19 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
393435

394436
LogicalResult matchAndRewrite(UnPackOp unPackOp,
395437
PatternRewriter &rewriter) const override {
396-
auto transposeOp =
397-
unPackOp.getSource().getDefiningOp<linalg::TransposeOp>();
438+
auto linalgOp = unPackOp.getSource().getDefiningOp<linalg::LinalgOp>();
439+
if (!linalgOp)
440+
return failure();
398441

399-
if (!transposeOp)
442+
FailureOr<SmallVector<int64_t>> maybePerm =
443+
getTransposeOpPermutation(linalgOp);
444+
if (failed(maybePerm)) {
400445
return failure();
446+
}
401447

402-
auto transposePermutation = transposeOp.getPermutation();
448+
auto transposePermutation = maybePerm.value();
449+
SmallVector<int64_t> inverseTransposePerm =
450+
invertPermutationVector(transposePermutation);
403451
auto outerDimsPerm = unPackOp.getOuterDimsPerm();
404452
auto innerDimsPos = unPackOp.getInnerDimsPos();
405453
int64_t destRank = unPackOp.getSourceRank() - innerDimsPos.size();
@@ -408,26 +456,26 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
408456
SmallVector<int64_t> newInnerDimsPosVec;
409457
SmallVector<OpFoldResult> newMixedInnerTilesVec;
410458

411-
if (!checkAndPermute(transposePermutation, outerDimsPerm,
459+
if (!checkAndPermute(inverseTransposePerm, outerDimsPerm,
412460
newOuterDimsPermVec, destRank))
413461
return rewriter.notifyMatchFailure(
414462
unPackOp,
415463
"Cannot fold in tensor.unpack if a tile dimension was transposed "
416464
"with a non-tile dimension in linalg.transpose.");
417465

418466
// Process transpose operation for tiled inner dimensions
419-
for (unsigned int i = destRank; i < transposePermutation.size(); ++i) {
420-
int64_t remappedPosition = transposePermutation[i] - destRank;
467+
for (unsigned int i = destRank; i < inverseTransposePerm.size(); ++i) {
468+
int64_t remappedPosition = inverseTransposePerm[i] - destRank;
421469
newMixedInnerTilesVec.push_back(mixedInnerTilesVec[remappedPosition]);
422470
newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
423471
}
424472

425473
Value output = unPackOp.createDestinationTensor(
426-
rewriter, unPackOp.getLoc(), transposeOp.getOperand(0),
474+
rewriter, unPackOp.getLoc(), linalgOp->getOperand(0),
427475
newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
428476

429477
rewriter.replaceOpWithNewOp<UnPackOp>(
430-
unPackOp, transposeOp.getOperand(0), output, newInnerDimsPosVec,
478+
unPackOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
431479
newMixedInnerTilesVec, newOuterDimsPermVec);
432480

433481
return success();

mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -636,3 +636,142 @@ func.func @tensor_padded_unpack_linalg_transpose_fold(%arg0: tensor<71x7x4x16x16
636636
// CHECK-SAME: into %[[OUT:.+]] : tensor<71x7x4x16x16xf32> -> tensor<100x71x64xf32>
637637
// CHECK: return %[[UNPACK]] : tensor<100x71x64xf32>
638638
// CHECK: }
639+
640+
// -----
641+
642+
func.func @non_involution_transpose_unpack_fold(%arg0: tensor<2x3x5x4x16xi32>) -> tensor<5x48x8xi32> {
643+
%0 = tensor.empty() : tensor<5x2x3x16x4xi32>
644+
%transposed = linalg.transpose ins(%arg0 : tensor<2x3x5x4x16xi32>)
645+
outs(%0 : tensor<5x2x3x16x4xi32>)
646+
permutation = [2, 0, 1, 4, 3]
647+
%1 = tensor.empty() : tensor<5x48x8xi32>
648+
%unpack = tensor.unpack %transposed
649+
outer_dims_perm = [0, 2, 1]
650+
inner_dims_pos = [1, 2]
651+
inner_tiles = [16, 4] into
652+
%1 : tensor<5x2x3x16x4xi32> -> tensor<5x48x8xi32>
653+
return %unpack : tensor<5x48x8xi32>
654+
}
655+
//CHECK-LABEL: func.func @non_involution_transpose_unpack_fold(
656+
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x5x4x16xi32>) -> tensor<5x48x8xi32> {
657+
// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<5x48x8xi32>
658+
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
659+
// CHECK-SAME: outer_dims_perm = [2, 1, 0]
660+
// CHECK-SAME: inner_dims_pos = [2, 1]
661+
// CHECK-SAME: inner_tiles = [4, 16]
662+
// CHEKC-SAME: into %[[OUT]] : tensor<2x3x5x4x16xi32> -> tensor<5x48x8xi32>
663+
// CHECK: return %[[UNPACK]] : tensor<5x48x8xi32>
664+
// CHECK: }
665+
666+
// -----
667+
668+
func.func @unpack_non_involution_transpose_fold(%arg0: tensor<57x3x56x1x64xf32>) -> tensor<3648x3x56xf32> {
669+
%0 = tensor.empty() : tensor<3x56x3648xf32>
670+
%unpack = tensor.unpack %arg0
671+
outer_dims_perm = [2, 0, 1]
672+
inner_dims_pos = [1, 2]
673+
inner_tiles = [1, 64]
674+
into %0 : tensor<57x3x56x1x64xf32> -> tensor<3x56x3648xf32>
675+
676+
%1 = tensor.empty() : tensor<3648x3x56xf32>
677+
%transposed = linalg.transpose
678+
ins(%unpack : tensor<3x56x3648xf32>)
679+
outs(%1 : tensor<3648x3x56xf32>)
680+
permutation = [2, 0, 1]
681+
return %transposed : tensor<3648x3x56xf32>
682+
}
683+
// CHECK-LABEL: func.func @unpack_non_involution_transpose_fold(
684+
// CHECK-SAME: %[[ARG0:.+]]: tensor<57x3x56x1x64xf32>) -> tensor<3648x3x56xf32> {
685+
// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<3648x3x56xf32>
686+
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
687+
// CHECK-SAME: outer_dims_perm = [0, 1, 2]
688+
// CHECK-SAME: inner_dims_pos = [2, 0]
689+
// CHECK-SAME: inner_tiles = [1, 64]
690+
// CHECK-SAME: into %[[OUT:.+]] : tensor<57x3x56x1x64xf32> -> tensor<3648x3x56xf32>
691+
// CHECK: return %[[UNPACK]] : tensor<3648x3x56xf32>
692+
// CHECK: }
693+
694+
// -----
695+
696+
func.func @transpose_unpacked_dims_no_fold(%arg0: tensor<2x16x5x4x3xi32>) -> tensor<5x32x12xi32> {
697+
%0 = tensor.empty() : tensor<5x2x3x16x4xi32>
698+
%transposed = linalg.transpose ins(%arg0 : tensor<2x16x5x4x3xi32>)
699+
outs(%0 : tensor<5x2x3x16x4xi32>)
700+
permutation = [2, 0, 4, 1, 3]
701+
%1 = tensor.empty() : tensor<5x32x12xi32>
702+
%unpack = tensor.unpack %transposed
703+
inner_dims_pos = [1, 2]
704+
inner_tiles = [16, 4] into
705+
%1 : tensor<5x2x3x16x4xi32> -> tensor<5x32x12xi32>
706+
return %unpack : tensor<5x32x12xi32>
707+
}
708+
//CHECK-LABEL: func.func @transpose_unpacked_dims_no_fold(
709+
// CHECK: linalg.transpose
710+
// CHECK: tensor.unpack
711+
712+
// -----
713+
714+
#map = affine_map<(d0, d1, d2, d3, d4)->(d1, d2, d0, d4, d3)>
715+
#map1 = affine_map<(d0, d1, d2, d3, d4)->(d0, d1, d2, d3, d4)>
716+
func.func @generic_transpose_unpack_fold(%arg0: tensor<2x3x5x4x16xi32>) -> tensor<5x48x8xi32> {
717+
%0 = tensor.empty() : tensor<5x2x3x16x4xi32>
718+
%transposed = linalg.generic {
719+
iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"],
720+
indexing_maps = [#map, #map1]}
721+
ins(%arg0 : tensor<2x3x5x4x16xi32>)
722+
outs(%0 : tensor<5x2x3x16x4xi32>) {
723+
^bb0(%in : i32, %out : i32):
724+
linalg.yield %in : i32
725+
} -> tensor<5x2x3x16x4xi32>
726+
%1 = tensor.empty() : tensor<5x48x8xi32>
727+
%unpack = tensor.unpack %transposed
728+
outer_dims_perm = [0, 2, 1]
729+
inner_dims_pos = [1, 2]
730+
inner_tiles = [16, 4] into
731+
%1 : tensor<5x2x3x16x4xi32> -> tensor<5x48x8xi32>
732+
return %unpack : tensor<5x48x8xi32>
733+
}
734+
//CHECK-LABEL: func.func @generic_transpose_unpack_fold(
735+
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x5x4x16xi32>) -> tensor<5x48x8xi32> {
736+
// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<5x48x8xi32>
737+
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
738+
// CHECK-SAME: outer_dims_perm = [2, 1, 0]
739+
// CHECK-SAME: inner_dims_pos = [2, 1]
740+
// CHECK-SAME: inner_tiles = [4, 16]
741+
// CHEKC-SAME: into %[[OUT]] : tensor<2x3x5x4x16xi32> -> tensor<5x48x8xi32>
742+
// CHECK: return %[[UNPACK]] : tensor<5x48x8xi32>
743+
// CHECK: }
744+
745+
// -----
746+
747+
#map = affine_map<(d0, d1, d2)->(d1, d2, d0)>
748+
#map1 = affine_map<(d0, d1, d2)->(d0, d1, d2)>
749+
func.func @unpack_generic_transpose_fold(%arg0: tensor<57x3x56x1x64xf32>) -> tensor<3648x3x56xf32> {
750+
%0 = tensor.empty() : tensor<3x56x3648xf32>
751+
%unpack = tensor.unpack %arg0
752+
outer_dims_perm = [2, 0, 1]
753+
inner_dims_pos = [1, 2]
754+
inner_tiles = [1, 64]
755+
into %0 : tensor<57x3x56x1x64xf32> -> tensor<3x56x3648xf32>
756+
757+
%1 = tensor.empty() : tensor<3648x3x56xf32>
758+
%transposed = linalg.generic {
759+
iterator_types = ["parallel", "parallel", "parallel"],
760+
indexing_maps = [#map, #map1]}
761+
ins(%unpack : tensor<3x56x3648xf32>)
762+
outs(%1 : tensor<3648x3x56xf32>) {
763+
^bb0(%in : f32, %out : f32):
764+
linalg.yield %in : f32
765+
} -> tensor<3648x3x56xf32>
766+
return %transposed : tensor<3648x3x56xf32>
767+
}
768+
// CHECK-LABEL: func.func @unpack_generic_transpose_fold(
769+
// CHECK-SAME: %[[ARG0:.+]]: tensor<57x3x56x1x64xf32>) -> tensor<3648x3x56xf32> {
770+
// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<3648x3x56xf32>
771+
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
772+
// CHECK-SAME: outer_dims_perm = [0, 1, 2]
773+
// CHECK-SAME: inner_dims_pos = [2, 0]
774+
// CHECK-SAME: inner_tiles = [1, 64]
775+
// CHECK-SAME: into %[[OUT:.+]] : tensor<57x3x56x1x64xf32> -> tensor<3648x3x56xf32>
776+
// CHECK: return %[[UNPACK]] : tensor<3648x3x56xf32>
777+
// CHECK: }

0 commit comments

Comments
 (0)