Skip to content

Commit b26ee97

Browse files
authored
[MLIR][Linalg] Support dynamic sizes in lower_unpack (#75494)
1 parent 644e6d7 commit b26ee97

File tree

2 files changed

+132
-9
lines changed

2 files changed

+132
-9
lines changed

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -380,17 +380,11 @@ FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
380380
if (!unPackOp.getOuterDimsPerm().empty())
381381
return rewriter.notifyMatchFailure(unPackOp, "outer dims perm NYI");
382382

383-
RankedTensorType packedTensorType = unPackOp.getSourceType();
384-
if (!packedTensorType.hasStaticShape()) {
385-
return rewriter.notifyMatchFailure(
386-
unPackOp,
387-
"non-static shape NYI, needs a more powerful tensor.expand_shape op");
388-
}
389-
390383
Location loc = unPackOp->getLoc();
391384
OpBuilder::InsertionGuard g(rewriter);
392385
rewriter.setInsertionPoint(unPackOp);
393386

387+
RankedTensorType packedTensorType = unPackOp.getSourceType();
394388
int64_t packedRank = packedTensorType.getRank();
395389

396390
OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
@@ -434,8 +428,14 @@ FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
434428
RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape);
435429
RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
436430
stripMinedTensorType, packingMetadata.reassociations);
437-
auto emptyOp =
438-
rewriter.create<tensor::EmptyOp>(loc, stripMinedTensorType, ValueRange{});
431+
432+
// Get dynamic dims from input tensor based on lastDimsToInsertPositionsPerm
433+
// permutation.
434+
SmallVector<OpFoldResult, 4> dims =
435+
tensor::getMixedSizes(rewriter, loc, unPackOp.getSource());
436+
applyPermutationToVector(dims, lastDimsToInsertPositionsPerm);
437+
auto emptyOp = rewriter.create<tensor::EmptyOp>(
438+
loc, dims, stripMinedTensorType.getElementType());
439439
auto transposeOp = rewriter.create<linalg::TransposeOp>(
440440
loc, unPackOp.getSource(), emptyOp, lastDimsToInsertPositionsPerm);
441441

