Skip to content

Commit 1b7b40b

Browse files
authored
[mlir][Linalg] Support lowerUnPack for identity out_dims_perm cases. (#79594)
1 parent 89cd345 commit 1b7b40b

File tree

2 files changed

+40
-2
lines changed

2 files changed

+40
-2
lines changed

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -377,8 +377,11 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
377377
FailureOr<LowerUnPackOpResult> linalg::lowerUnPack(RewriterBase &rewriter,
378378
tensor::UnPackOp unPackOp) {
379379
// 1. Filter out NYI cases.
380-
if (!unPackOp.getOuterDimsPerm().empty())
381-
return rewriter.notifyMatchFailure(unPackOp, "outer dims perm NYI");
380+
if (!unPackOp.getOuterDimsPerm().empty() &&
381+
!isIdentityPermutation(unPackOp.getOuterDimsPerm())) {
382+
return rewriter.notifyMatchFailure(unPackOp,
383+
"non-identity outer dims perm NYI");
384+
}
382385

383386
Location loc = unPackOp->getLoc();
384387
OpBuilder::InsertionGuard g(rewriter);

mlir/test/Dialect/Linalg/transform-lower-pack.mlir

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,41 @@ module attributes {transform.with_named_sequence} {
163163

164164
// -----
165165

166+
// CHECK-LABEL: func.func @unpack_with_identity_outer_dims_perm(
167+
func.func @unpack_with_identity_outer_dims_perm(%arg0: tensor<17x2x16x16x32x8xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> {
168+
%cst_0 = arith.constant 0.0 : f32
169+
// CHECK-SAME: %[[ARG0:.*]]: tensor<17x2x16x16x32x8xf32>, %[[ARG1:.*]]: tensor<129x47x16x16xf32>
170+
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<17x8x2x32x16x16xf32>
171+
// CHECK: %[[TRAN:.*]] = linalg.transpose
172+
// CHECK-SAME: ins(%[[ARG0]] : tensor<17x2x16x16x32x8xf32>)
173+
// CHECK-SAME: outs(%[[EMPTY]] : tensor<17x8x2x32x16x16xf32>)
174+
// CHECK-SAME: permutation = [0, 5, 1, 4, 2, 3]
175+
// CHECK: %[[CLP:.*]] = tensor.collapse_shape %[[TRAN]] {{\[}}[0, 1], [2, 3], [4], [5]]
176+
// CHECK-SAME: : tensor<17x8x2x32x16x16xf32> into tensor<136x64x16x16xf32>
177+
// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[CLP]][0, 0, 0, 0] [129, 47, 16, 16] [1, 1, 1, 1]
178+
// CHECK-SAME: : tensor<136x64x16x16xf32> to tensor<129x47x16x16xf32>
179+
// CHECK: linalg.copy ins(%[[SLICE]] : tensor<129x47x16x16xf32>)
180+
// CHECK-SAME: outs(%[[ARG1]] : tensor<129x47x16x16xf32>)
181+
%unpack = tensor.unpack %arg0 outer_dims_perm = [0, 1, 2, 3] inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %arg1
182+
: tensor<17x2x16x16x32x8xf32> -> tensor<129x47x16x16xf32>
183+
return %unpack : tensor<129x47x16x16xf32>
184+
}
185+
186+
module attributes {transform.with_named_sequence} {
187+
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
188+
%unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
189+
: (!transform.any_op) -> !transform.op<"tensor.unpack">
190+
transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">)
191+
-> (!transform.op<"tensor.empty">,
192+
!transform.op<"linalg.transpose">,
193+
!transform.op<"tensor.collapse_shape">,
194+
!transform.op<"tensor.extract_slice">)
195+
transform.yield
196+
}
197+
}
198+
199+
// -----
200+
166201
// When an unpack is a plain 'unpad', lower it to a simple extract_slice.
167202
// CHECK-LABEL: func.func @unpack_as_pad(
168203
func.func @unpack_as_pad(%arg0: tensor<1x1x1x1x136x64x16x16xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> {

0 commit comments

Comments
 (0)