Skip to content

Commit 034e365

Browse files
committed
[MLIR][Linalg] improve silenceable failure msg for lower_pack (NFC)
Adjust the silenceable failure message as we lower `tensor.unpack` as a combination of `linalg.transpose` + `tensor.collapse_shape` and `tensor.extract_slice`.
1 parent 1d56138 commit 034e365

File tree

2 files changed

+33
-5
lines changed

2 files changed

+33
-5
lines changed

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,8 +1125,11 @@ DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne(
11251125
rewriter.setInsertionPoint(target);
11261126
FailureOr<LowerUnPackOpResult> res = lowerUnPack(rewriter, target);
11271127
if (failed(res)) {
1128-
return mlir::emitSilenceableFailure(target->getLoc())
1129-
<< "cannot rewrite to pad + expand + transpose";
1128+
DiagnosedSilenceableFailure diag =
1129+
emitSilenceableError()
1130+
<< "cannot lower to transpose + collapse + extract";
1131+
diag.attachNote(target->getLoc()) << "target payload op";
1132+
return diag;
11301133
}
11311134
transformResults.push_back(res->emptyOp);
11321135
transformResults.push_back(res->transposeOp);

mlir/test/Dialect/Linalg/transform-lower-pack.mlir

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s -transform-interpreter -cse --split-input-file | FileCheck %s
1+
// RUN: mlir-opt %s -transform-interpreter -cse -verify-diagnostics -split-input-file | FileCheck %s
22

33
// CHECK-LABEL: func.func @pack(
44
func.func @pack(%arg0: tensor<129x47x16x16xf32>, %arg1: tensor<17x2x16x16x32x8xf32>) -> tensor<17x2x16x16x32x8xf32> {
@@ -143,9 +143,9 @@ func.func @unpack(%arg0: tensor<17x2x16x16x32x8xf32>, %arg1: tensor<129x47x16x16
143143
// CHECK-SAME: : tensor<136x64x16x16xf32> to tensor<129x47x16x16xf32>
144144
// CHECK: linalg.copy ins(%[[SLICE]] : tensor<129x47x16x16xf32>)
145145
// CHECK-SAME: outs(%[[ARG1]] : tensor<129x47x16x16xf32>)
146-
%pack = tensor.unpack %arg0 inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %arg1
146+
%unpack = tensor.unpack %arg0 inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %arg1
147147
: tensor<17x2x16x16x32x8xf32> -> tensor<129x47x16x16xf32>
148-
return %pack : tensor<129x47x16x16xf32>
148+
return %unpack : tensor<129x47x16x16xf32>
149149
}
150150

151151
module attributes {transform.with_named_sequence} {
@@ -162,6 +162,7 @@ module attributes {transform.with_named_sequence} {
162162
}
163163

164164
// -----
165+
165166
// When an unpack is a plain 'unpad', lower it to a simple extract_slice.
166167
// CHECK-LABEL: func.func @unpack_as_pad(
167168
func.func @unpack_as_pad(%arg0: tensor<1x1x1x1x136x64x16x16xf32>, %arg1: tensor<129x47x16x16xf32>) -> tensor<129x47x16x16xf32> {
@@ -460,3 +461,27 @@ module attributes {transform.with_named_sequence} {
460461
transform.yield
461462
}
462463
}
464+
465+
// -----
466+
467+
// At the moment, we cannot lower tensor.unpack with outer_dims_perm.
468+
func.func @diagnostic_unpack(%arg0: tensor<32x64xf32>, %arg1: tensor<2x4x32x8xf32>) -> tensor<32x64xf32> {
469+
// expected-note @below {{target payload op}}
470+
%unpack = tensor.unpack %arg1 outer_dims_perm = [1, 0]
471+
inner_dims_pos = [1, 0] inner_tiles = [32, 8] into %arg0 : tensor<2x4x32x8xf32> -> tensor<32x64xf32>
472+
return %unpack : tensor<32x64xf32>
473+
}
474+
475+
module attributes {transform.with_named_sequence} {
476+
transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) {
477+
%unpack = transform.structured.match ops{["tensor.unpack"]} in %module_op
478+
: (!transform.any_op) -> !transform.op<"tensor.unpack">
479+
// expected-error @below {{cannot lower to transpose + collapse + extract}}
480+
transform.structured.lower_unpack %unpack : (!transform.op<"tensor.unpack">)
481+
-> (!transform.op<"tensor.empty">,
482+
!transform.op<"linalg.transpose">,
483+
!transform.op<"tensor.collapse_shape">,
484+
!transform.op<"tensor.extract_slice">)
485+
transform.yield
486+
}
487+
}

0 commit comments

Comments
 (0)