Skip to content

[mlir][acc] Fix async only api on data entry operations #122818

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 14, 2025

Conversation

razvanlupusoru
Copy link
Contributor

Data entry operations which are created from constructs with async clause that has no value (aka acc data copyin(var) async) end up holding an attribute array named to keep track of this information. However, in cases where async clause is not used, calling hasAsyncOnly ends up crashing since this attribute is not set.

Thus, to fix this issue, ensure that we check for this attribute before trying to walk the attribute array.

Data entry operations which are created from constructs with async
clause that has no value (aka `acc data copyin(var) async`) end up
holding an attribute array named to keep track of this information.
However, in cases where `async` clause is not used, calling
`hasAsyncOnly` ends up crashing since this attribute is not set.

Thus, to fix this issue, ensure that we check for this attribute before
trying to walk the attribute array.
@llvmbot
Copy link
Member

llvmbot commented Jan 13, 2025

@llvm/pr-subscribers-openacc
@llvm/pr-subscribers-mlir-openacc

@llvm/pr-subscribers-mlir

Author: Razvan Lupusoru (razvanlupusoru)

Changes

Data entry operations which are created from constructs with async clause that has no value (aka acc data copyin(var) async) end up holding an attribute array named to keep track of this information. However, in cases where async clause is not used, calling hasAsyncOnly ends up crashing since this attribute is not set.

Thus, to fix this issue, ensure that we check for this attribute before trying to walk the attribute array.


