Skip to content

Commit 6d2406c

Browse files
ingomueller-netzahiraam
authored andcommitted
[mlir][transform] Fix crash in transform.get_parent_op. (llvm#66492)
The previous implementation crashed if run on a `builtin.module` using an `op_name` filter (because the initial value of `parent` in the while loop was a `nullptr`). This PR fixes the crash and adds a test.
1 parent 2494321 commit 6d2406c

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

mlir/lib/Dialect/Transform/IR/TransformOps.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1233,15 +1233,16 @@ transform::GetParentOp::apply(transform::TransformRewriter &rewriter,
12331233
DenseSet<Operation *> resultSet;
12341234
for (Operation *target : state.getPayloadOps(getTarget())) {
12351235
Operation *parent = target->getParentOp();
1236-
do {
1236+
while (parent) {
12371237
bool checkIsolatedFromAbove =
12381238
!getIsolatedFromAbove() ||
12391239
parent->hasTrait<OpTrait::IsIsolatedFromAbove>();
12401240
bool checkOpName = !getOpName().has_value() ||
12411241
parent->getName().getStringRef() == *getOpName();
12421242
if (checkIsolatedFromAbove && checkOpName)
12431243
break;
1244-
} while ((parent = parent->getParentOp()));
1244+
parent = parent->getParentOp();
1245+
}
12451246
if (!parent) {
12461247
DiagnosedSilenceableFailure diag =
12471248
emitSilenceableError()

mlir/test/Dialect/Transform/test-interpreter.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1891,6 +1891,18 @@ transform.sequence failures(propagate) {
18911891
test_print_number_of_associated_payload_ir_ops %4 : !transform.any_op
18921892
}
18931893

1894+
1895+
// -----
1896+
1897+
// expected-note @below {{target op}}
1898+
module {
1899+
transform.sequence failures(propagate) {
1900+
^bb0(%arg0: !transform.any_op):
1901+
// expected-error @below{{could not find a parent op that matches all requirements}}
1902+
%3 = get_parent_op %arg0 {op_name = "builtin.module"} : (!transform.any_op) -> !transform.any_op
1903+
}
1904+
}
1905+
18941906
// -----
18951907

18961908
func.func @cast(%arg0: f32) -> f64 {

0 commit comments

Comments
 (0)