Skip to content

[MLIR][OpenMP] Prevent loop wrapper translation crashes #115475

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 24, 2025

Conversation

skatrak
Copy link
Member

@skatrak skatrak commented Nov 8, 2024

This patch updates the convertOmpOpRegions translation function to prevent calling it for a loop wrapper region from causing a compiler crash due to a lack of terminator operations.

This problem is currently not triggered because there are no cases for which the region of a loop wrapper is passed to that function. This might have to change in order to support composite construct translation to LLVM IR.

This patch updates the `convertOmpOpRegions` translation function to prevent
calling it for a loop wrapper region from causing a compiler crash due to a
lack of terminator operations.

This problem is currently not triggered because there are no cases for which
the region of a loop wrapper is passed to that function. This might have to
change in order to support composite construct translation to LLVM IR.
@llvmbot
Copy link
Member

llvmbot commented Nov 8, 2024

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-flang-openmp

Author: Sergio Afonso (skatrak)

Changes

This patch updates the convertOmpOpRegions translation function to prevent calling it for a loop wrapper region from causing a compiler crash due to a lack of terminator operations.

This problem is currently not triggered because there are no cases for which the region of a loop wrapper is passed to that function. This might have to change in order to support composite construct translation to LLVM IR.


Full diff: https://github.com/llvm/llvm-project/pull/115475.diff

1 Files Affected:

  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+33-20)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index da11ee9960e1f9..b507fa656d601f 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -391,6 +391,8 @@ static llvm::Expected<llvm::BasicBlock *> convertOmpOpRegions(
     Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
     LLVM::ModuleTranslation &moduleTranslation,
     SmallVectorImpl<llvm::PHINode *> *continuationBlockPHIs = nullptr) {
+  bool isLoopWrapper = isa<omp::LoopWrapperInterface>(region.getParentOp());
+
   llvm::BasicBlock *continuationBlock =
       splitBB(builder, true, "omp.region.cont");
   llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
@@ -407,30 +409,34 @@ static llvm::Expected<llvm::BasicBlock *> convertOmpOpRegions(
 
   // Terminators (namely YieldOp) may be forwarding values to the region that
   // need to be available in the continuation block. Collect the types of these
-  // operands in preparation of creating PHI nodes.
+  // operands in preparation of creating PHI nodes. This is skipped for loop
+  // wrapper operations, for which we know in advance they have no terminators.
   SmallVector<llvm::Type *> continuationBlockPHITypes;
-  bool operandsProcessed = false;
   unsigned numYields = 0;
-  for (Block &bb : region.getBlocks()) {
-    if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
-      if (!operandsProcessed) {
-        for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
-          continuationBlockPHITypes.push_back(
-              moduleTranslation.convertType(yield->getOperand(i).getType()));
-        }
-        operandsProcessed = true;
-      } else {
-        assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
-               "mismatching number of values yielded from the region");
-        for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
-          llvm::Type *operandType =
-              moduleTranslation.convertType(yield->getOperand(i).getType());
-          (void)operandType;
-          assert(continuationBlockPHITypes[i] == operandType &&
-                 "values of mismatching types yielded from the region");
+
+  if (!isLoopWrapper) {
+    bool operandsProcessed = false;
+    for (Block &bb : region.getBlocks()) {
+      if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
+        if (!operandsProcessed) {
+          for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
+            continuationBlockPHITypes.push_back(
+                moduleTranslation.convertType(yield->getOperand(i).getType()));
+          }
+          operandsProcessed = true;
+        } else {
+          assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
+                 "mismatching number of values yielded from the region");
+          for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
+            llvm::Type *operandType =
+                moduleTranslation.convertType(yield->getOperand(i).getType());
+            (void)operandType;
+            assert(continuationBlockPHITypes[i] == operandType &&
+                   "values of mismatching types yielded from the region");
+          }
         }
+        numYields++;
       }
-      numYields++;
     }
   }
 
