Skip to content

Commit a44b787

Browse files
authored
[MLIR][linalg] Fix unpack rewriter for dynamic shapes (llvm#67096)
Prior to this patch, `GeneralizeOuterUnitDimsUnPackOpPattern` would assert that we cannot create a `tensor.empty` operation with dynamic shapes. The problem stems from the fact that we were not using the right builder for the `tensor.empty` operation. Indeed, each dynamic dim needs to be specified by an input variable. Simply provide the dynamic dimensions to the `tensor.empty` builder to fix that.
1 parent 39d7f70 commit a44b787

File tree

2 files changed

+30
-3
lines changed

2 files changed

+30
-3
lines changed

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,15 +1256,18 @@ LogicalResult GeneralizeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
12561256
SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
12571257
SmallVector<OpFoldResult> readSizes;
12581258
SmallVector<int64_t> readShape;
1259+
SmallVector<Value> dynamicDims;
12591260
for (auto i : llvm::seq<unsigned>(0, destRank)) {
12601261
if (dimAndTileMapping.count(i)) {
12611262
readSizes.push_back(oneIdxAttr);
12621263
continue;
12631264
}
12641265

12651266
if (ShapedType::isDynamic(srcShape[i])) {
1266-
readSizes.push_back(
1267-
rewriter.create<tensor::DimOp>(loc, source, i).getResult());
1267+
Value dynamicDim =
1268+
rewriter.create<tensor::DimOp>(loc, source, i).getResult();
1269+
readSizes.push_back(dynamicDim);
1270+
dynamicDims.push_back(dynamicDim);
12681271
} else {
12691272
readSizes.push_back(rewriter.getIndexAttr(srcShape[i]));
12701273
}
@@ -1292,7 +1295,8 @@ LogicalResult GeneralizeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
12921295
SmallVector<int64_t> transpShape(readShape);
12931296
applyPermutationToVector<int64_t>(transpShape, perm);
12941297

1295-
Value empty = rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType);
1298+
Value empty =
1299+
rewriter.create<tensor::EmptyOp>(loc, transpShape, elemType, dynamicDims);
12961300
auto transposedOp =
12971301
rewriter.create<linalg::TransposeOp>(loc, innerTile, empty, perm);
12981302

mlir/test/Dialect/Linalg/generalize-tensor-unpack.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,26 @@ func.func @simple_NHWC_to_NCHW(%arg0: tensor<1x16x8x32xf32>, %arg1: tensor<1x32x
9494
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]]
9595
// CHECK-SAME: [0, 0, 0, 0] [1, 32, 16, 8] [1, 1, 1, 1]
9696
// CHECK: return %[[INSERT]]
97+
98+
// -----
99+
100+
func.func @unpack_with_dynamic_dims(%arg0: tensor<?x1x1x1x8x32xf32>, %arg1: tensor<?x1x32x8xf32>) -> tensor<?x1x32x8xf32> {
101+
%0 = tensor.unpack %arg0 inner_dims_pos = [3, 2] inner_tiles = [8, 32] into %arg1 : tensor<?x1x1x1x8x32xf32> -> tensor<?x1x32x8xf32>
102+
return %0 : tensor<?x1x32x8xf32>
103+
}
104+
// CHECK-LABEL: func.func @unpack_with_dynamic_dims
105+
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
106+
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
107+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
108+
// CHECK: %[[DIM0_SRC:.+]] = tensor.dim %[[SRC]], %[[C0]] : tensor<?x1x1x1x8x32xf32>
109+
// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0, 0, 0] [%[[DIM0_SRC]], 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1]
110+
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM0_SRC]]) : tensor<?x32x8xf32>
111+
// CHECK: %[[TRANSP:.+]] = linalg.transpose
112+
// CHECK-SAME: ins(%[[TILE]] : tensor<?x8x32xf32>)
113+
// CHECK-SAME: outs(%[[EMPTY]] : tensor<?x32x8xf32>)
114+
// CHECK-SAME: permutation = [0, 2, 1]
115+
// CHECK: %[[DIM0_DEST:.+]] = tensor.dim %[[DEST]], %[[C0]] : tensor<?x1x32x8xf32>
116+
// CHECK: %[[EXTRACT_SLICE:.+]] = tensor.extract_slice %[[TRANSP]][0, 0, 0] [%[[DIM0_DEST]], 32, 8] [1, 1, 1] : tensor<?x32x8xf32> to tensor<?x32x8xf32>
117+
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[EXTRACT_SLICE]] into %[[DEST]]
118+
// CHECK-SAME: [0, 0, 0, 0] [%[[DIM0_DEST]], 1, 32, 8] [1, 1, 1, 1]
119+
// CHECK: return %[[INSERT]]

0 commit comments

Comments
 (0)