Skip to content

[mlir][linalg][Transform] Fix use-after-free in SplitOp::apply #96390

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2314,7 +2314,7 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
}
} else {
chunkSizes.resize(payload.size(),
rewriter.getIndexAttr(getStaticChunkSizes()));
rewriter.getIndexAttr(getStaticChunkSizes()));
}

auto checkStructuredOpAndDimensions =
Expand All @@ -2327,18 +2327,18 @@ SplitOp::apply(transform::TransformRewriter &rewriter,

if (getDimension() >= linalgOp.getNumLoops()) {
auto diag = emitSilenceableError() << "dimension " << getDimension()
<< " does not exist in target op";
<< " does not exist in target op";
diag.attachNote(loc) << "target op";
return diag;
}
return DiagnosedSilenceableFailure::success();
};

auto checkFailureInSplitting =
[&](bool hasFailed, Operation *op) -> DiagnosedSilenceableFailure {
[&](bool hasFailed, Location loc) -> DiagnosedSilenceableFailure {
if (hasFailed) {
auto diag = emitDefiniteFailure() << "internal failure in splitting";
diag.attachNote(op->getLoc()) << "target op";
diag.attachNote(loc) << "target op";
return diag;
}
return DiagnosedSilenceableFailure::success();
Expand Down Expand Up @@ -2368,6 +2368,7 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
break;

linalgOp = cast<LinalgOp>(target);
Location loc = target->getLoc();

rewriter.setInsertionPoint(linalgOp);
std::tie(head, tail) = linalg::splitOp(
Expand All @@ -2376,7 +2377,7 @@ SplitOp::apply(transform::TransformRewriter &rewriter,

// Propagate errors.
DiagnosedSilenceableFailure diag =
checkFailureInSplitting(!head && !tail, target);
checkFailureInSplitting(!head && !tail, loc);
if (diag.isDefiniteFailure())
return diag;

Expand All @@ -2395,6 +2396,7 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
Operation *noSecondPart = nullptr;
for (const auto &pair : llvm::zip(payload, chunkSizes)) {
Operation *target = std::get<0>(pair);
Location loc = target->getLoc();
LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
DiagnosedSilenceableFailure diag =
checkStructuredOpAndDimensions(linalgOp, target->getLoc());
Expand All @@ -2409,7 +2411,7 @@ SplitOp::apply(transform::TransformRewriter &rewriter,

// Propagate errors.
DiagnosedSilenceableFailure diagSplit =
checkFailureInSplitting(!first.back() && !second.back(), target);
checkFailureInSplitting(!first.back() && !second.back(), loc);
if (diagSplit.isDefiniteFailure())
return diag;

Expand Down Expand Up @@ -2718,8 +2720,8 @@ transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,

auto getI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
return llvm::map_to_vector(values, [&](int64_t value) -> Attribute {
return builder.getI64IntegerAttr(value);
});
return builder.getI64IntegerAttr(value);
});
};
transformResults.setParams(cast<OpResult>(getTileSizes()),
getI64AttrsFromI64(spec->tileSizes));
Expand Down Expand Up @@ -2756,9 +2758,9 @@ transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
}

auto getDefiningOps = [&](ArrayRef<Value> values) {
return llvm::map_to_vector(values, [&](Value value) -> Operation * {
return value.getDefiningOp();
});
return llvm::map_to_vector(values, [&](Value value) -> Operation * {
return value.getDefiningOp();
});
};

transformResults.set(cast<OpResult>(getTileSizes()),
Expand Down
Loading