Skip to content

Commit c80fb59

Browse files
committed
fold-transpose-unpack-partial-tile
1 parent 26ba186 commit c80fb59

File tree

4 files changed

+48
-15
lines changed

4 files changed

+48
-15
lines changed

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2076,7 +2076,8 @@ def Tensor_UnPackOp : Tensor_RelayoutOp<"unpack"> {
20762076
let extraClassDeclaration = commonExtraClassDeclaration # [{
20772077
static Value createDestinationTensor(OpBuilder &b, Location loc,
20782078
Value source, ArrayRef<OpFoldResult> innerTileSizes,
2079-
ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm);
2079+
ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm,
2080+
SmallVector<OpFoldResult> mixedSizes = {});
20802081

20812082
/// Build and return a new UnPackOp that is a clone of the current UnPackOp
20822083
/// with (innerDimsPos, innerTiles) (resp. outerDimsPerm) are permuted by

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4360,15 +4360,19 @@ Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc,
43604360
Value source,
43614361
ArrayRef<OpFoldResult> innerTileSizes,
43624362
ArrayRef<int64_t> innerDimsPos,
4363-
ArrayRef<int64_t> outerDimsPerm) {
4363+
ArrayRef<int64_t> outerDimsPerm,
4364+
SmallVector<OpFoldResult> mixedSizes) {
4365+
auto srcType = llvm::cast<RankedTensorType>(source.getType());
4366+
auto elemType = srcType.getElementType();
4367+
if (!mixedSizes.empty()) {
4368+
return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
4369+
}
4370+
43644371
AffineExpr sym0, sym1;
43654372
bindSymbols(b.getContext(), sym0, sym1);
43664373
auto dimMul = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult {
43674374
return affine::makeComposedFoldedAffineApply(b, loc, sym0 * sym1, {v1, v2});
43684375
};
4369-
4370-
SmallVector<OpFoldResult> mixedSizes;
4371-
auto srcType = llvm::cast<RankedTensorType>(source.getType());
43724376
for (auto i :
43734377
llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) {
43744378
if (srcType.isDynamicDim(i))
@@ -4384,7 +4388,6 @@ Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc,
43844388
for (auto [dimPos, tileSize] : llvm::zip_equal(innerDimsPos, innerTileSizes))
43854389
mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize);
43864390

4387-
auto elemType = srcType.getElementType();
43884391
return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType);
43894392
}
43904393

mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,11 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
439439
if (failed(maybePerm))
440440
return failure();
441441

442+
SmallVector<SmallVector<OpFoldResult>> unpackOpResultDims;
443+
if (failed(reifyResultShapes(rewriter, unPackOp, unpackOpResultDims))) {
444+
return failure();
445+
}
446+
442447
SmallVector<int64_t> inverseTransposePerm =
443448
invertPermutationVector(maybePerm.value());
444449
auto outerDimsPerm = unPackOp.getOuterDimsPerm();
@@ -448,13 +453,13 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
448453
SmallVector<int64_t> newOuterDimsPermVec;
449454
SmallVector<int64_t> newInnerDimsPosVec;
450455
SmallVector<OpFoldResult> newMixedInnerTilesVec;
451-
452456
if (!checkAndPermute(inverseTransposePerm, outerDimsPerm,
453-
newOuterDimsPermVec, destRank))
457+
newOuterDimsPermVec, destRank)) {
454458
return rewriter.notifyMatchFailure(
455459
unPackOp,
456460
"Cannot fold in tensor.unpack if a tile dimension was transposed "
457461
"with a non-tile dimension in linalg.transpose.");
462+
}
458463

459464
// Process transpose operation for tiled inner dimensions
460465
for (unsigned int i = destRank; i < inverseTransposePerm.size(); ++i) {
@@ -465,7 +470,8 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp
465470

466471
Value output = unPackOp.createDestinationTensor(
467472
rewriter, unPackOp.getLoc(), linalgOp->getOperand(0),
468-
newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec);
473+
newMixedInnerTilesVec, newInnerDimsPosVec, newOuterDimsPermVec,
474+
unpackOpResultDims[0]);
469475

