Skip to content

Commit f07718b

Browse files
[mlir][transform] Improve error when merging of modules fails. (#69331)
This resolved #69112.
1 parent 1a7061c commit f07718b

File tree

4 files changed

+34
-27
lines changed

4 files changed

+34
-27
lines changed

mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ TransformOpInterface findTransformEntryPoint(
8181
// function to clone (or move) `other` in order to improve efficiency.
8282
// This might primarily make sense if we can also prune the symbols that
8383
// are merged to a subset (such as those that are actually used).
84-
LogicalResult mergeSymbolsInto(Operation *target,
85-
OwningOpRef<Operation *> other);
84+
InFlightDiagnostic mergeSymbolsInto(Operation *target,
85+
OwningOpRef<Operation *> other);
8686
} // namespace detail
8787

8888
/// Standalone util to apply the named sequence `transformRoot` to `payload` IR.

mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -337,11 +337,14 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
337337
diag.attachNote(target->getLoc()) << "pass anchor op";
338338
return diag;
339339
}
340-
if (failed(detail::mergeSymbolsInto(
341-
SymbolTable::getNearestSymbolTable(transformRoot),
342-
transformLibraryModule->get()->clone())))
343-
return emitError(transformRoot->getLoc(),
344-
"failed to merge library symbols into transform root");
340+
InFlightDiagnostic diag = detail::mergeSymbolsInto(
341+
SymbolTable::getNearestSymbolTable(transformRoot),
342+
transformLibraryModule->get()->clone());
343+
if (failed(diag)) {
344+
diag.attachNote(transformRoot->getLoc())
345+
<< "failed to merge library symbols into transform root";
346+
return diag;
347+
}
345348
}
346349

347350
// Step 4

mlir/lib/Dialect/Transform/Transforms/TransformInterpreterUtils.cpp

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,8 @@ LogicalResult transform::detail::assembleTransformLibraryFromPaths(
177177
for (OwningOpRef<ModuleOp> &parsedLibrary : parsedLibraries) {
178178
if (failed(transform::detail::mergeSymbolsInto(
179179
mergedParsedLibraries.get(), std::move(parsedLibrary))))
180-
return mergedParsedLibraries->emitError()
181-
<< "failed to verify merged transform module";
180+
return parsedLibrary->emitError()
181+
<< "failed to merge symbols into shared library module";
182182
}
183183
}
184184

@@ -197,8 +197,8 @@ static bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
197197
/// Merge `func1` into `func2`. The two ops must be inside the same parent op
198198
/// and mergable according to `canMergeInto`. The function erases `func1` such
199199
/// that only `func2` exists when the function returns.
200-
static LogicalResult mergeInto(FunctionOpInterface func1,
201-
FunctionOpInterface func2) {
200+
static InFlightDiagnostic mergeInto(FunctionOpInterface func1,
201+
FunctionOpInterface func2) {
202202
assert(canMergeInto(func1, func2));
203203
assert(func1->getParentOp() == func2->getParentOp() &&
204204
"expected func1 and func2 to be in the same parent op");
@@ -241,10 +241,10 @@ static LogicalResult mergeInto(FunctionOpInterface func1,
241241
assert(func1.isExternal());
242242
func1->erase();
243243

244-
return success();
244+
return InFlightDiagnostic();
245245
}
246246

247-
LogicalResult
247+
InFlightDiagnostic
248248
transform::detail::mergeSymbolsInto(Operation *target,
249249
OwningOpRef<Operation *> other) {
250250
assert(target->hasTrait<OpTrait::SymbolTable>() &&
@@ -301,7 +301,7 @@ transform::detail::mergeSymbolsInto(Operation *target,
301301
auto renameToUnique =
302302
[&](SymbolOpInterface op, SymbolOpInterface otherOp,
303303
SymbolTable &symbolTable,
304-
SymbolTable &otherSymbolTable) -> LogicalResult {
304+
SymbolTable &otherSymbolTable) -> InFlightDiagnostic {
305305
LLVM_DEBUG(llvm::dbgs() << ", renaming\n");
306306
FailureOr<StringAttr> maybeNewName =
307307
symbolTable.renameToUnique(op, {&otherSymbolTable});
@@ -313,19 +313,21 @@ transform::detail::mergeSymbolsInto(Operation *target,
313313
}
314314
LLVM_DEBUG(DBGS() << " renamed to @" << maybeNewName->getValue()
315315
<< "\n");
316-
return success();
316+
return InFlightDiagnostic();
317317
};
318318

319319
if (symbolOp.isPrivate()) {
320-
if (failed(renameToUnique(symbolOp, collidingOp, *symbolTable,
321-
*otherSymbolTable)))
322-
return failure();
320+
InFlightDiagnostic diag = renameToUnique(
321+
symbolOp, collidingOp, *symbolTable, *otherSymbolTable);
322+
if (failed(diag))
323+
return diag;
323324
continue;
324325
}
325326
if (collidingOp.isPrivate()) {
326-
if (failed(renameToUnique(collidingOp, symbolOp, *otherSymbolTable,
327-
*symbolTable)))
328-
return failure();
327+
InFlightDiagnostic diag = renameToUnique(
328+
collidingOp, symbolOp, *otherSymbolTable, *symbolTable);
329+
if (failed(diag))
330+
return diag;
329331
continue;
330332
}
331333
LLVM_DEBUG(llvm::dbgs() << ", emitting error\n");
@@ -394,8 +396,10 @@ transform::detail::mergeSymbolsInto(Operation *target,
394396
assert(targetSymbolTable.lookup(funcOp.getName()) == collidingFuncOp);
395397

396398
// Do the actual merging.
397-
if (failed(mergeInto(funcOp, collidingFuncOp))) {
398-
return failure();
399+
{
400+
InFlightDiagnostic diag = mergeInto(funcOp, collidingFuncOp);
401+
if (failed(diag))
402+
return diag;
399403
}
400404
}
401405
}
@@ -405,7 +409,7 @@ transform::detail::mergeSymbolsInto(Operation *target,
405409
<< "failed to verify target op after merging symbols";
406410

407411
LLVM_DEBUG(DBGS() << "done merging ops\n");
408-
return success();
412+
return InFlightDiagnostic();
409413
}
410414

411415
LogicalResult transform::applyTransformNamedSequence(

mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ module attributes {transform.with_named_sequence} {
88
// expected-error @below {{external definition has a mismatching signature}}
99
transform.named_sequence private @print_message(!transform.op<"builtin.module"> {transform.readonly})
1010

11-
// expected-error @below {{failed to merge library symbols into transform root}}
11+
// expected-note @below {{failed to merge library symbols into transform root}}
1212
transform.sequence failures(propagate) {
1313
^bb0(%arg0: !transform.op<"builtin.module">):
1414
include @print_message failures(propagate) (%arg0) : (!transform.op<"builtin.module">) -> ()
@@ -33,7 +33,7 @@ module attributes {transform.with_named_sequence} {
3333
// expected-error @below {{external definition has mismatching consumption annotations for argument #0}}
3434
transform.named_sequence private @consuming(%arg0: !transform.any_op {transform.readonly})
3535

36-
// expected-error @below {{failed to merge library symbols into transform root}}
36+
// expected-note @below {{failed to merge library symbols into transform root}}
3737
transform.sequence failures(suppress) {
3838
^bb0(%arg0: !transform.any_op):
3939
include @consuming failures(suppress) (%arg0) : (!transform.any_op) -> ()
@@ -49,7 +49,7 @@ module attributes {transform.with_named_sequence} {
4949
transform.yield
5050
}
5151

52-
// expected-error @below {{failed to merge library symbols into transform root}}
52+
// expected-note @below {{failed to merge library symbols into transform root}}
5353
transform.sequence failures(suppress) {
5454
^bb0(%arg0: !transform.any_op):
5555
include @print_message failures(propagate) (%arg0) : (!transform.any_op) -> ()

0 commit comments

Comments
 (0)