Full diff: https://github.com/llvm/llvm-project/pull/122818.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td (+8-2)
  • (modified) mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp (+88)
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index a47f70b168066e..c60eb5cc620a7d 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -445,7 +445,10 @@ class OpenACC_DataEntryOp<string mnemonic, string clause, string extraDescriptio
     }
     /// Return true if the op has the async attribute for the given device_type.
     bool hasAsyncOnly(mlir::acc::DeviceType deviceType) {
-      for (auto attr : getAsyncOnlyAttr()) {
+      mlir::ArrayAttr asyncOnly = getAsyncOnlyAttr();
+      if (!asyncOnly)
+        return false;
+      for (auto attr : asyncOnly) {
         auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
         if (deviceTypeAttr.getValue() == deviceType)
           return true;
@@ -817,7 +820,10 @@ class OpenACC_DataExitOp<string mnemonic, string clause, string extraDescription
     }
     /// Return true if the op has the async attribute for the given device_type.
     bool hasAsyncOnly(mlir::acc::DeviceType deviceType) {
-      for (auto attr : getAsyncOnlyAttr()) {
+      mlir::ArrayAttr asyncOnly = getAsyncOnlyAttr();
+      if (!asyncOnly)
+        return false;
+      for (auto attr : asyncOnly) {
         auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
         if (deviceTypeAttr.getValue() == deviceType)
           return true;
diff --git a/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp
index cfb8aa767b6f86..aa16421cbec512 100644
--- a/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp
+++ b/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp
@@ -77,6 +77,54 @@ TEST_F(OpenACCOpsTest, asyncOnlyTest) {
   testAsyncOnly<SerialOp>(b, context, loc, dtypes);
 }
 
+template <typename Op>
+void testAsyncOnlyDataEntry(OpBuilder &b, MLIRContext &context, Location loc,
+                            llvm::SmallVector<DeviceType> &dtypes) {
+  auto memrefTy = MemRefType::get({}, b.getI32Type());
+  OwningOpRef<memref::AllocaOp> varPtrOp =
+      b.create<memref::AllocaOp>(loc, memrefTy);
+
+  TypedValue<PointerLikeType> varPtr =
+      cast<TypedValue<PointerLikeType>>(varPtrOp->getResult());
+  OwningOpRef<Op> op = b.create<Op>(loc, varPtr,
+                                    /*structured=*/true, /*implicit=*/true);
+
+  EXPECT_FALSE(op->hasAsyncOnly());
+  for (auto d : dtypes)
+    EXPECT_FALSE(op->hasAsyncOnly(d));
+
+  auto dtypeNone = DeviceTypeAttr::get(&context, DeviceType::None);
+  op->setAsyncOnlyAttr(b.getArrayAttr({dtypeNone}));
+  EXPECT_TRUE(op->hasAsyncOnly());
+  EXPECT_TRUE(op->hasAsyncOnly(DeviceType::None));
+  op->removeAsyncOnlyAttr();
+
+  auto dtypeHost = DeviceTypeAttr::get(&context, DeviceType::Host);
+  op->setAsyncOnlyAttr(b.getArrayAttr({dtypeHost}));
+  EXPECT_TRUE(op->hasAsyncOnly(DeviceType::Host));
+  EXPECT_FALSE(op->hasAsyncOnly());
+  op->removeAsyncOnlyAttr();
+
+  auto dtypeStar = DeviceTypeAttr::get(&context, DeviceType::Star);
+  op->setAsyncOnlyAttr(b.getArrayAttr({dtypeHost, dtypeStar}));
+  EXPECT_TRUE(op->hasAsyncOnly(DeviceType::Star));
+  EXPECT_TRUE(op->hasAsyncOnly(DeviceType::Host));
+  EXPECT_FALSE(op->hasAsyncOnly());
+
+  op->removeAsyncOnlyAttr();
+}
+
+TEST_F(OpenACCOpsTest, asyncOnlyTestDataEntry) {
+  testAsyncOnlyDataEntry<DevicePtrOp>(b, context, loc, dtypes);
+  testAsyncOnlyDataEntry<PresentOp>(b, context, loc, dtypes);
+  testAsyncOnlyDataEntry<CopyinOp>(b, context, loc, dtypes);
+  testAsyncOnlyDataEntry<CreateOp>(b, context, loc, dtypes);
+  testAsyncOnlyDataEntry<NoCreateOp>(b, context, loc, dtypes);
+  testAsyncOnlyDataEntry<AttachOp>(b, context, loc, dtypes);
+  testAsyncOnlyDataEntry<UpdateDeviceOp>(b, context, loc, dtypes);
+  testAsyncOnlyDataEntry<UseDeviceOp>(b, context, loc, dtypes);
+}
+
 template <typename Op>
 void testAsyncValue(OpBuilder &b, MLIRContext &context, Location loc,
                     llvm::SmallVector<DeviceType> &dtypes) {
@@ -105,6 +153,46 @@ TEST_F(OpenACCOpsTest, asyncValueTest) {
   testAsyncValue<SerialOp>(b, context, loc, dtypes);
 }
 
+template <typename Op>
+void testAsyncValueDataEntry(OpBuilder &b, MLIRContext &context, Location loc,
+                             llvm::SmallVector<DeviceType> &dtypes) {
+  auto memrefTy = MemRefType::get({}, b.getI32Type());
+  OwningOpRef<memref::AllocaOp> varPtrOp =
+      b.create<memref::AllocaOp>(loc, memrefTy);
+
+  TypedValue<PointerLikeType> varPtr =
+      cast<TypedValue<PointerLikeType>>(varPtrOp->getResult());
+  OwningOpRef<Op> op = b.create<Op>(loc, varPtr,
+                                    /*structured=*/true, /*implicit=*/true);
+
+  mlir::Value empty;
+  EXPECT_EQ(op->getAsyncValue(), empty);
+  for (auto d : dtypes)
+    EXPECT_EQ(op->getAsyncValue(d), empty);
+
+  OwningOpRef<arith::ConstantIndexOp> val =
+      b.create<arith::ConstantIndexOp>(loc, 1);
+  auto dtypeNvidia = DeviceTypeAttr::get(&context, DeviceType::Nvidia);
+  op->setAsyncOperandsDeviceTypeAttr(b.getArrayAttr({dtypeNvidia}));
+  op->getAsyncOperandsMutable().assign(val->getResult());
+  EXPECT_EQ(op->getAsyncValue(), empty);
+  EXPECT_EQ(op->getAsyncValue(DeviceType::Nvidia), val->getResult());
+
+  op->getAsyncOperandsMutable().clear();
+  op->removeAsyncOperandsDeviceTypeAttr();
+}
+
+TEST_F(OpenACCOpsTest, asyncValueTestDataEntry) {
+  testAsyncValueDataEntry<DevicePtrOp>(b, context, loc, dtypes);
+  testAsyncValueDataEntry<PresentOp>(b, context, loc, dtypes);
+  testAsyncValueDataEntry<CopyinOp>(b, context, loc, dtypes);
+  testAsyncValueDataEntry<CreateOp>(b, context, loc, dtypes);
+  testAsyncValueDataEntry<NoCreateOp>(b, context, loc, dtypes);
+  testAsyncValueDataEntry<AttachOp>(b, context, loc, dtypes);
+  testAsyncValueDataEntry<UpdateDeviceOp>(b, context, loc, dtypes);
+  testAsyncValueDataEntry<UseDeviceOp>(b, context, loc, dtypes);
+}
+
 template <typename Op>
 void testNumGangsValues(OpBuilder &b, MLIRContext &context, Location loc,
                         llvm::SmallVector<DeviceType> &dtypes,

Copy link
Contributor

@clementval clementval left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks fir the fix

@nvptm
Copy link
Contributor

nvptm commented Jan 13, 2025

The hasAsyncOnly logic makes sense. Thanks for the fix.

@razvanlupusoru razvanlupusoru merged commit f4aec22 into llvm:main Jan 14, 2025
10 of 12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants