Skip to content

[mlir][linalg][Transform] Fix use-after-free in SplitOp::apply #96390

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Jun 22, 2024

Detected with ASAN. Operation::getLoc() was called after erasing the operation.

Reverts 48cf6b6, which attempted to fix the use-after-free. (But the use-after-free is still there when the hasFailed branch is taken.)

@llvmbot
Copy link
Member

llvmbot commented Jun 22, 2024

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/96390.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+12-10)
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 37467db568c27..4eb334f8bbbfa 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2314,7 +2314,7 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
     }
   } else {
     chunkSizes.resize(payload.size(),
-                       rewriter.getIndexAttr(getStaticChunkSizes()));
+                      rewriter.getIndexAttr(getStaticChunkSizes()));
   }
 
   auto checkStructuredOpAndDimensions =
@@ -2327,7 +2327,7 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
 
     if (getDimension() >= linalgOp.getNumLoops()) {
       auto diag = emitSilenceableError() << "dimension " << getDimension()
-                                          << " does not exist in target op";
+                                         << " does not exist in target op";
       diag.attachNote(loc) << "target op";
       return diag;
     }
@@ -2368,6 +2368,7 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
         break;
 
       linalgOp = cast<LinalgOp>(target);
+      Location loc = target->getLoc();
 
       rewriter.setInsertionPoint(linalgOp);
       std::tie(head, tail) = linalg::splitOp(
@@ -2376,7 +2377,7 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
 
       // Propagate errors.
       DiagnosedSilenceableFailure diag =
-          checkFailureInSplitting(!head && !tail, target->getLoc());
+          checkFailureInSplitting(!head && !tail, loc);
       if (diag.isDefiniteFailure())
         return diag;
 
@@ -2395,6 +2396,7 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
     Operation *noSecondPart = nullptr;
     for (const auto &pair : llvm::zip(payload, chunkSizes)) {
       Operation *target = std::get<0>(pair);
+      Location loc = target->getLoc();
       LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
       DiagnosedSilenceableFailure diag =
           checkStructuredOpAndDimensions(linalgOp, target->getLoc());
@@ -2408,8 +2410,8 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
           getDimension(), std::get<1>(pair));
 
       // Propagate errors.
-      DiagnosedSilenceableFailure diagSplit = checkFailureInSplitting(
-          !first.back() && !second.back(), target->getLoc());
+      DiagnosedSilenceableFailure diagSplit =
+          checkFailureInSplitting(!first.back() && !second.back(), loc);
       if (diagSplit.isDefiniteFailure())
         return diag;
 
@@ -2718,8 +2720,8 @@ transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
 
     auto getI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
       return llvm::map_to_vector(values, [&](int64_t value) -> Attribute {
-            return builder.getI64IntegerAttr(value);
-          });
+        return builder.getI64IntegerAttr(value);
+      });
     };
     transformResults.setParams(cast<OpResult>(getTileSizes()),
                                getI64AttrsFromI64(spec->tileSizes));
@@ -2756,9 +2758,9 @@ transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
   }
 
   auto getDefiningOps = [&](ArrayRef<Value> values) {
-        return llvm::map_to_vector(values, [&](Value value) -> Operation * {
-          return value.getDefiningOp();
-        });
+    return llvm::map_to_vector(values, [&](Value value) -> Operation * {
+      return value.getDefiningOp();
+    });
   };
 
   transformResults.set(cast<OpResult>(getTileSizes()),

@llvmbot
Copy link
Member

llvmbot commented Jun 22, 2024

@llvm/pr-subscribers-mlir-linalg

Author: Matthias Springer (matthias-springer)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/96390.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+12-10)
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 37467db568c27..4eb334f8bbbfa 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2314,7 +2314,7 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
     }
   } else {
     chunkSizes.resize(payload.size(),
-                       rewriter.getIndexAttr(getStaticChunkSizes()));
+                      rewriter.getIndexAttr(getStaticChunkSizes()));
   }
 
   auto checkStructuredOpAndDimensions =
@@ -2327,7 +2327,7 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
 
     if (getDimension() >= linalgOp.getNumLoops()) {
       auto diag = emitSilenceableError() << "dimension " << getDimension()
-                                          << " does not exist in target op";
+                                         << " does not exist in target op";
       diag.attachNote(loc) << "target op";
       return diag;
     }
@@ -2368,6 +2368,7 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
         break;
 
       linalgOp = cast<LinalgOp>(target);
+      Location loc = target->getLoc();
 
       rewriter.setInsertionPoint(linalgOp);
       std::tie(head, tail) = linalg::splitOp(
@@ -2376,7 +2377,7 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
 
       // Propagate errors.
       DiagnosedSilenceableFailure diag =
-          checkFailureInSplitting(!head && !tail, target->getLoc());
+          checkFailureInSplitting(!head && !tail, loc);
       if (diag.isDefiniteFailure())
         return diag;
 
@@ -2395,6 +2396,7 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
     Operation *noSecondPart = nullptr;
     for (const auto &pair : llvm::zip(payload, chunkSizes)) {
       Operation *target = std::get<0>(pair);
+      Location loc = target->getLoc();
       LinalgOp linalgOp = dyn_cast<LinalgOp>(target);
       DiagnosedSilenceableFailure diag =
           checkStructuredOpAndDimensions(linalgOp, target->getLoc());
@@ -2408,8 +2410,8 @@ SplitOp::apply(transform::TransformRewriter &rewriter,
           getDimension(), std::get<1>(pair));
 
       // Propagate errors.
-      DiagnosedSilenceableFailure diagSplit = checkFailureInSplitting(
-          !first.back() && !second.back(), target->getLoc());
+      DiagnosedSilenceableFailure diagSplit =
+          checkFailureInSplitting(!first.back() && !second.back(), loc);
       if (diagSplit.isDefiniteFailure())
         return diag;
 
@@ -2718,8 +2720,8 @@ transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
 
     auto getI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
       return llvm::map_to_vector(values, [&](int64_t value) -> Attribute {
-            return builder.getI64IntegerAttr(value);
-          });
+        return builder.getI64IntegerAttr(value);
+      });
     };
     transformResults.setParams(cast<OpResult>(getTileSizes()),
                                getI64AttrsFromI64(spec->tileSizes));
@@ -2756,9 +2758,9 @@ transform::ContinuousTileSizesOp::apply(transform::TransformRewriter &rewriter,
   }
 
   auto getDefiningOps = [&](ArrayRef<Value> values) {
-        return llvm::map_to_vector(values, [&](Value value) -> Operation * {
-          return value.getDefiningOp();
-        });
+    return llvm::map_to_vector(values, [&](Value value) -> Operation * {
+      return value.getDefiningOp();
+    });
   };
 
   transformResults.set(cast<OpResult>(getTileSizes()),

Copy link
Member

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@matthias-springer matthias-springer force-pushed the users/matthias-springer/split_op_use_after_free branch from 3fa484c to de6b2cb Compare June 24, 2024 19:26
@matthias-springer matthias-springer merged commit f2d3d82 into main Jun 24, 2024
5 of 6 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/split_op_use_after_free branch June 24, 2024 19:35
AlexisPerry pushed a commit to llvm-project-tlp/llvm-project that referenced this pull request Jul 9, 2024
…m#96390)

Detected with ASAN. `Operation::getLoc()` was called after erasing the
operation.

Reverts 48cf6b6, which attempted to fix
the use-after-free. (But the use-after-free is still there when the
`hasFailed` branch is taken.)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants