Skip to content

Commit 9d3057c

Browse files
committed
[mlir][Linalg] Add support for lowerPack on dynamic outer shapes.
The revision adds support for tensor.pack op decomposition when all inner tile sizes are static. The generated tensor.expand_shape op is still valid because only one of the expanding dimension is dynamic. Reviewed By: mravishankar Differential Revision: https://reviews.llvm.org/D150233
1 parent 1e0966c commit 9d3057c

File tree

2 files changed

+80
-18
lines changed

2 files changed

+80
-18
lines changed

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,8 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
477477
// 1. Filter out NYI cases.
478478
auto packedTensorType =
479479
packOp->getResultTypes().front().cast<RankedTensorType>();
480-
if (!packedTensorType.hasStaticShape()) {
480+
if (llvm::any_of(packOp.getStaticInnerTiles(),
481+
[](int64_t size) { return ShapedType::isDynamic(size); })) {
481482
return rewriter.notifyMatchFailure(
482483
packOp,
483484
"non-static shape NYI, needs a more powerful tensor.expand_shape op");
@@ -520,6 +521,22 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
520521
applyPermutationToVector(stripMinedShape, packedToStripMinedShapePerm);
521522

522523
// 4. Pad the source of packOp to a shape we can expand into stripMinedShape.
524+
SmallVector<OpFoldResult> lows(packOp.getSourceRank(),
525+
rewriter.getIndexAttr(0));
526+
SmallVector<OpFoldResult> highs(packOp.getSourceRank(),
527+
rewriter.getIndexAttr(0));
528+
for (auto [pos, innerSize] :
529+
llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getMixedTiles())) {
530+
OpFoldResult origSize = rewriter.createOrFold<tensor::DimOp>(
531+
loc, packOp.getSource(),
532+
rewriter.create<arith::ConstantIndexOp>(loc, pos));
533+
AffineExpr s0, d0;
534+
bindDims(rewriter.getContext(), d0);
535+
bindSymbols(rewriter.getContext(), s0);
536+
auto map = AffineMap::get(1, 1, d0.ceilDiv(s0) * s0 - d0);
537+
highs[pos] = affine::makeComposedFoldedAffineApply(rewriter, loc, map,
538+
{origSize, innerSize});
539+
}
523540
RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
524541
RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape),
525542
packingMetadata.reassociations);
@@ -529,8 +546,8 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
529546
loc, rewriter.getZeroAttr(getElementTypeOrSelf(collapsed)));
530547
}
531548
auto padOp =
532-
tensor::createPadHighOp(collapsed, packOp.getSource(), paddingValue,
533-
/*nofold=*/false, loc, rewriter);
549+
rewriter.create<tensor::PadOp>(loc, collapsed, packOp.getSource(), lows,
550+
highs, paddingValue, /*nofold=*/false);
534551