@@ -468,6 +474,13 @@ static llvm::Expected<llvm::BasicBlock *> convertOmpOpRegions(
             moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder)))
       return llvm::make_error<PreviouslyReportedError>();
 
+    // Create a direct branch here for loop wrappers to prevent their lack of a
+    // terminator from causing a crash below.
+    if (isLoopWrapper) {
+      builder.CreateBr(continuationBlock);
+      continue;
+    }
+
     // Special handling for `omp.yield` and `omp.terminator` (we may have more
     // than one): they return the control to the parent OpenMP dialect operation
     // so replace them with the branch to the continuation block. We handle this

@llvmbot
Copy link
Member

llvmbot commented Nov 8, 2024

@llvm/pr-subscribers-mlir-openmp

Author: Sergio Afonso (skatrak)

Changes

This patch updates the convertOmpOpRegions translation function to prevent calling it for a loop wrapper region from causing a compiler crash due to a lack of terminator operations.

This problem is currently not triggered because there are no cases for which the region of a loop wrapper is passed to that function. This might have to change in order to support composite construct translation to LLVM IR.


Full diff: https://github.com/llvm/llvm-project/pull/115475.diff

1 Files Affected:

  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+33-20)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index da11ee9960e1f9..b507fa656d601f 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -391,6 +391,8 @@ static llvm::Expected<llvm::BasicBlock *> convertOmpOpRegions(
     Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
     LLVM::ModuleTranslation &moduleTranslation,
     SmallVectorImpl<llvm::PHINode *> *continuationBlockPHIs = nullptr) {
+  bool isLoopWrapper = isa<omp::LoopWrapperInterface>(region.getParentOp());
+
   llvm::BasicBlock *continuationBlock =
       splitBB(builder, true, "omp.region.cont");
   llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
@@ -407,30 +409,34 @@ static llvm::Expected<llvm::BasicBlock *> convertOmpOpRegions(
 
   // Terminators (namely YieldOp) may be forwarding values to the region that
   // need to be available in the continuation block. Collect the types of these
-  // operands in preparation of creating PHI nodes.
+  // operands in preparation of creating PHI nodes. This is skipped for loop
+  // wrapper operations, for which we know in advance they have no terminators.
   SmallVector<llvm::Type *> continuationBlockPHITypes;
-  bool operandsProcessed = false;
   unsigned numYields = 0;
-  for (Block &bb : region.getBlocks()) {
-    if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
-      if (!operandsProcessed) {
-        for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
-          continuationBlockPHITypes.push_back(
-              moduleTranslation.convertType(yield->getOperand(i).getType()));
-        }
-        operandsProcessed = true;
-      } else {
-        assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
-               "mismatching number of values yielded from the region");
-        for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
-          llvm::Type *operandType =
-              moduleTranslation.convertType(yield->getOperand(i).getType());
-          (void)operandType;
-          assert(continuationBlockPHITypes[i] == operandType &&
-                 "values of mismatching types yielded from the region");
+
+  if (!isLoopWrapper) {
+    bool operandsProcessed = false;
+    for (Block &bb : region.getBlocks()) {
+      if (omp::YieldOp yield = dyn_cast<omp::YieldOp>(bb.getTerminator())) {
+        if (!operandsProcessed) {
+          for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
+            continuationBlockPHITypes.push_back(
+                moduleTranslation.convertType(yield->getOperand(i).getType()));
+          }
+          operandsProcessed = true;
+        } else {
+          assert(continuationBlockPHITypes.size() == yield->getNumOperands() &&
+                 "mismatching number of values yielded from the region");
+          for (unsigned i = 0, e = yield->getNumOperands(); i < e; ++i) {
+            llvm::Type *operandType =
+                moduleTranslation.convertType(yield->getOperand(i).getType());
+            (void)operandType;
+            assert(continuationBlockPHITypes[i] == operandType &&
+                   "values of mismatching types yielded from the region");
+          }
         }
+        numYields++;
       }
-      numYields++;
     }
   }
 
@@ -468,6 +474,13 @@ static llvm::Expected<llvm::BasicBlock *> convertOmpOpRegions(
             moduleTranslation.convertBlock(*bb, bb->isEntryBlock(), builder)))
       return llvm::make_error<PreviouslyReportedError>();
 
+    // Create a direct branch here for loop wrappers to prevent their lack of a
+    // terminator from causing a crash below.
+    if (isLoopWrapper) {
+      builder.CreateBr(continuationBlock);
+      continue;
+    }
+
     // Special handling for `omp.yield` and `omp.terminator` (we may have more
     // than one): they return the control to the parent OpenMP dialect operation
     // so replace them with the branch to the continuation block. We handle this

Copy link
Contributor

@kiranchandramohan kiranchandramohan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewable status: 0 of 1 files reviewed, 1 unresolved discussion (waiting on @agozillon, @bhandarkar-pranav, @DominikAdamski, @ergawy, @jsjodin, @kparzysz, @Meinersbur, @mjklemm, @raghavendhra, @skatrak, @tblah, and @TIFitis)


mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp line 482 at r1 (raw file):

      builder.CreateBr(continuationBlock);
      continue;
    }

Can we write a test for this?

@skatrak
Copy link
Member Author

skatrak commented Nov 8, 2024

Thanks @kiranchandramohan for the quick review!

Can we write a test for this?

At the moment we can't because this code path is not triggered by anything. Translation functions for loop wrappers currently call this function passing the region of the nested omp.loop_nest, so a loop wrapper region is never passed here. However, if you check the code directly below you can tell that this is doing the same thing as if an omp.terminator was present.

I think it's preferable to have this implemented before needing it rather than having it crash later on, since I think it was an oversight from when I implemented #112229. But let me know if you don't agree and I can keep it around until some composite construct lowering ends up triggering the issue.

Copy link
Contributor

@tblah tblah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This LGTM.

If there are currently no paths using this lowering, how did you find this bug? If you expect to upstream something in the near future that makes this testable (and you will make sure it is covered by tests for that thing) then I think it is okay.

@kiranchandramohan
Copy link
Contributor

Thanks @kiranchandramohan for the quick review!

Can we write a test for this?

At the moment we can't because this code path is not triggered by anything. Translation functions for loop wrappers currently call this function passing the region of the nested omp.loop_nest, so a loop wrapper region is never passed here. However, if you check the code directly below you can tell that this is doing the same thing as if an omp.terminator was present.

I think it's preferable to have this implemented before needing it rather than having it crash later on, since I think it was an oversight from when I implemented #112229. But let me know if you don't agree and I can keep it around until some composite construct lowering ends up triggering the issue.

The danger with this is that someone can remove this piece of code in refactoring and think it was accidentally introduced since no test fails. Ideally, such code can be introduced in the first patch that needs it. You can organize such patches as two commits, where the first commit is this patch and the second commit is the one that needs it

Note: I did not intend to request changes. I was trying reviewable.io and did not realise that it requested changes.

@skatrak
Copy link
Member Author

skatrak commented Nov 11, 2024

If there are currently no paths using this lowering, how did you find this bug?

The danger with this is that someone can remove this piece of code in refactoring and think it was accidentally introduced since no test fails. Ideally, such code can be introduced in the first patch that needs it. You can organize such patches as two commits, where the first commit is this patch and the second commit is the one that needs it

Thank you both for your reviews. I found this bug through our current downstream support for 'distribute parallel do'. However, it will take a bit of time to get to upstreaming it, since other pieces must make it first, and there's a chance that we might want to rethink how we do it before then (potentially circumventing this code path). So I'll keep this PR around until we're ready to upstream the code that would trigger this failure, and I'll close it if we end up not needing it at that point.

@skatrak skatrak merged commit 9b52d9e into llvm:main Feb 24, 2025
13 checks passed
@skatrak skatrak deleted the wrapper-terminator-fix branch February 24, 2025 11:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants