Skip to content

Commit 3fa484c

Browse files
[mlir][linalg][Transform] Fix use-after-free in SplitOp::apply
1 parent 34d44eb commit 3fa484c

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2314,7 +2314,7 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
23142314
}
23152315
} else {
23162316
chunkSizes.resize(payload.size(),
2317-
rewriter.getIndexAttr(getStaticChunkSizes()));
2317+
rewriter.getIndexAttr(getStaticChunkSizes()));
23182318
}
23192319

23202320
auto checkStructuredOpAndDimensions =
@@ -2327,7 +2327,7 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
23272327

23282328
if (getDimension() >= linalgOp.getNumLoops()) {
23292329
auto diag = emitSilenceableError() << "dimension " << getDimension()
2330-
<< " does not exist in target op";
2330+
<< " does not exist in target op";
23312331
diag.attachNote(loc) << "target op";
23322332
return diag;
23332333
}
@@ -2368,6 +2368,7 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
23682368
break;
23692369

23702370
linalgOp = cast<LinalgOp>(target);
2371+
Location loc = target->getLoc();
23712372

23722373
rewriter.setInsertionPoint(linalgOp);
23732374
std::tie(head, tail) = linalg::splitOp(
@@ -2376,7 +2377,7 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
23762377

23772378
// Propagate errors.
23782379
DiagnosedSilenceableFailure diag =
2379-
checkFailureInSplitting(!head && !tail, target->getLoc());
2380+
checkFailureInSplitting(!head && !tail, loc);
23802381
if (diag.isDefiniteFailure())
23812382
return diag;
23822383

@@ -2395,6 +2396,7 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
23952396
Operation *noSecondPart = nullptr;
23962397
for (const auto &pair : llvm::zip(payload, chunkSizes)) {
23972398
Operation *target = std::get<0>(pair);
2399+
Location loc = target->getLoc();
23982400
LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
23992401
DiagnosedSilenceableFailure diag =
24002402
checkStructuredOpAndDimensions(linalgOp, target->getLoc());
@@ -2408,8 +2410,8 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
24082410
getDimension(), std::get<1>(pair));
24092411

24102412
// Propagate errors.
2411-
DiagnosedSilenceableFailure diagSplit = checkFailureInSplitting(
2412-
!first.back() && !second.back(), target->getLoc());
2413+
DiagnosedSilenceableFailure diagSplit =
2414+
checkFailureInSplitting(!first.back() && !second.back(), loc);
24132415
if (diagSplit.isDefiniteFailure())
24142416
return diag;
24152417

@@ -2718,8 +2720,8 @@ transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
27182720

27192721
auto getI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
27202722
return llvm::map_to_vector(values, [&](int64_t value) -> Attribute {
2721-
return builder.getI64IntegerAttr(value);
2722-
});
2723+
return builder.getI64IntegerAttr(value);
2724+
});
27232725
};
27242726
transformResults.setParams(cast<OpResult>(getTileSizes()),
27252727
getI64AttrsFromI64(spec->tileSizes));
@@ -2756,9 +2758,9 @@ transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
27562758
}
27572759

27582760
auto getDefiningOps = [&](ArrayRef<Value> values) {
2759-
return llvm::map_to_vector(values, [&](Value value) -> Operation * {
2760-
return value.getDefiningOp();
2761-
});
2761+
return llvm::map_to_vector(values, [&](Value value) -> Operation * {
2762+
return value.getDefiningOp();
2763+
});
27622764
};
27632765

27642766
transformResults.set(cast<OpResult>(getTileSizes()),

0 commit comments

Comments
 (0)