535552
LLVM_DEBUG(
536553
DBGSNL(); DBGSNL(); llvm::interleaveComma(packingMetadata.insertPositions,

mlir/test/Dialect/Linalg/transform-lower-pack.mlir

Lines changed: 60 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
1-
// RUN: mlir-opt %s -test-transform-dialect-interpreter --split-input-file | FileCheck %s
1+
// RUN: mlir-opt %s -test-transform-dialect-interpreter -cse --split-input-file | FileCheck %s
22

33
// CHECK-LABEL: func.func @pack(
44
func.func @pack(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<17x2x16x16x32x8xf32>) -> tensor<17x2x16x16x32x8xf32> {
55
%cst_0 = arith.constant 0.0 : f32
66

77
// tensor.pack is lowered to tensor.pad + tensor.expand_shape + linalg.transpose
8-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
9-
// CHECK: tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]
8+
// CHECK: tensor.pad {{.*}} low[0, 0, 0, 0]
109
// CHECK: : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32>
1110
// CHECK: tensor.expand_shape %{{.*}} [{{.*}}[0, 1], [2, 3], [4], [5]]
1211
// CHECK-SAME: : tensor<136x64x16x16xf32> into tensor<17x8x2x32x16x16xf32>
@@ -33,8 +32,7 @@ transform.sequence failures(propagate) {
3332
func.func @pack(%arg0: tensor<128x8xf32>, %arg1: tensor<8x8x16x1xf32>) -> tensor<8x8x16x1xf32> {
3433

3534
// tensor.pack is lowered to tensor.pad + tensor.expand_shape + linalg.transpose
36-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
37-
// CHECK: tensor.pad {{.*}} low[%[[C0]], %[[C0]]]
35+
// CHECK: tensor.pad {{.*}} low[0, 0]
3836
// CHECK: : tensor<128x8xf32> to tensor<128x8xf32>
3937
// CHECK: tensor.expand_shape %{{.*}} [{{.*}}[0, 1], [2, 3]]
4038
// CHECK-SAME: : tensor<128x8xf32> into tensor<8x16x8x1xf32>
@@ -64,8 +62,7 @@ func.func @pack_as_pad(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x13
6462
%cst_0 = arith.constant 0.0 : f32
6563

6664
// tensor.pack is lowered to tensor.pad + tensor.insert_slice
67-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
68-
// CHECK: %[[PAD:.*]] = tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]
65+
// CHECK: %[[PAD:.*]] = tensor.pad {{.*}} low[0, 0, 0, 0]
6966
// CHECK: : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32>
7067
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x1x1x1x136x64x16x16xf32>
7168
// CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[EMPTY]]
@@ -100,8 +97,7 @@ transform.sequence failures(propagate) {
10097
func.func @pack_not_a_pad(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x16x16x136x64xf32>) -> tensor<1x1x16x16x136x64xf32> {
10198
%cst_0 = arith.constant 0.0 : f32
10299

103-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
104-
// CHECK: tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]
100+
// CHECK: tensor.pad {{.*}} low[0, 0, 0, 0]
105101
// CHECK: : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32>
106102
// CHECK: tensor.expand_shape %{{.*}} [{{.*}}[0, 1], [2, 3], [4], [5]]
107103
// CHECK-SAME: : tensor<136x64x16x16xf32> into tensor<1x136x1x64x16x16xf32>
@@ -190,8 +186,7 @@ transform.sequence failures(propagate) {
190186
func.func @pack_with_outer_dims_perm(%src: tensor<100x200x128x256xi32>,
191187
%dest: tensor<200x4x16x100x16x32xi32>)
192188
-> tensor<200x4x16x100x16x32xi32> {
193-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
194-
// CHECK: tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]
189+
// CHECK: tensor.pad {{.*}} low[0, 0, 0, 0]
195190
// CHECK: : tensor<100x200x128x256xi32> to tensor<100x200x128x256xi32>
196191
// CHECK: tensor.expand_shape %{{.*}} [{{.*}}[0], [1], [2, 3], [4, 5]]
197192
// CHECK-SAME: : tensor<100x200x128x256xi32> into tensor<100x200x4x32x16x16xi32>
@@ -221,8 +216,7 @@ transform.sequence failures(propagate) {
221216
func.func @pack_with_pad_and_outer_dims_perm(%src: tensor<100x200x127x255xi32>,
222217
%dest: tensor<200x4x16x100x16x32xi32>)
223218
-> tensor<200x4x16x100x16x32xi32> {
224-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
225-
// CHECK: tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]
219+
// CHECK: tensor.pad {{.*}} low[0, 0, 0, 0]
226220
// CHECK: : tensor<100x200x127x255xi32> to tensor<100x200x128x256xi32>
227221
// CHECK: tensor.expand_shape %{{.*}} [{{.*}}[0], [1], [2, 3], [4, 5]]
228222
// CHECK-SAME: : tensor<100x200x128x256xi32> into tensor<100x200x4x32x16x16xi32>
@@ -250,13 +244,64 @@ transform.sequence failures(propagate) {
250244

251245
// -----
252246

247+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (-s0 + (s0 ceildiv 16) * 16)>
248+
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (-s0 + (s0 ceildiv 32) * 32)>
249+
// CHECK: func.func @dynamic_pack_pad_transpose_inner_and_outer_dims(
250+
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]]
251+
func.func @dynamic_pack_pad_transpose_inner_and_outer_dims(%source: tensor<?x?xf32>) -> tensor<?x?x16x32xf32> {
252+
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
253+
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
254+
// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index
255+
// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index
256+
// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[SRC]], %[[C0]]
257+
// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[SRC]], %[[C1]]
258+
// CHECK-DAG: %[[OUT_D0:.+]] = arith.ceildivui %[[D1]], %[[C16]] : index
259+
// CHECK-DAG: %[[OUT_D1:.+]] = arith.ceildivui %[[D0]], %[[C32]] : index
260+
// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty(%[[OUT_D0]], %[[OUT_D1]]) : tensor<?x?x16x32xf32>
261+
// CHECK-DAG: %[[H1:.+]] = affine.apply #[[MAP0]]()[%[[D1]]]
262+
// CHECK-DAG: %[[H0:.+]] = affine.apply #[[MAP1]]()[%[[D0]]]
263+
// CHECK: %[[PAD:.+]] = tensor.pad %[[SRC]] low[0, 0] high[%[[H0]], %[[H1]]]
264+
// CHECK: : tensor<?x?xf32> to tensor<?x?xf32>
265+
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]] {{\[}}[0, 1], [2, 3]]
266+
// CHECK-SAME: : tensor<?x?xf32> into tensor<?x32x?x16xf32>
267+
// CHECK: %[[TRANSP:.+]] = linalg.transpose
268+
// CHECK-SAME: ins(%[[EXPAND]] : tensor<?x32x?x16xf32>)
269+
// CHECK-SAME: outs(%[[EMPTY]] : tensor<?x?x16x32xf32>)
270+
// CHECK-SAME: permutation = [2, 0, 3, 1]
271+
// CHECK: return %[[TRANSP]]
272+
%c0 = arith.constant 0 : index
273+
%c1 = arith.constant 1 : index
274+
%d0 = tensor.dim %source, %c0 : tensor<?x?xf32>
275+
%d1 = tensor.dim %source, %c1 : tensor<?x?xf32>
276+
%padding_value = arith.constant 0.0 : f32
277+
278+
%c16 = arith.constant 16 : index
279+
%c32 = arith.constant 32 : index
280+
%tiled_d0 = arith.ceildivui %d0, %c32 : index
281+
%tiled_d1 = arith.ceildivui %d1, %c16 : index
282+
%init_pack = tensor.empty(%tiled_d1, %tiled_d0) : tensor<?x?x16x32xf32>
283+
%pack = tensor.pack %source padding_value(%padding_value : f32)
284+
outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 32] into %init_pack
285+
: tensor<?x?xf32> -> tensor<?x?x16x32xf32>
286+
return %pack : tensor<?x?x16x32xf32>
287+
}
288+
289+
transform.sequence failures(propagate) {
290+
^bb1(%module_op: !pdl.operation):
291+
%pack = transform.structured.match ops{["tensor.pack"]} in %module_op
292+
: (!pdl.operation) -> !transform.op<"tensor.pack">
293+
transform.structured.lower_pack %pack : (!transform.op<"tensor.pack">)
294+
-> (!transform.op<"tensor.pad">, !transform.op<"tensor.expand_shape">, !transform.op<"linalg.transpose">)
295+
}
296+
297+
// -----
298+
253299
// CHECK-LABEL: func.func @pack_as_pad_with_outer_dims_perm(
254300
func.func @pack_as_pad_with_outer_dims_perm(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<1x1x1x1x136x64x16x16xf32>) -> tensor<1x1x1x1x136x64x16x16xf32> {
255301
%cst_0 = arith.constant 0.0 : f32
256302

257303
// tensor.pack is lowered to tensor.pad + tensor.insert_slice
258-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
259-
// CHECK: %[[PAD:.*]] = tensor.pad {{.*}} low[%[[C0]], %[[C0]], %[[C0]], %[[C0]]]
304+
// CHECK: %[[PAD:.*]] = tensor.pad {{.*}} low[0, 0, 0, 0]
260305
// CHECK: : tensor<129x47x16x16xf32> to tensor<136x64x16x16xf32>
261306
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x1x1x1x136x64x16x16xf32>
262307
// CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[EMPTY]]

0 commit comments

Comments
 (0)