Skip to content

Commit f2d3d82

Browse files
[mlir][linalg][Transform] Fix use-after-free in SplitOp::apply (llvm#96390)
Detected with ASAN. `Operation::getLoc()` was called after erasing the operation. Reverts 48cf6b6, which attempted to fix the use-after-free. (But the use-after-free is still there when the `hasFailed` branch is taken.)
1 parent 09c0337 commit f2d3d82

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2314,7 +2314,7 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
23142314
}
23152315
} else {
23162316
chunkSizes.resize(payload.size(),
2317-
rewriter.getIndexAttr(getStaticChunkSizes()));
2317+
rewriter.getIndexAttr(getStaticChunkSizes()));
23182318
}
23192319

23202320
auto checkStructuredOpAndDimensions =
@@ -2327,18 +2327,18 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
23272327

23282328
if (getDimension() >= linalgOp.getNumLoops()) {
23292329
auto diag = emitSilenceableError() << "dimension " << getDimension()
2330-
<< " does not exist in target op";
2330+
<< " does not exist in target op";
23312331
diag.attachNote(loc) << "target op";
23322332
return diag;
23332333
}
23342334
return DiagnosedSilenceableFailure::success();
23352335
};
23362336

23372337
auto checkFailureInSplitting =
2338-
[&](bool hasFailed, Operation *op) -> DiagnosedSilenceableFailure {
2338+
[&](bool hasFailed, Location loc) -> DiagnosedSilenceableFailure {
23392339
if (hasFailed) {
23402340
auto diag = emitDefiniteFailure() << "internal failure in splitting";
2341-
diag.attachNote(op->getLoc()) << "target op";
2341+
diag.attachNote(loc) << "target op";
23422342
return diag;
23432343
}
23442344
return DiagnosedSilenceableFailure::success();
@@ -2368,6 +2368,7 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
23682368
break;
23692369

23702370
linalgOp = cast<LinalgOp>(target);
2371+
Location loc = target->getLoc();
23712372

23722373
rewriter.setInsertionPoint(linalgOp);
23732374
std::tie(head, tail) = linalg::splitOp(
@@ -2376,7 +2377,7 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
23762377

23772378
// Propagate errors.
23782379
DiagnosedSilenceableFailure diag =
2379-
checkFailureInSplitting(!head && !tail, target);
2380+
checkFailureInSplitting(!head && !tail, loc);
23802381
if (diag.isDefiniteFailure())
23812382
return diag;
23822383

@@ -2395,6 +2396,7 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
23952396
Operation *noSecondPart = nullptr;
23962397
for (const auto &pair : llvm::zip(payload, chunkSizes)) {
23972398
Operation *target = std::get<0>(pair);
2399+
Location loc = target->getLoc();
23982400
LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
23992401
DiagnosedSilenceableFailure diag =
24002402
checkStructuredOpAndDimensions(linalgOp, target->getLoc());
@@ -2409,7 +2411,7 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
24092411

24102412
// Propagate errors.
24112413
DiagnosedSilenceableFailure diagSplit =
2412-
checkFailureInSplitting(!first.back() && !second.back(), target);
2414+
checkFailureInSplitting(!first.back() && !second.back(), loc);
24132415
if (diagSplit.isDefiniteFailure())
24142416
return diag;
24152417

@@ -2718,8 +2720,8 @@ transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
27182720

27192721
auto getI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
27202722
return llvm::map_to_vector(values, [&](int64_t value) -> Attribute {
2721-
return builder.getI64IntegerAttr(value);
2722-
});
2723+
return builder.getI64IntegerAttr(value);
2724+
});
27232725
};
27242726
transformResults.setParams(cast<OpResult>(getTileSizes()),
27252727
getI64AttrsFromI64(spec->tileSizes));
@@ -2756,9 +2758,9 @@ transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
27562758
}
27572759

27582760
auto getDefiningOps = [&](ArrayRef<Value> values) {
2759-
return llvm::map_to_vector(values, [&](Value value) -> Operation * {
2760-
return value.getDefiningOp();
2761-
});
2761+
return llvm::map_to_vector(values, [&](Value value) -> Operation * {
2762+
return value.getDefiningOp();
2763+
});
27622764
};
27632765

27642766
transformResults.set(cast<OpResult>(getTileSizes()),

0 commit comments

Comments
 (0)