Skip to content

Commit f4aec22

Browse files
[mlir][acc] Fix async only api on data entry operations (#122818)
Data entry operations which are created from constructs with async clause that has no value (aka `acc data copyin(var) async`) end up holding an attribute array named to keep track of this information. However, in cases where `async` clause is not used, calling `hasAsyncOnly` ends up crashing since this attribute is not set. Thus, to fix this issue, ensure that we check for this attribute before trying to walk the attribute array.
1 parent e03c435 commit f4aec22

File tree

2 files changed

+96
-2
lines changed

2 files changed

+96
-2
lines changed

mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,10 @@ class OpenACC_DataEntryOp<string mnemonic, string clause, string extraDescriptio
445445
}
446446
/// Return true if the op has the async attribute for the given device_type.
447447
bool hasAsyncOnly(mlir::acc::DeviceType deviceType) {
448-
for (auto attr : getAsyncOnlyAttr()) {
448+
mlir::ArrayAttr asyncOnly = getAsyncOnlyAttr();
449+
if (!asyncOnly)
450+
return false;
451+
for (auto attr : asyncOnly) {
449452
auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
450453
if (deviceTypeAttr.getValue() == deviceType)
451454
return true;
@@ -817,7 +820,10 @@ class OpenACC_DataExitOp<string mnemonic, string clause, string extraDescription
817820
}
818821
/// Return true if the op has the async attribute for the given device_type.
819822
bool hasAsyncOnly(mlir::acc::DeviceType deviceType) {
820-
for (auto attr : getAsyncOnlyAttr()) {
823+
mlir::ArrayAttr asyncOnly = getAsyncOnlyAttr();
824+
if (!asyncOnly)
825+
return false;
826+
for (auto attr : asyncOnly) {
821827
auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
822828
if (deviceTypeAttr.getValue() == deviceType)
823829
return true;

mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,54 @@ TEST_F(OpenACCOpsTest, asyncOnlyTest) {
7777
testAsyncOnly<SerialOp>(b, context, loc, dtypes);
7878
}
7979

80+
template <typename Op>
81+
void testAsyncOnlyDataEntry(OpBuilder &b, MLIRContext &context, Location loc,
82+
llvm::SmallVector<DeviceType> &dtypes) {
83+
auto memrefTy = MemRefType::get({}, b.getI32Type());
84+
OwningOpRef<memref::AllocaOp> varPtrOp =
85+
b.create<memref::AllocaOp>(loc, memrefTy);
86+
87+
TypedValue<PointerLikeType> varPtr =
88+
cast<TypedValue<PointerLikeType>>(varPtrOp->getResult());
89+
OwningOpRef<Op> op = b.create<Op>(loc, varPtr,
90+
/*structured=*/true, /*implicit=*/true);
91+
92+
EXPECT_FALSE(op->hasAsyncOnly());
93+
for (auto d : dtypes)
94+
EXPECT_FALSE(op->hasAsyncOnly(d));
95+
96+
auto dtypeNone = DeviceTypeAttr::get(&context, DeviceType::None);
97+
op->setAsyncOnlyAttr(b.getArrayAttr({dtypeNone}));
98+
EXPECT_TRUE(op->hasAsyncOnly());
99+
EXPECT_TRUE(op->hasAsyncOnly(DeviceType::None));
100+
op->removeAsyncOnlyAttr();
101+
102+
auto dtypeHost = DeviceTypeAttr::get(&context, DeviceType::Host);
103+
op->setAsyncOnlyAttr(b.getArrayAttr({dtypeHost}));
104+
EXPECT_TRUE(op->hasAsyncOnly(DeviceType::Host));
105+
EXPECT_FALSE(op->hasAsyncOnly());
106+
op->removeAsyncOnlyAttr();
107+
108+
auto dtypeStar = DeviceTypeAttr::get(&context, DeviceType::Star);
109+
op->setAsyncOnlyAttr(b.getArrayAttr({dtypeHost, dtypeStar}));
110+
EXPECT_TRUE(op->hasAsyncOnly(DeviceType::Star));
111+
EXPECT_TRUE(op->hasAsyncOnly(DeviceType::Host));
112+
EXPECT_FALSE(op->hasAsyncOnly());
113+
114+
op->removeAsyncOnlyAttr();
115+
}
116+
117+
TEST_F(OpenACCOpsTest, asyncOnlyTestDataEntry) {
118+
testAsyncOnlyDataEntry<DevicePtrOp>(b, context, loc, dtypes);
119+
testAsyncOnlyDataEntry<PresentOp>(b, context, loc, dtypes);
120+
testAsyncOnlyDataEntry<CopyinOp>(b, context, loc, dtypes);
121+
testAsyncOnlyDataEntry<CreateOp>(b, context, loc, dtypes);
122+
testAsyncOnlyDataEntry<NoCreateOp>(b, context, loc, dtypes);
123+
testAsyncOnlyDataEntry<AttachOp>(b, context, loc, dtypes);
124+
testAsyncOnlyDataEntry<UpdateDeviceOp>(b, context, loc, dtypes);
125+
testAsyncOnlyDataEntry<UseDeviceOp>(b, context, loc, dtypes);
126+
}
127+
80128
template <typename Op>
81129
void testAsyncValue(OpBuilder &b, MLIRContext &context, Location loc,
82130
llvm::SmallVector<DeviceType> &dtypes) {
@@ -105,6 +153,46 @@ TEST_F(OpenACCOpsTest, asyncValueTest) {
105153
testAsyncValue<SerialOp>(b, context, loc, dtypes);
106154
}
107155

156+
template <typename Op>
157+
void testAsyncValueDataEntry(OpBuilder &b, MLIRContext &context, Location loc,
158+
llvm::SmallVector<DeviceType> &dtypes) {
159+
auto memrefTy = MemRefType::get({}, b.getI32Type());
160+
OwningOpRef<memref::AllocaOp> varPtrOp =
161+
b.create<memref::AllocaOp>(loc, memrefTy);
162+
163+
TypedValue<PointerLikeType> varPtr =
164+
cast<TypedValue<PointerLikeType>>(varPtrOp->getResult());
165+
OwningOpRef<Op> op = b.create<Op>(loc, varPtr,
166+
/*structured=*/true, /*implicit=*/true);
167+
168+
mlir::Value empty;
169+
EXPECT_EQ(op->getAsyncValue(), empty);
170+
for (auto d : dtypes)
171+
EXPECT_EQ(op->getAsyncValue(d), empty);
172+
173+
OwningOpRef<arith::ConstantIndexOp> val =
174+
b.create<arith::ConstantIndexOp>(loc, 1);
175+
auto dtypeNvidia = DeviceTypeAttr::get(&context, DeviceType::Nvidia);
176+
op->setAsyncOperandsDeviceTypeAttr(b.getArrayAttr({dtypeNvidia}));
177+
op->getAsyncOperandsMutable().assign(val->getResult());
178+
EXPECT_EQ(op->getAsyncValue(), empty);
179+
EXPECT_EQ(op->getAsyncValue(DeviceType::Nvidia), val->getResult());
180+
181+
op->getAsyncOperandsMutable().clear();
182+
op->removeAsyncOperandsDeviceTypeAttr();
183+
}
184+
185+
TEST_F(OpenACCOpsTest, asyncValueTestDataEntry) {
186+
testAsyncValueDataEntry<DevicePtrOp>(b, context, loc, dtypes);
187+
testAsyncValueDataEntry<PresentOp>(b, context, loc, dtypes);
188+
testAsyncValueDataEntry<CopyinOp>(b, context, loc, dtypes);
189+
testAsyncValueDataEntry<CreateOp>(b, context, loc, dtypes);
190+
testAsyncValueDataEntry<NoCreateOp>(b, context, loc, dtypes);
191+
testAsyncValueDataEntry<AttachOp>(b, context, loc, dtypes);
192+
testAsyncValueDataEntry<UpdateDeviceOp>(b, context, loc, dtypes);
193+
testAsyncValueDataEntry<UseDeviceOp>(b, context, loc, dtypes);
194+
}
195+
108196
template <typename Op>
109197
void testNumGangsValues(OpBuilder &b, MLIRContext &context, Location loc,
110198
llvm::SmallVector<DeviceType> &dtypes,

0 commit comments

Comments
 (0)