-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][linalg][Transform] Fix use-after-free in SplitOp::apply
#96390
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][linalg][Transform] Fix use-after-free in SplitOp::apply
#96390
Conversation
@llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesFull diff: https://github.com/llvm/llvm-project/pull/96390.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 37467db568c27..4eb334f8bbbfa 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2314,7 +2314,7 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
}
} else {
chunkSizes.resize(payload.size(),
- rewriter.getIndexAttr(getStaticChunkSizes()));
+ rewriter.getIndexAttr(getStaticChunkSizes()));
}
auto checkStructuredOpAndDimensions =
@@ -2327,7 +2327,7 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
if (getDimension() >= linalgOp.getNumLoops()) {
auto diag = emitSilenceableError() << "dimension " << getDimension()
- << " does not exist in target op";
+ << " does not exist in target op";
diag.attachNote(loc) << "target op";
return diag;
}
@@ -2368,6 +2368,7 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
break;
linalgOp = cast<LinalgOp>(target);
+ Location loc = target->getLoc();
rewriter.setInsertionPoint(linalgOp);
std::tie(head, tail) = linalg::splitOp(
@@ -2376,7 +2377,7 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
// Propagate errors.
DiagnosedSilenceableFailure diag =
- checkFailureInSplitting(!head && !tail, target->getLoc());
+ checkFailureInSplitting(!head && !tail, loc);
if (diag.isDefiniteFailure())
return diag;
@@ -2395,6 +2396,7 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
Operation *noSecondPart = nullptr;
for (const auto &pair : llvm::zip(payload, chunkSizes)) {
Operation *target = std::get<0>(pair);
+ Location loc = target->getLoc();
LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
DiagnosedSilenceableFailure diag =
checkStructuredOpAndDimensions(linalgOp, target->getLoc());
@@ -2408,8 +2410,8 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
getDimension(), std::get<1>(pair));
// Propagate errors.
- DiagnosedSilenceableFailure diagSplit = checkFailureInSplitting(
- !first.back() && !second.back(), target->getLoc());
+ DiagnosedSilenceableFailure diagSplit =
+ checkFailureInSplitting(!first.back() && !second.back(), loc);
if (diagSplit.isDefiniteFailure())
return diag;
@@ -2718,8 +2720,8 @@ transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
auto getI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
return llvm::map_to_vector(values, [&](int64_t value) -> Attribute {
- return builder.getI64IntegerAttr(value);
- });
+ return builder.getI64IntegerAttr(value);
+ });
};
transformResults.setParams(cast<OpResult>(getTileSizes()),
getI64AttrsFromI64(spec->tileSizes));
@@ -2756,9 +2758,9 @@ transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
}
auto getDefiningOps = [&](ArrayRef<Value> values) {
- return llvm::map_to_vector(values, [&](Value value) -> Operation * {
- return value.getDefiningOp();
- });
+ return llvm::map_to_vector(values, [&](Value value) -> Operation * {
+ return value.getDefiningOp();
+ });
};
transformResults.set(cast<OpResult>(getTileSizes()),
|
@llvm/pr-subscribers-mlir-linalg Author: Matthias Springer (matthias-springer) ChangesFull diff: https://github.com/llvm/llvm-project/pull/96390.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 37467db568c27..4eb334f8bbbfa 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2314,7 +2314,7 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
}
} else {
chunkSizes.resize(payload.size(),
- rewriter.getIndexAttr(getStaticChunkSizes()));
+ rewriter.getIndexAttr(getStaticChunkSizes()));
}
auto checkStructuredOpAndDimensions =
@@ -2327,7 +2327,7 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
if (getDimension() >= linalgOp.getNumLoops()) {
auto diag = emitSilenceableError() << "dimension " << getDimension()
- << " does not exist in target op";
+ << " does not exist in target op";
diag.attachNote(loc) << "target op";
return diag;
}
@@ -2368,6 +2368,7 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
break;
linalgOp = cast<LinalgOp>(target);
+ Location loc = target->getLoc();
rewriter.setInsertionPoint(linalgOp);
std::tie(head, tail) = linalg::splitOp(
@@ -2376,7 +2377,7 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
// Propagate errors.
DiagnosedSilenceableFailure diag =
- checkFailureInSplitting(!head && !tail, target->getLoc());
+ checkFailureInSplitting(!head && !tail, loc);
if (diag.isDefiniteFailure())
return diag;
@@ -2395,6 +2396,7 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
Operation *noSecondPart = nullptr;
for (const auto &pair : llvm::zip(payload, chunkSizes)) {
Operation *target = std::get<0>(pair);
+ Location loc = target->getLoc();
LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
DiagnosedSilenceableFailure diag =
checkStructuredOpAndDimensions(linalgOp, target->getLoc());
@@ -2408,8 +2410,8 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
getDimension(), std::get<1>(pair));
// Propagate errors.
- DiagnosedSilenceableFailure diagSplit = checkFailureInSplitting(
- !first.back() && !second.back(), target->getLoc());
+ DiagnosedSilenceableFailure diagSplit =
+ checkFailureInSplitting(!first.back() && !second.back(), loc);
if (diagSplit.isDefiniteFailure())
return diag;
@@ -2718,8 +2720,8 @@ transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
auto getI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
return llvm::map_to_vector(values, [&](int64_t value) -> Attribute {
- return builder.getI64IntegerAttr(value);
- });
+ return builder.getI64IntegerAttr(value);
+ });
};
transformResults.setParams(cast<OpResult>(getTileSizes()),
getI64AttrsFromI64(spec->tileSizes));
@@ -2756,9 +2758,9 @@ transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
}
auto getDefiningOps = [&](ArrayRef<Value> values) {
- return llvm::map_to_vector(values, [&](Value value) -> Operation * {
- return value.getDefiningOp();
- });
+ return llvm::map_to_vector(values, [&](Value value) -> Operation * {
+ return value.getDefiningOp();
+ });
};
transformResults.set(cast<OpResult>(getTileSizes()),
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
3fa484c
to
de6b2cb
Compare
Detected with ASAN.
Operation::getLoc()
was called after erasing the operation.Reverts 48cf6b6, which attempted to fix the use-after-free. (But the use-after-free is still there when the
hasFailed
branch is taken.)