-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][acc] Fix async only api on data entry operations #122818
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Data entry operations which are created from constructs with async clause that has no value (aka `acc data copyin(var) async`) end up holding an attribute array named to keep track of this information. However, in cases where `async` clause is not used, calling `hasAsyncOnly` ends up crashing since this attribute is not set. Thus, to fix this issue, ensure that we check for this attribute before trying to walk the attribute array.
@llvm/pr-subscribers-openacc @llvm/pr-subscribers-mlir Author: Razvan Lupusoru (razvanlupusoru) ChangesData entry operations which are created from constructs with async clause that has no value (aka Thus, to fix this issue, ensure that we check for this attribute before trying to walk the attribute array. Full diff: https://github.com/llvm/llvm-project/pull/122818.diff 2 Files Affected:
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
index a47f70b168066e..c60eb5cc620a7d 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
@@ -445,7 +445,10 @@ class OpenACC_DataEntryOp<string mnemonic, string clause, string extraDescriptio
}
/// Return true if the op has the async attribute for the given device_type.
bool hasAsyncOnly(mlir::acc::DeviceType deviceType) {
- for (auto attr : getAsyncOnlyAttr()) {
+ mlir::ArrayAttr asyncOnly = getAsyncOnlyAttr();
+ if (!asyncOnly)
+ return false;
+ for (auto attr : asyncOnly) {
auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
if (deviceTypeAttr.getValue() == deviceType)
return true;
@@ -817,7 +820,10 @@ class OpenACC_DataExitOp<string mnemonic, string clause, string extraDescription
}
/// Return true if the op has the async attribute for the given device_type.
bool hasAsyncOnly(mlir::acc::DeviceType deviceType) {
- for (auto attr : getAsyncOnlyAttr()) {
+ mlir::ArrayAttr asyncOnly = getAsyncOnlyAttr();
+ if (!asyncOnly)
+ return false;
+ for (auto attr : asyncOnly) {
auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
if (deviceTypeAttr.getValue() == deviceType)
return true;
diff --git a/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp
index cfb8aa767b6f86..aa16421cbec512 100644
--- a/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp
+++ b/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp
@@ -77,6 +77,54 @@ TEST_F(OpenACCOpsTest, asyncOnlyTest) {
testAsyncOnly<SerialOp>(b, context, loc, dtypes);
}
+template <typename Op>
+void testAsyncOnlyDataEntry(OpBuilder &b, MLIRContext &context, Location loc,
+ llvm::SmallVector<DeviceType> &dtypes) {
+ auto memrefTy = MemRefType::get({}, b.getI32Type());
+ OwningOpRef<memref::AllocaOp> varPtrOp =
+ b.create<memref::AllocaOp>(loc, memrefTy);
+
+ TypedValue<PointerLikeType> varPtr =
+ cast<TypedValue<PointerLikeType>>(varPtrOp->getResult());
+ OwningOpRef<Op> op = b.create<Op>(loc, varPtr,
+ /*structured=*/true, /*implicit=*/true);
+
+ EXPECT_FALSE(op->hasAsyncOnly());
+ for (auto d : dtypes)
+ EXPECT_FALSE(op->hasAsyncOnly(d));
+
+ auto dtypeNone = DeviceTypeAttr::get(&context, DeviceType::None);
+ op->setAsyncOnlyAttr(b.getArrayAttr({dtypeNone}));
+ EXPECT_TRUE(op->hasAsyncOnly());
+ EXPECT_TRUE(op->hasAsyncOnly(DeviceType::None));
+ op->removeAsyncOnlyAttr();
+
+ auto dtypeHost = DeviceTypeAttr::get(&context, DeviceType::Host);
+ op->setAsyncOnlyAttr(b.getArrayAttr({dtypeHost}));
+ EXPECT_TRUE(op->hasAsyncOnly(DeviceType::Host));
+ EXPECT_FALSE(op->hasAsyncOnly());
+ op->removeAsyncOnlyAttr();
+
+ auto dtypeStar = DeviceTypeAttr::get(&context, DeviceType::Star);
+ op->setAsyncOnlyAttr(b.getArrayAttr({dtypeHost, dtypeStar}));
+ EXPECT_TRUE(op->hasAsyncOnly(DeviceType::Star));
+ EXPECT_TRUE(op->hasAsyncOnly(DeviceType::Host));
+ EXPECT_FALSE(op->hasAsyncOnly());
+
+ op->removeAsyncOnlyAttr();
+}
+
+TEST_F(OpenACCOpsTest, asyncOnlyTestDataEntry) {
+ testAsyncOnlyDataEntry<DevicePtrOp>(b, context, loc, dtypes);
+ testAsyncOnlyDataEntry<PresentOp>(b, context, loc, dtypes);
+ testAsyncOnlyDataEntry<CopyinOp>(b, context, loc, dtypes);
+ testAsyncOnlyDataEntry<CreateOp>(b, context, loc, dtypes);
+ testAsyncOnlyDataEntry<NoCreateOp>(b, context, loc, dtypes);
+ testAsyncOnlyDataEntry<AttachOp>(b, context, loc, dtypes);
+ testAsyncOnlyDataEntry<UpdateDeviceOp>(b, context, loc, dtypes);
+ testAsyncOnlyDataEntry<UseDeviceOp>(b, context, loc, dtypes);
+}
+
template <typename Op>
void testAsyncValue(OpBuilder &b, MLIRContext &context, Location loc,
llvm::SmallVector<DeviceType> &dtypes) {
@@ -105,6 +153,46 @@ TEST_F(OpenACCOpsTest, asyncValueTest) {
testAsyncValue<SerialOp>(b, context, loc, dtypes);
}
+template <typename Op>
+void testAsyncValueDataEntry(OpBuilder &b, MLIRContext &context, Location loc,
+ llvm::SmallVector<DeviceType> &dtypes) {
+ auto memrefTy = MemRefType::get({}, b.getI32Type());
+ OwningOpRef<memref::AllocaOp> varPtrOp =
+ b.create<memref::AllocaOp>(loc, memrefTy);
+
+ TypedValue<PointerLikeType> varPtr =
+ cast<TypedValue<PointerLikeType>>(varPtrOp->getResult());
+ OwningOpRef<Op> op = b.create<Op>(loc, varPtr,
+ /*structured=*/true, /*implicit=*/true);
+
+ mlir::Value empty;
+ EXPECT_EQ(op->getAsyncValue(), empty);
+ for (auto d : dtypes)
+ EXPECT_EQ(op->getAsyncValue(d), empty);
+
+ OwningOpRef<arith::ConstantIndexOp> val =
+ b.create<arith::ConstantIndexOp>(loc, 1);
+ auto dtypeNvidia = DeviceTypeAttr::get(&context, DeviceType::Nvidia);
+ op->setAsyncOperandsDeviceTypeAttr(b.getArrayAttr({dtypeNvidia}));
+ op->getAsyncOperandsMutable().assign(val->getResult());
+ EXPECT_EQ(op->getAsyncValue(), empty);
+ EXPECT_EQ(op->getAsyncValue(DeviceType::Nvidia), val->getResult());
+
+ op->getAsyncOperandsMutable().clear();
+ op->removeAsyncOperandsDeviceTypeAttr();
+}
+
+TEST_F(OpenACCOpsTest, asyncValueTestDataEntry) {
+ testAsyncValueDataEntry<DevicePtrOp>(b, context, loc, dtypes);
+ testAsyncValueDataEntry<PresentOp>(b, context, loc, dtypes);
+ testAsyncValueDataEntry<CopyinOp>(b, context, loc, dtypes);
+ testAsyncValueDataEntry<CreateOp>(b, context, loc, dtypes);
+ testAsyncValueDataEntry<NoCreateOp>(b, context, loc, dtypes);
+ testAsyncValueDataEntry<AttachOp>(b, context, loc, dtypes);
+ testAsyncValueDataEntry<UpdateDeviceOp>(b, context, loc, dtypes);
+ testAsyncValueDataEntry<UseDeviceOp>(b, context, loc, dtypes);
+}
+
template <typename Op>
void testNumGangsValues(OpBuilder &b, MLIRContext &context, Location loc,
llvm::SmallVector<DeviceType> &dtypes,
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks fir the fix
The |
Data entry operations which are created from constructs with async clause that has no value (aka
acc data copyin(var) async
) end up holding an attribute array named to keep track of this information. However, in cases whereasync
clause is not used, callinghasAsyncOnly
ends up crashing since this attribute is not set.Thus, to fix this issue, ensure that we check for this attribute before trying to walk the attribute array.