470476
rewriter.replaceOpWithNewOp<UnPackOp>(
471477
unPackOp, linalgOp->getOperand(0), output, newInnerDimsPosVec,

mlir/test/Dialect/Tensor/fold-into-pack-and-unpack.mlir

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,32 @@ func.func @linalg_transpose_tensor_unpack_fold(%arg0: tensor<1x1x4x16xi32>) -> t
550550

551551
// -----
552552

553+
func.func @linalg_transpose_tensor_unpack_fold_partial_tile(%arg0: tensor<1x1x4x16xi32>) -> tensor<15x3xi32> {
554+
%0 = tensor.empty() : tensor<1x1x16x4xi32>
555+
%transposed = linalg.transpose ins(%arg0 : tensor<1x1x4x16xi32>)
556+
outs(%0 : tensor<1x1x16x4xi32>)
557+
permutation = [1, 0, 3, 2]
558+
%1 = tensor.empty() : tensor<15x3xi32>
559+
%unpack = tensor.unpack %transposed
560+
outer_dims_perm = [0, 1]
561+
inner_dims_pos = [0, 1]
562+
inner_tiles = [16, 4] into
563+
%1 : tensor<1x1x16x4xi32> -> tensor<15x3xi32>
564+
return %unpack : tensor<15x3xi32>
565+
}
566+
//CHECK-LABEL: func.func @linalg_transpose_tensor_unpack_fold_partial_tile(
567+
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x4x16xi32>) -> tensor<15x3xi32> {
568+
// CHECK: %[[OUT:.+]] = tensor.empty() : tensor<15x3xi32>
569+
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
570+
// CHECK-SAME: outer_dims_perm = [1, 0]
571+
// CHECK-SAME: inner_dims_pos = [1, 0]
572+
// CHECK-SAME: inner_tiles = [4, 16]
573+
// CHECK-SAME: into %[[OUT]] : tensor<1x1x4x16xi32> -> tensor<15x3xi32>
574+
// CHECK: return %[[UNPACK]] : tensor<15x3xi32>
575+
// CHECK: }
576+
577+
// -----
578+
553579
func.func @linalg_transpose_tensor_unpack_fold_dynamic_outer_dims_tile_dims_tile_sizes(%arg0: tensor<?x?x?x?xf32>, %transpose_dest: tensor<?x?x?x?xf32>, %unpack_dest: tensor<?x?xf32>, %tile_p : index, %tile_q : index) -> tensor<?x?xf32> {
554580
%transposed = linalg.transpose
555581
ins(%arg0 : tensor<?x?x?x?xf32>)
@@ -563,17 +589,14 @@ func.func @linalg_transpose_tensor_unpack_fold_dynamic_outer_dims_tile_dims_tile
563589
into %unpack_dest : tensor<?x?x?x?xf32> -> tensor<?x?xf32>
564590
return %unpack : tensor<?x?xf32>
565591
}
566-
// CHECK: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
567592
// CHECK-LABEL: func.func @linalg_transpose_tensor_unpack_fold_dynamic_outer_dims_tile_dims_tile_sizes(
568593
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>, %[[ARG1:.+]]: tensor<?x?x?x?xf32>, %[[ARG2:.+]]: tensor<?x?xf32>,
569594
// CHECK-SAME: %[[IDX1:.+]]: index, %[[IDX2:.+]]: index) -> tensor<?x?xf32> {
570595
// CHECK-DAG: %[[CST1:.+]] = arith.constant 1 : index
571596
// CHECK-DAG: %[[CST0:.+]] = arith.constant 0 : index
572-
// CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[CST0]] : tensor<?x?x?x?xf32>
573-
// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[CST1]] : tensor<?x?x?x?xf32>
574-
// CHECK-DAG: %[[AMAP0:.+]] = affine.apply #[[$MAP]]()[%[[DIM1]], %[[IDX2]]]
575-
// CHECK-DAG: %[[AMAP1:.+]] = affine.apply #[[$MAP]]()[%[[DIM0]], %[[IDX1]]]
576-
// CHECK: %[[OUT:.+]] = tensor.empty(%[[AMAP1]], %[[AMAP0]]) : tensor<?x?xf32>
597+
// CHECK-DAG: %[[DIM0:.+]] = tensor.dim %[[ARG2]], %[[CST0]] : tensor<?x?xf32>
598+
// CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG2]], %[[CST1]] : tensor<?x?xf32>
599+
// CHECK: %[[OUT:.+]] = tensor.empty(%[[DIM0]], %[[DIM1]]) : tensor<?x?xf32>
577600
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
578601
// CHECK-SAME: outer_dims_perm = [0, 1]
579602
// CHECK-SAME: inner_dims_pos = [1, 0]

0 commit comments

Comments
 (0)