mlir/test/Dialect/Linalg/transform-lower-pack.mlir

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,129 @@ module attributes {transform.with_named_sequence} {
464464

465465
// -----
466466

467+
// Check that we can lower unpack with dynamic dimensions in the input and destination.
468+
// CHECK-LABEL: func.func @unpack_with_dynamic_input_dest(
469+
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x8x16xf32>, %[[ARG1:.*]]: tensor<?x?xf32>)
470+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
471+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
472+
// CHECK-DAG: %[[DIM00:.*]] = tensor.dim %[[ARG0]], %[[C0]]
473+
// CHECK-DAG: %[[DIM01:.*]] = tensor.dim %[[ARG0]], %[[C1]]
474+
// CHECK: %[[EMPTY:.*]] = tensor.empty(%[[DIM00]], %[[DIM01]]) : tensor<?x8x?x16xf32>
475+
// CHECK: %[[TRAN:.*]] = linalg.transpose
476+
// CHECK-SAME: ins(%[[ARG0]] : tensor<?x?x8x16xf32>)
477+
// CHECK-SAME: outs(%[[EMPTY]] : tensor<?x8x?x16xf32>)
478+
// CHECK-SAME: permutation = [0, 2, 1, 3]
479+
// CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{\[}}[0, 1], [2, 3]]
480+
// CHECK-SAME: : tensor<?x8x?x16xf32> into tensor<?x?xf32>
481+
// CHECK: %[[DIM10:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
482+
// CHECK: %[[DIM11:.*]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
483+
// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0] [%[[DIM10]], %[[DIM11]]] [1, 1]
484+
// CHECK-SAME: : tensor<?x?xf32> to tensor<?x?xf32>
485+
// CHECK: linalg.copy ins(%[[SLICE]] : tensor<?x?xf32>)
486+
// CHECK-SAME: outs(%[[ARG1]] : tensor<?x?xf32>)
487+
func.func @unpack_with_dynamic_input_dest(%arg0: tensor<?x?x8x16xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
488+
%unpack = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 16] into %arg1 : tensor<?x?x8x16xf32> -> tensor<?x?xf32>
489+
return %unpack : tensor<?x?xf32>
490+
}
491+
492+
module attributes {transform.with_named_sequence} {
493+
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
494+
%unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
495+
: (!transform.any_op) -> !transform.op<"tensor.unpack">
496+
transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">)
497+
-> (!transform.op<"tensor.empty">,
498+
!transform.op<"linalg.transpose">,
499+
!transform.op<"tensor.collapse_shape">,
500+
!transform.op<"tensor.extract_slice">)
501+
transform.yield
502+
}
503+
}
504+
505+
// -----
506+
507+
// Check that we can lower unpack with dynamic dimensions in the input, destination, inner_tiles.
508+
// CHECK-LABEL: func.func @unpack_fully_dynamic(
509+
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x?x?xf32>, %[[ARG1:.*]]: tensor<?x?xf32>, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
510+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
511+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
512+
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
513+
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
514+
// CHECK-DAG: %[[DIM00:.*]] = tensor.dim %[[ARG0]], %[[C0]]
515+
// CHECK-DAG: %[[DIM01:.*]] = tensor.dim %[[ARG0]], %[[C1]]
516+
// CHECK-DAG: %[[DIM02:.*]] = tensor.dim %[[ARG0]], %[[C2]]
517+
// CHECK-DAG: %[[DIM03:.*]] = tensor.dim %[[ARG0]], %[[C3]]
518+
// CHECK: %[[EMPTY:.*]] = tensor.empty(%[[DIM00]], %[[DIM02]], %[[DIM01]], %[[DIM03]]) : tensor<?x?x?x?xf32>
519+
// CHECK: %[[TRAN:.*]] = linalg.transpose
520+
// CHECK-SAME: ins(%[[ARG0]] : tensor<?x?x?x?xf32>)
521+
// CHECK-SAME: outs(%[[EMPTY]] : tensor<?x?x?x?xf32>)
522+
// CHECK-SAME: permutation = [0, 2, 1, 3]
523+
// CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{\[}}[0, 1], [2, 3]]
524+
// CHECK-SAME: : tensor<?x?x?x?xf32> into tensor<?x?xf32>
525+
// CHECK: %[[DIM10:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor<?x?xf32>
526+
// CHECK: %[[DIM11:.*]] = tensor.dim %[[ARG1]], %[[C1]] : tensor<?x?xf32>
527+
// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0] [%[[DIM10]], %[[DIM11]]] [1, 1]
528+
// CHECK-SAME: : tensor<?x?xf32> to tensor<?x?xf32>
529+
// CHECK: linalg.copy ins(%[[SLICE]] : tensor<?x?xf32>)
530+
// CHECK-SAME: outs(%[[ARG1]] : tensor<?x?xf32>)
531+
func.func @unpack_fully_dynamic(%source: tensor<?x?x?x?xf32>, %dest: tensor<?x?xf32>, %tile_n : index, %tile_m : index) -> tensor<?x?xf32> {
532+
%0 = tensor.unpack %source inner_dims_pos = [0, 1] inner_tiles = [%tile_n, %tile_m] into %dest : tensor<?x?x?x?xf32> -> tensor<?x?xf32>
533+
return %0 : tensor<?x?xf32>
534+
}
535+
module attributes {transform.with_named_sequence} {
536+
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
537+
%unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
538+
: (!transform.any_op) -> !transform.op<"tensor.unpack">
539+
transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">)
540+
-> (!transform.op<"tensor.empty">,
541+
!transform.op<"linalg.transpose">,
542+
!transform.op<"tensor.collapse_shape">,
543+
!transform.op<"tensor.extract_slice">)
544+
transform.yield
545+
}
546+
}
547+
548+
// -----
549+
550+
// Check that we can lower unpack "as unpad" with dynamic dims.
551+
// CHECK-LABEL: func.func @unpack_as_pad_dynamic(
552+
// CHECK-SAME: %[[ARG0:.*]]: tensor<1x1x1x1x?x?x?x?xf32>, %[[ARG1:.*]]: tensor<?x?x?x?xf32>
553+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
554+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
555+
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
556+
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
557+
// CHECK-DAG: %[[DIM0:.*]] = tensor.dim %[[ARG1]], %[[C0]]
558+
// CHECK-DAG: %[[DIM1:.*]] = tensor.dim %[[ARG1]], %[[C1]]
559+
// CHECK-DAG: %[[DIM2:.*]] = tensor.dim %[[ARG1]], %[[C2]]
560+
// CHECK-DAG: %[[DIM3:.*]] = tensor.dim %[[ARG1]], %[[C3]]
561+
// CHECK: %[[RES:.*]] = tensor.extract_slice %[[ARG0]]
562+
// offsets.
563+
// CHECK-SAME: [0, 0, 0, 0, 0, 0, 0, 0]
564+
// sizes.
565+
// CHECK-SAME: [1, 1, 1, 1, %[[DIM0]], %[[DIM1]], %[[DIM2]], %[[DIM3]]]
566+
// strides multiplers.
567+
// CHECK-SAME: [1, 1, 1, 1, 1, 1, 1, 1]
568+
// CHECK-SAME: : tensor<1x1x1x1x?x?x?x?xf32> to tensor<?x?x?x?xf32>
569+
func.func @unpack_as_pad_dynamic(%arg0: tensor<1x1x1x1x?x?x?x?xf32>, %arg1: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
570+
%pack = tensor.unpack %arg0 inner_dims_pos = [0, 1, 2, 3] inner_tiles = [136, 64, 16, 16] into %arg1
571+
: tensor<1x1x1x1x?x?x?x?xf32> -> tensor<?x?x?x?xf32>
572+
return %pack : tensor<?x?x?x?xf32>
573+
}
574+
575+
module attributes {transform.with_named_sequence} {
576+
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
577+
%unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
578+
: (!transform.any_op) -> !transform.op<"tensor.unpack">
579+
transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">)
580+
-> (!transform.op<"tensor.empty">,
581+
!transform.op<"linalg.transpose">,
582+
!transform.op<"tensor.collapse_shape">,
583+
!transform.op<"tensor.extract_slice">)
584+
transform.yield
585+
}
586+
}
587+
588+
// -----
589+
467590
// At the moment, we cannot lower tensor.unpack with outer_dims_perm.
468591
func.func @diagnostic_unpack(%arg0: tensor<32x64xf32>, %arg1: tensor<2x4x32x8xf32>) -> tensor<32x64xf32> {
469592
// expected-note @below {{target payload op}}

0 commit comments

Comments
 (0)