Skip to content

Commit c1667f9

Browse files
authored
Fix transpose->unpack folding pattern for the partial-tile case of unpack (#107271)
Just directly create the empty tensor of appropriate shape instead of relying on `UnPackOp::createDestinationTensor` which is trying to infer the destination shape, which isn't possible in general with the set of paramters that it is taking. Signed-off-by: Benoit Jacob <[email protected]>
1 parent a43137c commit c1667f9

File tree

2 files changed

+38
-10
lines changed

2 files changed

+38
-10
lines changed

mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,11 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
439439
if (failed(maybePerm))
440440
return failure();
441441

442+
SmallVector<SmallVector<OpFoldResult>> unpackOpResultDims;
443+
if (failed(reifyResultShapes(rewriter, unPackOp, unpackOpResultDims))) {
444+
return failure();
445+
}
446+
442447
SmallVector<int64_t> inverseTransposePerm =
443448
invertPermutationVector(maybePerm.value());
444449
auto outerDimsPerm = unPackOp.getOuterDimsPerm();
@@ -448,7 +453,6 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
448453
SmallVector<int64_t> newOuterDimsPermVec;
449454
SmallVector<int64_t> newInnerDimsPosVec;
450455
SmallVector<OpFoldResult> newMixedInnerTilesVec;
451-
452456
if (!checkAndPermute(inverseTransposePerm, outerDimsPerm,
453457
newOuterDimsPermVec, destRank))
454458
return rewriter.notifyMatchFailure(
@@ -463,9 +467,10 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
463467
newInnerDimsPosVec.push_back(innerDimsPos[remappedPosition]);
464468
}
465469

466-
Value output = unPackOp.createDestinationTensor(
467-
rewriter, unPackOp.getLoc(), linalgOp->getOperand(0),
468-
newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
470+
auto elemType =
471+
cast<ShapedType>(unPackOp->getResultTypes()[0]).getElementType();
472+
Value output = rewriter.create<tensor::EmptyOp>(
473+
unPackOp->getLoc(), unpackOpResultDims[0], elemType);
469474

470475
rewriter.replaceOpWithNewOp<UnPackOp>(
471476
unPackOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,

mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,32 @@ func.func @linalg_transpose_tensor_unpack_fold(%arg0: tensor<1x1x4x16xi32>) -> t
550550

551551
// -----
552552

553+
func.func @linalg_transpose_tensor_unpack_fold_partial_tile(%arg0: tensor<1x1x4x16xi32>) -> tensor<15x3xi32> {
554+
%0 = tensor.empty() : tensor<1x1x16x4xi32>
555+
%transposed = linalg.transpose ins(%arg0 : tensor<1x1x4x16xi32>)
556+
outs(%0 : tensor<1x1x16x4xi32>)
557+
permutation = [1, 0, 3, 2]
558+
%1 = tensor.empty() : tensor<15x3xi32>
559+
%unpack = tensor.unpack %transposed
560+
outer_dims_perm = [0, 1]
561+
inner_dims_pos = [0, 1]
562+
inner_tiles = [16, 4] into
563+
%1 : tensor<1x1x16x4xi32> -> tensor<15x3xi32>
564+
return %unpack : tensor<15x3xi32>
565+
}
566+
//CHECK-LABEL: func.func @linalg_transpose_tensor_unpack_fold_partial_tile(
567+
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x4x16xi32>) -> tensor<15x3xi32> {
568+
// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<15x3xi32>
569+
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
570+
// CHECK-SAME: outer_dims_perm = [1, 0]
571+
// CHECK-SAME: inner_dims_pos = [1, 0]
572+
// CHECK-SAME: inner_tiles = [4, 16]
573+
// CHECK-SAME: into %[[OUT]] : tensor<1x1x4x16xi32> -> tensor<15x3xi32>
574+
// CHECK: return %[[UNPACK]] : tensor<15x3xi32>
575+
// CHECK: }
576+
577+
// -----
578+
553579
func.func @linalg_transpose_tensor_unpack_fold_dynamic_outer_dims_tile_dims_tile_sizes(%arg0: tensor<?x?x?x?xf32>, %transpose_dest: tensor<?x?x?x?xf32>, %unpack_dest: tensor<?x?xf32>, %tile_p : index, %tile_q : index) -> tensor<?x?xf32> {
554580
%transposed = linalg.transpose
555581
ins(%arg0 : tensor<?x?x?x?xf32>)
@@ -563,17 +589,14 @@ func.func @linalg_transpose_tensor_unpack_fold_dynamic_outer_dims_tile_dims_tile
563589
into %unpack_dest : tensor<?x?x?x?xf32> -> tensor<?x?xf32>
564590
return %unpack : tensor<?x?xf32>
565591
}
566-
// CHECK: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
567592
// CHECK-LABEL: func.func @linalg_transpose_tensor_unpack_fold_dynamic_outer_dims_tile_dims_tile_sizes(
568593
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>, %[[ARG1:.+]]: tensor<?x?x?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>,
569594
// CHECK-SAME: %[[IDX1:.+]]: index, %[[IDX2:.+]]: index) -> tensor<?x?xf32> {
570595
// CHECK-DAG: %[[CST1:.+]] = arith.constant 1 : index
571596
// CHECK-DAG: %[[CST0:.+]] = arith.constant 0 : index
572-
// CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[CST0]] : tensor<?x?x?x?xf32>
573-
// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[CST1]] : tensor<?x?x?x?xf32>
574-
// CHECK-DAG: %[[AMAP0:.+]] = affine.apply #[[$MAP]]()[%[[DIM1]], %[[IDX2]]]
575-
// CHECK-DAG: %[[AMAP1:.+]] = affine.apply #[[$MAP]]()[%[[DIM0]], %[[IDX1]]]
576-
// CHECK: %[[OUT:.+]] = tensor.empty(%[[AMAP1]], %[[AMAP0]]) : tensor<?x?xf32>
597+
// CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[ARG2]], %[[CST0]] : tensor<?x?xf32>
598+
// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG2]], %[[CST1]] : tensor<?x?xf32>
599+
// CHECK: %[[OUT:.+]] = tensor.empty(%[[DIM0]], %[[DIM1]]) : tensor<?x?xf32>
577600
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
578601
// CHECK-SAME: outer_dims_perm = [0, 1]
579602
// CHECK-SAME: inner_dims_pos = [1, 0]

0 commit comments

Comments
 (0)