Skip to content

Commit d2908ca

Browse files
Revert "[mlir] Add pack/unpack transpose foldings for linalg.generic ops, fix bugs (llvm#93055)"
This reverts commit 7ef83f5.
1 parent 144ebdd commit d2908ca

File tree

2 files changed

+41
-221
lines changed

2 files changed

+41
-221
lines changed

mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp

Lines changed: 41 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -48,34 +48,6 @@ static LogicalResult isPackOn1D(RewriterBase &rewriter, Operation *op,
4848
return success();
4949
}
5050

51-
// If the `linalgOp` represents a transpose, return the permutation vector for
52-
// the transpose. Otherwise, return failure.
53-
static FailureOr<SmallVector<int64_t>>
54-
getTransposeOpPermutation(linalg::LinalgOp linalgOp) {
55-
if (auto transposeOp = dyn_cast<linalg::TransposeOp>(linalgOp.getOperation()))
56-
return SmallVector<int64_t>(transposeOp.getPermutation());
57-
if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
58-
return failure();
59-
60-
if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
61-
return failure();
62-
auto mapRange = linalgOp.getIndexingMapsArray();
63-
if (!mapRange.front().isPermutation() || !mapRange.back().isPermutation() ||
64-
mapRange.front() == mapRange.back()) {
65-
return failure();
66-
}
67-
if (!llvm::hasSingleElement(linalgOp.getBlock()->getOperations()))
68-
return failure();
69-
AffineMap outMap = mapRange.back();
70-
AffineMap inMap = mapRange.front();
71-
// To get the permutation, look at each output index and find which
72-
// dimension in the input we're reading from for that index.
73-
return llvm::map_to_vector(outMap.getResults(),
74-
[&](AffineExpr expr) -> int64_t {
75-
return *inMap.getResultPosition(expr);
76-
});
77-
}
78-
7951
/// Packing one-dimensional tensor can be expressed as an expand shape op.
8052
struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
8153
using OpRewritePattern<PackOp>::OpRewritePattern;
@@ -274,10 +246,14 @@ static bool checkAndPermute(ArrayRef<int64_t> permutation,
274246

275247
for (unsigned int i = 0; i < rank; ++i) {
276248
int64_t remappedPosition = permutation[i];
277-
if (remappedPosition >= rank)
278-
return false;
279-
if (!inVec.empty())
249+
250+
if (!inVec.empty()) {
251+
if (remappedPosition >= rank) {
252+
return false;
253+
}
280254
remappedPosition = inVec[remappedPosition];
255+
}
256+
281257
resVec.push_back(remappedPosition);
282258
}
283259

@@ -287,25 +263,20 @@ static bool checkAndPermute(ArrayRef<int64_t> permutation,
287263
/// Fold 'pack' -> 'transpose' into 'pack' since 'pack' already has transpose
288264
/// semantics.
289265
struct FoldProducerPackWithConsumerLinalgTransposeOp
290-
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
291-
using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
266+
: public OpRewritePattern<linalg::TransposeOp> {
267+
using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
292268

293-
LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
269+
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
294270
PatternRewriter &rewriter) const override {
295-
auto packOp = linalgOp->getOperand(0).getDefiningOp<PackOp>();
271+
auto packOp = transposeOp.getOperand(0).getDefiningOp<PackOp>();
296272

297273
if (!packOp)
298274
return failure();
299275

300-
FailureOr<SmallVector<int64_t>> maybePerm =
301-
getTransposeOpPermutation(linalgOp);
302-
if (failed(maybePerm))
303-
return failure();
304-
305276
auto innerDimsPos = packOp.getInnerDimsPos();
306277
auto mixedInnerTiles = packOp.getMixedTiles();
307278
auto outerDimsPerm = packOp.getOuterDimsPerm();
308-
auto transposePerm = maybePerm.value();
279+
auto transposePerm = transposeOp.getPermutation();
309280
SmallVector<int64_t> newOuterDimsPermVec;
310281
SmallVector<int64_t> newInnerDimsPosVec;
311282
SmallVector<OpFoldResult> newMixedInnerTilesVec;
@@ -314,7 +285,7 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
314285
if (!checkAndPermute(transposePerm, outerDimsPerm, newOuterDimsPermVec,
315286
srcRank))
316287
return rewriter.notifyMatchFailure(
317-
linalgOp,
288+
transposeOp,
318289
"Cannot fold in tensor.pack if a tile dimension was transposed "
319290
"with a non-tile dimension in linalg.transpose.");
320291

@@ -326,11 +297,11 @@ struct FoldProducerPackWithConsumerLinalgTransposeOp
326297
}
327298

328299
Value output = packOp.createDestinationTensor(
329-
rewriter, linalgOp.getLoc(), packOp.getSource(), newMixedInnerTilesVec,
330-
newInnerDimsPosVec, newOuterDimsPermVec);
300+
rewriter, transposeOp.getLoc(), packOp.getSource(),
301+
newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
331302

332303
rewriter.replaceOpWithNewOp<PackOp>(
333-
linalgOp, packOp.getSource(), output, newInnerDimsPosVec,
304+
transposeOp, packOp.getSource(), output, newInnerDimsPosVec,
334305
newMixedInnerTilesVec, packOp.getPaddingValue(), newOuterDimsPermVec);
335306

336307
return success();
@@ -345,16 +316,12 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
345316

346317
LogicalResult matchAndRewrite(PackOp packOp,
347318
PatternRewriter &rewriter) const override {
348-
auto linalgOp = packOp.getSource().getDefiningOp<linalg::LinalgOp>();
349-
if (!linalgOp)
350-
return failure();
319+
auto transposeOp = packOp.getSource().getDefiningOp<linalg::TransposeOp>();
351320

352-
FailureOr<SmallVector<int64_t>> maybePerm =
353-
getTransposeOpPermutation(linalgOp);
354-
if (failed(maybePerm))
321+
if (!transposeOp)
355322
return failure();
356323

357-
auto transposePermutation = maybePerm.value();
324+
auto transposePermutation = transposeOp.getPermutation();
358325
auto outerDimsPerm = packOp.getOuterDimsPerm();
359326
auto innerDimsPos = packOp.getInnerDimsPos();
360327
SmallVector<int64_t> newInnerDimsPosVec;
@@ -370,11 +337,11 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
370337
newInnerDimsPosVec.push_back(transposePermutation[dim]);
371338

372339
Value output = packOp.createDestinationTensor(
373-
rewriter, packOp.getLoc(), linalgOp->getOperand(0),
340+
rewriter, packOp.getLoc(), transposeOp.getOperand(0),
374341
packOp.getMixedTiles(), newInnerDimsPosVec, newOuterDimsPermVec);
375342

376343
rewriter.replaceOpWithNewOp<PackOp>(
377-
packOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
344+
packOp, transposeOp.getOperand(0), output, newInnerDimsPosVec,
378345
packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPermVec);
379346

380347
return success();
@@ -384,38 +351,34 @@ struct FoldConsumerPackWithProducerLinalgTransposeOp
384351
/// Fold 'unpack' -> 'transpose' into 'unpack' since 'unpack' already has
385352
/// transpose semantics.
386353
struct FoldProducerUnPackWithConsumerLinalgTransposeOp
387-
: public OpInterfaceRewritePattern<linalg::LinalgOp> {
388-
using OpInterfaceRewritePattern<linalg::LinalgOp>::OpInterfaceRewritePattern;
354+
: public OpRewritePattern<linalg::TransposeOp> {
355+
using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
389356

390-
LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
357+
LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
391358
PatternRewriter &rewriter) const override {
392-
auto unPackOp = linalgOp->getOperand(0).getDefiningOp<UnPackOp>();
359+
auto unPackOp = transposeOp.getOperand(0).getDefiningOp<UnPackOp>();
393360

394361
if (!unPackOp)
395362
return failure();
396363

397-
FailureOr<SmallVector<int64_t>> maybePerm =
398-
getTransposeOpPermutation(linalgOp);
399-
if (failed(maybePerm))
400-
return failure();
401-
364+
auto transposePermutation = transposeOp.getPermutation();
402365
auto outerDimsPerm = unPackOp.getOuterDimsPerm();
403366
auto innerDimsPos = unPackOp.getInnerDimsPos();
404367
SmallVector<int64_t> newInnerDimsPosVec;
405368
SmallVector<int64_t> newOuterDimsPermVec =
406-
invertPermutationVector(maybePerm.value());
369+
llvm::to_vector(transposePermutation);
370+
371+
if (!outerDimsPerm.empty())
372+
applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
407373

408374
// Can't use applyPermutationToVector for newInnerDimsPosVec since input and
409375
// permutation rank won't necessarily be equal in all cases.
410376
for (auto dim : innerDimsPos)
411-
newInnerDimsPosVec.push_back(newOuterDimsPermVec[dim]);
412-
413-
if (!outerDimsPerm.empty())
414-
applyPermutationToVector(newOuterDimsPermVec, outerDimsPerm);
377+
newInnerDimsPosVec.push_back(transposePermutation[dim]);
415378

416379
// Reuse the destination of the transpose op.
417380
rewriter.replaceOpWithNewOp<UnPackOp>(
418-
linalgOp, unPackOp.getSource(), linalgOp.getDpsInits()[0],
381+
transposeOp, unPackOp.getSource(), transposeOp.getDpsInits()[0],
419382
newInnerDimsPosVec, unPackOp.getMixedTiles(), newOuterDimsPermVec);
420383

421384
return success();
@@ -430,17 +393,13 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
430393

431394
LogicalResult matchAndRewrite(UnPackOp unPackOp,
432395
PatternRewriter &rewriter) const override {
433-
auto linalgOp = unPackOp.getSource().getDefiningOp<linalg::LinalgOp>();
434-
if (!linalgOp)
435-
return failure();
396+
auto transposeOp =
397+
unPackOp.getSource().getDefiningOp<linalg::TransposeOp>();
436398

437-
FailureOr<SmallVector<int64_t>> maybePerm =
438-
getTransposeOpPermutation(linalgOp);
439-
if (failed(maybePerm))
399+
if (!transposeOp)
440400
return failure();
441401

442-
SmallVector<int64_t> inverseTransposePerm =
443-
invertPermutationVector(maybePerm.value());
402+
auto transposePermutation = transposeOp.getPermutation();
444403
auto outerDimsPerm = unPackOp.getOuterDimsPerm();
445404
auto innerDimsPos = unPackOp.getInnerDimsPos();
446405
int64_t destRank = unPackOp.getSourceRank() - innerDimsPos.size();
@@ -449,26 +408,26 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
449408
SmallVector<int64_t> newInnerDimsPosVec;
450409
SmallVector<OpFoldResult> newMixedInnerTilesVec;
451410

452-
if (!checkAndPermute(inverseTransposePerm, outerDimsPerm,
411+
if (!checkAndPermute(transposePermutation, outerDimsPerm,
453412
newOuterDimsPermVec, destRank))
454413
return rewriter.notifyMatchFailure(
455414
unPackOp,
456415
"Cannot fold in tensor.unpack if a tile dimension was transposed "
457416
"with a non-tile dimension in linalg.transpose.");
458417

459418
// Process transpose operation for tiled inner dimensions
460-
for (unsigned int i = destRank; i < inverseTransposePerm.size(); ++i) {
461-
int64_t remappedPosition = inverseTransposePerm[i] - destRank;
419+
for (unsigned int i = destRank; i < transposePermutation.size(); ++i) {
420+
int64_t remappedPosition = transposePermutation[i] - destRank;
462421
newMixedInnerTilesVec.push_back(mixedInnerTilesVec[remappedPosition]);
463422
newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
464423
}
465424

466425
Value output = unPackOp.createDestinationTensor(
467-
rewriter, unPackOp.getLoc(), linalgOp->getOperand(0),
426+
rewriter, unPackOp.getLoc(), transposeOp.getOperand(0),
468427
newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
469428

470429
rewriter.replaceOpWithNewOp<UnPackOp>(
471-
unPackOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,
430+
unPackOp, transposeOp.getOperand(0), output, newInnerDimsPosVec,
472431
newMixedInnerTilesVec, newOuterDimsPermVec);
473432

474433
return success();

mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir

Lines changed: 0 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -636,142 +636,3 @@ func.func @tensor_padded_unpack_linalg_transpose_fold(%arg0: tensor<71x7x4x16x16
636636
// CHECK-SAME: into %[[OUT:.+]] : tensor<71x7x4x16x16xf32> -> tensor<100x71x64xf32>
637637
// CHECK: return %[[UNPACK]] : tensor<100x71x64xf32>
638638
// CHECK: }
639-
640-
// -----
641-
642-
func.func @non_involution_transpose_unpack_fold(%arg0: tensor<2x3x5x4x16xi32>) -> tensor<5x48x8xi32> {
643-
%0 = tensor.empty() : tensor<5x2x3x16x4xi32>
644-
%transposed = linalg.transpose ins(%arg0 : tensor<2x3x5x4x16xi32>)
645-
outs(%0 : tensor<5x2x3x16x4xi32>)
646-
permutation = [2, 0, 1, 4, 3]
647-
%1 = tensor.empty() : tensor<5x48x8xi32>
648-
%unpack = tensor.unpack %transposed
649-
outer_dims_perm = [0, 2, 1]
650-
inner_dims_pos = [1, 2]
651-
inner_tiles = [16, 4] into
652-
%1 : tensor<5x2x3x16x4xi32> -> tensor<5x48x8xi32>
653-
return %unpack : tensor<5x48x8xi32>
654-
}
655-
//CHECK-LABEL: func.func @non_involution_transpose_unpack_fold(
656-
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x5x4x16xi32>) -> tensor<5x48x8xi32> {
657-
// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<5x48x8xi32>
658-
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
659-
// CHECK-SAME: outer_dims_perm = [2, 1, 0]
660-
// CHECK-SAME: inner_dims_pos = [2, 1]
661-
// CHECK-SAME: inner_tiles = [4, 16]
662-
// CHEKC-SAME: into %[[OUT]] : tensor<2x3x5x4x16xi32> -> tensor<5x48x8xi32>
663-
// CHECK: return %[[UNPACK]] : tensor<5x48x8xi32>
664-
// CHECK: }
665-
666-
// -----
667-
668-
func.func @unpack_non_involution_transpose_fold(%arg0: tensor<57x3x56x1x64xf32>) -> tensor<3648x3x56xf32> {
669-
%0 = tensor.empty() : tensor<3x56x3648xf32>
670-
%unpack = tensor.unpack %arg0
671-
outer_dims_perm = [2, 0, 1]
672-
inner_dims_pos = [1, 2]
673-
inner_tiles = [1, 64]
674-
into %0 : tensor<57x3x56x1x64xf32> -> tensor<3x56x3648xf32>
675-
676-
%1 = tensor.empty() : tensor<3648x3x56xf32>
677-
%transposed = linalg.transpose
678-
ins(%unpack : tensor<3x56x3648xf32>)
679-
outs(%1 : tensor<3648x3x56xf32>)
680-
permutation = [2, 0, 1]
681-
return %transposed : tensor<3648x3x56xf32>
682-
}
683-
// CHECK-LABEL: func.func @unpack_non_involution_transpose_fold(
684-
// CHECK-SAME: %[[ARG0:.+]]: tensor<57x3x56x1x64xf32>) -> tensor<3648x3x56xf32> {
685-
// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<3648x3x56xf32>
686-
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
687-
// CHECK-SAME: outer_dims_perm = [0, 1, 2]
688-
// CHECK-SAME: inner_dims_pos = [2, 0]
689-
// CHECK-SAME: inner_tiles = [1, 64]
690-
// CHECK-SAME: into %[[OUT:.+]] : tensor<57x3x56x1x64xf32> -> tensor<3648x3x56xf32>
691-
// CHECK: return %[[UNPACK]] : tensor<3648x3x56xf32>
692-
// CHECK: }
693-
694-
// -----
695-
696-
func.func @transpose_unpacked_dims_no_fold(%arg0: tensor<2x16x5x4x3xi32>) -> tensor<5x32x12xi32> {
697-
%0 = tensor.empty() : tensor<5x2x3x16x4xi32>
698-
%transposed = linalg.transpose ins(%arg0 : tensor<2x16x5x4x3xi32>)
699-
outs(%0 : tensor<5x2x3x16x4xi32>)
700-
permutation = [2, 0, 4, 1, 3]
701-
%1 = tensor.empty() : tensor<5x32x12xi32>
702-
%unpack = tensor.unpack %transposed
703-
inner_dims_pos = [1, 2]
704-
inner_tiles = [16, 4] into
705-
%1 : tensor<5x2x3x16x4xi32> -> tensor<5x32x12xi32>
706-
return %unpack : tensor<5x32x12xi32>
707-
}
708-
//CHECK-LABEL: func.func @transpose_unpacked_dims_no_fold(
709-
// CHECK: linalg.transpose
710-
// CHECK: tensor.unpack
711-
712-
// -----
713-
714-
#map = affine_map<(d0, d1, d2, d3, d4)->(d1, d2, d0, d4, d3)>
715-
#map1 = affine_map<(d0, d1, d2, d3, d4)->(d0, d1, d2, d3, d4)>
716-
func.func @generic_transpose_unpack_fold(%arg0: tensor<2x3x5x4x16xi32>) -> tensor<5x48x8xi32> {
717-
%0 = tensor.empty() : tensor<5x2x3x16x4xi32>
718-
%transposed = linalg.generic {
719-
iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"],
720-
indexing_maps = [#map, #map1]}
721-
ins(%arg0 : tensor<2x3x5x4x16xi32>)
722-
outs(%0 : tensor<5x2x3x16x4xi32>) {
723-
^bb0(%in : i32, %out : i32):
724-
linalg.yield %in : i32
725-
} -> tensor<5x2x3x16x4xi32>
726-
%1 = tensor.empty() : tensor<5x48x8xi32>
727-
%unpack = tensor.unpack %transposed
728-
outer_dims_perm = [0, 2, 1]
729-
inner_dims_pos = [1, 2]
730-
inner_tiles = [16, 4] into
731-
%1 : tensor<5x2x3x16x4xi32> -> tensor<5x48x8xi32>
732-
return %unpack : tensor<5x48x8xi32>
733-
}
734-
//CHECK-LABEL: func.func @generic_transpose_unpack_fold(
735-
// CHECK-SAME: %[[ARG0:.+]]: tensor<2x3x5x4x16xi32>) -> tensor<5x48x8xi32> {
736-
// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<5x48x8xi32>
737-
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
738-
// CHECK-SAME: outer_dims_perm = [2, 1, 0]
739-
// CHECK-SAME: inner_dims_pos = [2, 1]
740-
// CHECK-SAME: inner_tiles = [4, 16]
741-
// CHEKC-SAME: into %[[OUT]] : tensor<2x3x5x4x16xi32> -> tensor<5x48x8xi32>
742-
// CHECK: return %[[UNPACK]] : tensor<5x48x8xi32>
743-
// CHECK: }
744-
745-
// -----
746-
747-
#map = affine_map<(d0, d1, d2)->(d1, d2, d0)>
748-
#map1 = affine_map<(d0, d1, d2)->(d0, d1, d2)>
749-
func.func @unpack_generic_transpose_fold(%arg0: tensor<57x3x56x1x64xf32>) -> tensor<3648x3x56xf32> {
750-
%0 = tensor.empty() : tensor<3x56x3648xf32>
751-
%unpack = tensor.unpack %arg0
752-
outer_dims_perm = [2, 0, 1]
753-
inner_dims_pos = [1, 2]
754-
inner_tiles = [1, 64]
755-
into %0 : tensor<57x3x56x1x64xf32> -> tensor<3x56x3648xf32>
756-
757-
%1 = tensor.empty() : tensor<3648x3x56xf32>
758-
%transposed = linalg.generic {
759-
iterator_types = ["parallel", "parallel", "parallel"],
760-
indexing_maps = [#map, #map1]}
761-
ins(%unpack : tensor<3x56x3648xf32>)
762-
outs(%1 : tensor<3648x3x56xf32>) {
763-
^bb0(%in : f32, %out : f32):
764-
linalg.yield %in : f32
765-
} -> tensor<3648x3x56xf32>
766-
return %transposed : tensor<3648x3x56xf32>
767-
}
768-
// CHECK-LABEL: func.func @unpack_generic_transpose_fold(
769-
// CHECK-SAME: %[[ARG0:.+]]: tensor<57x3x56x1x64xf32>) -> tensor<3648x3x56xf32> {
770-
// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<3648x3x56xf32>
771-
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
772-
// CHECK-SAME: outer_dims_perm = [0, 1, 2]
773-
// CHECK-SAME: inner_dims_pos = [2, 0]
774-
// CHECK-SAME: inner_tiles = [1, 64]
775-
// CHECK-SAME: into %[[OUT:.+]] : tensor<57x3x56x1x64xf32> -> tensor<3648x3x56xf32>
776-
// CHECK: return %[[UNPACK]] : tensor<3648x3x56xf32>
777-
// CHECK: }

0 commit comments

Comments
 (0)