Skip to content

Commit 8c6da76

Browse files
[mlir][Transform] Fix applyToOne corner case when no op is matched.
Such situations manifest themselves with an empty payload which ends up producing empty results. In such cases, we still want to match the transform op contract and return as many empty SmallVector<Operation*> as the op requires. Differential Revision: https://reviews.llvm.org/D128456
1 parent fbf611e commit 8c6da76

File tree

2 files changed

+42
-6
lines changed

2 files changed

+42
-6
lines changed

mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -824,21 +824,35 @@ mlir::transform::TransformEachOpTrait<OpTy>::apply(
824824
decltype(&OpTy::applyToOne)>::template arg_t<0>;
825825
ArrayRef<Operation *> targets =
826826
state.getPayloadOps(this->getOperation()->getOperand(0));
827+
// Handle the corner case where no target is specified.
828+
// This is typically the case when the matcher fails to apply and we need to
829+
// propagate gracefully.
830+
// In this case, we fill all results with an empty vector.
831+
if (targets.empty()) {
832+
SmallVector<Operation *> emptyResult;
833+
for (auto r : this->getOperation()->getResults())
834+
transformResults.set(r.template cast<OpResult>(), emptyResult);
835+
return DiagnosedSilenceableFailure::success();
836+
}
837+
827838
SmallVector<SmallVector<Operation *>, 1> results;
828839
// In the multi-result case, collect the number of results each transform
829840
// produced.
830841
DiagnosedSilenceableFailure result = detail::applyTransformToEach(
831842
targets, results, [&](TransformOpType specificOp) {
832843
return static_cast<OpTy *>(this)->applyToOne(specificOp, state);
833844
});
845+
// Propagate the failure (definite or silencable) if any.
834846
if (!result.succeeded())
835847
return result;
836-
if (results.empty())
848+
849+
// Legitimately no results, bail early.
850+
if (results.empty() && OpTy::template hasTrait<OpTrait::ZeroResults>())
837851
return DiagnosedSilenceableFailure::success();
838852

839853
// Ensure all applications return the same number of results.
840854
// Variadic cases are much trickier to handle in a generic fashion.
841-
int64_t nRes = results[0].size();
855+
int64_t nRes = results.empty() ? 0 : results[0].size();
842856
if (llvm::any_of(results, [&](const auto &r) {
843857
return static_cast<int64_t>(r.size()) != nRes;
844858
})) {
@@ -849,6 +863,8 @@ mlir::transform::TransformEachOpTrait<OpTy>::apply(
849863
"generic `apply` instead of the specialized `applyToOne`";
850864
}
851865
// Ensure the number of results agrees with what the transform op expects.
866+
// Unless we see empty results, in which case we just want to propagate the
867+
// emptiness.
852868
if (this->getOperation()->getNumResults() != nRes) {
853869
InFlightDiagnostic diag = static_cast<OpTy *>(this)->emitError()
854870
<< "unexpected number of results (got " << nRes
@@ -857,10 +873,6 @@ mlir::transform::TransformEachOpTrait<OpTy>::apply(
857873
return DiagnosedSilenceableFailure::definiteFailure();
858874
}
859875

860-
// If no results, bail early.
861-
if (OpTy::template hasTrait<OpTrait::ZeroResults>())
862-
return DiagnosedSilenceableFailure::success();
863-
864876
// Perform transposition of M applications producing N results each into N
865877
// results for each of the M applications.
866878
SmallVector<SmallVector<Operation *, 1>> transposedResults =

mlir/test/Dialect/Transform/test-interpreter.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,3 +436,27 @@ transform.with_pdl_patterns {
436436
%1:2 = transform.test_correct_number_of_multi_results %0
437437
}
438438
}
439+
440+
// -----
441+
442+
func.func @foo() {
443+
"wrong_op_name" () : () -> ()
444+
return
445+
}
446+
447+
transform.with_pdl_patterns {
448+
^bb0(%arg0: !pdl.operation):
449+
pdl.pattern @some : benefit(1) {
450+
%0 = pdl.operands
451+
%1 = pdl.types
452+
%2 = pdl.operation "op"(%0 : !pdl.range<value>) -> (%1 : !pdl.range<type>)
453+
pdl.rewrite %2 with "transform.dialect"
454+
}
455+
456+
transform.sequence %arg0 {
457+
^bb0(%arg1: !pdl.operation):
458+
%0 = pdl_match @some in %arg1
459+
// Transform fails to match any but still produces 2 results.
460+
%1:2 = transform.test_correct_number_of_multi_results %0
461+
}
462+
}

0 commit comments

Comments
 (0)