Skip to content

Commit 02fa434

Browse files
authored
[mlir][openacc] Restore unit tests for device_type functions (#77122)
These tests were initially pushed together with #75864 but they were triggering some buildbot failure (sanitizers). They now make use of the `OwningOpRef` so all the resources are correctly destroyed at the end of each tests. They will be extended to includes all the extra getter functions added with device_type support.
1 parent 9052512 commit 02fa434

File tree

3 files changed

+358
-0
lines changed

3 files changed

+358
-0
lines changed

mlir/unittests/Dialect/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ add_subdirectory(ArmSME)
1010
add_subdirectory(Index)
1111
add_subdirectory(LLVMIR)
1212
add_subdirectory(MemRef)
13+
add_subdirectory(OpenACC)
1314
add_subdirectory(SCF)
1415
add_subdirectory(SparseTensor)
1516
add_subdirectory(SPIRV)
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
add_mlir_unittest(MLIROpenACCTests
2+
OpenACCOpsTest.cpp
3+
)
4+
target_link_libraries(MLIROpenACCTests
5+
PRIVATE
6+
MLIRIR
7+
MLIROpenACCDialect
8+
)
Lines changed: 349 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,349 @@
1+
//===- OpenACCOpsTest.cpp - OpenACC ops extra functiosn Tests -------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Arith/IR/Arith.h"
10+
#include "mlir/Dialect/OpenACC/OpenACC.h"
11+
#include "mlir/IR/Diagnostics.h"
12+
#include "mlir/IR/MLIRContext.h"
13+
#include "mlir/IR/OwningOpRef.h"
14+
#include "gtest/gtest.h"
15+
16+
using namespace mlir;
17+
using namespace mlir::acc;
18+
19+
//===----------------------------------------------------------------------===//
20+
// Test Fixture
21+
//===----------------------------------------------------------------------===//
22+
23+
class OpenACCOpsTest : public ::testing::Test {
24+
protected:
25+
OpenACCOpsTest() : b(&context), loc(UnknownLoc::get(&context)) {
26+
context.loadDialect<acc::OpenACCDialect, arith::ArithDialect>();
27+
}
28+
29+
MLIRContext context;
30+
OpBuilder b;
31+
Location loc;
32+
llvm::SmallVector<DeviceType> dtypes = {
33+
DeviceType::None, DeviceType::Star, DeviceType::Multicore,
34+
DeviceType::Default, DeviceType::Host, DeviceType::Nvidia,
35+
DeviceType::Radeon};
36+
llvm::SmallVector<DeviceType> dtypesWithoutNone = {
37+
DeviceType::Star, DeviceType::Multicore, DeviceType::Default,
38+
DeviceType::Host, DeviceType::Nvidia, DeviceType::Radeon};
39+
};
40+
41+
template <typename Op>
42+
void testAsyncOnly(OpBuilder &b, MLIRContext &context, Location loc,
43+
llvm::SmallVector<DeviceType> &dtypes) {
44+
OwningOpRef<Op> op = b.create<Op>(loc, TypeRange{}, ValueRange{});
45+
EXPECT_FALSE(op->hasAsyncOnly());
46+
for (auto d : dtypes)
47+
EXPECT_FALSE(op->hasAsyncOnly(d));
48+
49+
auto dtypeNone = DeviceTypeAttr::get(&context, DeviceType::None);
50+
op->setAsyncOnlyAttr(b.getArrayAttr({dtypeNone}));
51+
EXPECT_TRUE(op->hasAsyncOnly());
52+
EXPECT_TRUE(op->hasAsyncOnly(DeviceType::None));
53+
op->removeAsyncOnlyAttr();
54+
55+
auto dtypeHost = DeviceTypeAttr::get(&context, DeviceType::Host);
56+
op->setAsyncOnlyAttr(b.getArrayAttr({dtypeHost}));
57+
EXPECT_TRUE(op->hasAsyncOnly(DeviceType::Host));
58+
EXPECT_FALSE(op->hasAsyncOnly());
59+
op->removeAsyncOnlyAttr();
60+
61+
auto dtypeStar = DeviceTypeAttr::get(&context, DeviceType::Star);
62+
op->setAsyncOnlyAttr(b.getArrayAttr({dtypeHost, dtypeStar}));
63+
EXPECT_TRUE(op->hasAsyncOnly(DeviceType::Star));
64+
EXPECT_TRUE(op->hasAsyncOnly(DeviceType::Host));
65+
EXPECT_FALSE(op->hasAsyncOnly());
66+
67+
op->removeAsyncOnlyAttr();
68+
}
69+
70+
TEST_F(OpenACCOpsTest, asyncOnlyTest) {
71+
testAsyncOnly<ParallelOp>(b, context, loc, dtypes);
72+
testAsyncOnly<KernelsOp>(b, context, loc, dtypes);
73+
testAsyncOnly<SerialOp>(b, context, loc, dtypes);
74+
}
75+
76+
template <typename Op>
77+
void testAsyncValue(OpBuilder &b, MLIRContext &context, Location loc,
78+
llvm::SmallVector<DeviceType> &dtypes) {
79+
OwningOpRef<Op> op = b.create<Op>(loc, TypeRange{}, ValueRange{});
80+
81+
mlir::Value empty;
82+
EXPECT_EQ(op->getAsyncValue(), empty);
83+
for (auto d : dtypes)
84+
EXPECT_EQ(op->getAsyncValue(d), empty);
85+
86+
OwningOpRef<arith::ConstantIndexOp> val =
87+
b.create<arith::ConstantIndexOp>(loc, 1);
88+
auto dtypeNvidia = DeviceTypeAttr::get(&context, DeviceType::Nvidia);
89+
op->setAsyncDeviceTypeAttr(b.getArrayAttr({dtypeNvidia}));
90+
op->getAsyncMutable().assign(val->getResult());
91+
EXPECT_EQ(op->getAsyncValue(), empty);
92+
EXPECT_EQ(op->getAsyncValue(DeviceType::Nvidia), val->getResult());
93+
94+
op->getAsyncMutable().clear();
95+
op->removeAsyncDeviceTypeAttr();
96+
}
97+
98+
TEST_F(OpenACCOpsTest, asyncValueTest) {
99+
testAsyncValue<ParallelOp>(b, context, loc, dtypes);
100+
testAsyncValue<KernelsOp>(b, context, loc, dtypes);
101+
testAsyncValue<SerialOp>(b, context, loc, dtypes);
102+
}
103+
104+
template <typename Op>
105+
void testNumGangsValues(OpBuilder &b, MLIRContext &context, Location loc,
106+
llvm::SmallVector<DeviceType> &dtypes,
107+
llvm::SmallVector<DeviceType> &dtypesWithoutNone) {
108+
OwningOpRef<Op> op = b.create<Op>(loc, TypeRange{}, ValueRange{});
109+
EXPECT_EQ(op->getNumGangsValues().begin(), op->getNumGangsValues().end());
110+
111+
OwningOpRef<arith::ConstantIndexOp> val1 =
112+
b.create<arith::ConstantIndexOp>(loc, 1);
113+
OwningOpRef<arith::ConstantIndexOp> val2 =
114+
b.create<arith::ConstantIndexOp>(loc, 4);
115+
auto dtypeNone = DeviceTypeAttr::get(&context, DeviceType::None);
116+
op->getNumGangsMutable().assign(val1->getResult());
117+
op->setNumGangsDeviceTypeAttr(b.getArrayAttr({dtypeNone}));
118+
op->setNumGangsSegments(b.getDenseI32ArrayAttr({1}));
119+
EXPECT_EQ(op->getNumGangsValues().front(), val1->getResult());
120+
for (auto d : dtypesWithoutNone)
121+
EXPECT_EQ(op->getNumGangsValues(d).begin(), op->getNumGangsValues(d).end());
122+
123+
op->getNumGangsMutable().clear();
124+
op->removeNumGangsDeviceTypeAttr();
125+
op->removeNumGangsSegmentsAttr();
126+
for (auto d : dtypes)
127+
EXPECT_EQ(op->getNumGangsValues(d).begin(), op->getNumGangsValues(d).end());
128+
129+
op->getNumGangsMutable().append(val1->getResult());
130+
op->getNumGangsMutable().append(val2->getResult());
131+
op->setNumGangsDeviceTypeAttr(
132+
b.getArrayAttr({DeviceTypeAttr::get(&context, DeviceType::Host),
133+
DeviceTypeAttr::get(&context, DeviceType::Star)}));
134+
op->setNumGangsSegments(b.getDenseI32ArrayAttr({1, 1}));
135+
EXPECT_EQ(op->getNumGangsValues(DeviceType::None).begin(),
136+
op->getNumGangsValues(DeviceType::None).end());
137+
EXPECT_EQ(op->getNumGangsValues(DeviceType::Host).front(), val1->getResult());
138+
EXPECT_EQ(op->getNumGangsValues(DeviceType::Star).front(), val2->getResult());
139+
140+
op->getNumGangsMutable().clear();
141+
op->removeNumGangsDeviceTypeAttr();
142+
op->removeNumGangsSegmentsAttr();
143+
for (auto d : dtypes)
144+
EXPECT_EQ(op->getNumGangsValues(d).begin(), op->getNumGangsValues(d).end());
145+
146+
op->getNumGangsMutable().append(val1->getResult());
147+
op->getNumGangsMutable().append(val2->getResult());
148+
op->getNumGangsMutable().append(val1->getResult());
149+
op->setNumGangsDeviceTypeAttr(
150+
b.getArrayAttr({DeviceTypeAttr::get(&context, DeviceType::Default),
151+
DeviceTypeAttr::get(&context, DeviceType::Multicore)}));
152+
op->setNumGangsSegments(b.getDenseI32ArrayAttr({2, 1}));
153+
EXPECT_EQ(op->getNumGangsValues(DeviceType::None).begin(),
154+
op->getNumGangsValues(DeviceType::None).end());
155+
EXPECT_EQ(op->getNumGangsValues(DeviceType::Default).front(),
156+
val1->getResult());
157+
EXPECT_EQ(op->getNumGangsValues(DeviceType::Default).drop_front().front(),
158+
val2->getResult());
159+
EXPECT_EQ(op->getNumGangsValues(DeviceType::Multicore).front(),
160+
val1->getResult());
161+
162+
op->getNumGangsMutable().clear();
163+
op->removeNumGangsDeviceTypeAttr();
164+
op->removeNumGangsSegmentsAttr();
165+
}
166+
167+
TEST_F(OpenACCOpsTest, numGangsValuesTest) {
168+
testNumGangsValues<ParallelOp>(b, context, loc, dtypes, dtypesWithoutNone);
169+
testNumGangsValues<KernelsOp>(b, context, loc, dtypes, dtypesWithoutNone);
170+
}
171+
172+
template <typename Op>
173+
void testVectorLength(OpBuilder &b, MLIRContext &context, Location loc,
174+
llvm::SmallVector<DeviceType> &dtypes) {
175+
OwningOpRef<Op> op = b.create<Op>(loc, TypeRange{}, ValueRange{});
176+
177+
mlir::Value empty;
178+
EXPECT_EQ(op->getVectorLengthValue(), empty);
179+
for (auto d : dtypes)
180+
EXPECT_EQ(op->getVectorLengthValue(d), empty);
181+
182+
OwningOpRef<arith::ConstantIndexOp> val =
183+
b.create<arith::ConstantIndexOp>(loc, 1);
184+
auto dtypeNvidia = DeviceTypeAttr::get(&context, DeviceType::Nvidia);
185+
op->setVectorLengthDeviceTypeAttr(b.getArrayAttr({dtypeNvidia}));
186+
op->getVectorLengthMutable().assign(val->getResult());
187+
EXPECT_EQ(op->getVectorLengthValue(), empty);
188+
EXPECT_EQ(op->getVectorLengthValue(DeviceType::Nvidia), val->getResult());
189+
190+
op->getVectorLengthMutable().clear();
191+
op->removeVectorLengthDeviceTypeAttr();
192+
}
193+
194+
TEST_F(OpenACCOpsTest, vectorLengthTest) {
195+
testVectorLength<ParallelOp>(b, context, loc, dtypes);
196+
testVectorLength<KernelsOp>(b, context, loc, dtypes);
197+
}
198+
199+
template <typename Op>
200+
void testWaitOnly(OpBuilder &b, MLIRContext &context, Location loc,
201+
llvm::SmallVector<DeviceType> &dtypes,
202+
llvm::SmallVector<DeviceType> &dtypesWithoutNone) {
203+
OwningOpRef<Op> op = b.create<Op>(loc, TypeRange{}, ValueRange{});
204+
EXPECT_FALSE(op->hasWaitOnly());
205+
for (auto d : dtypes)
206+
EXPECT_FALSE(op->hasWaitOnly(d));
207+
208+
auto dtypeNone = DeviceTypeAttr::get(&context, DeviceType::None);
209+
op->setWaitOnlyAttr(b.getArrayAttr({dtypeNone}));
210+
EXPECT_TRUE(op->hasWaitOnly());
211+
EXPECT_TRUE(op->hasWaitOnly(DeviceType::None));
212+
for (auto d : dtypesWithoutNone)
213+
EXPECT_FALSE(op->hasWaitOnly(d));
214+
op->removeWaitOnlyAttr();
215+
216+
auto dtypeHost = DeviceTypeAttr::get(&context, DeviceType::Host);
217+
op->setWaitOnlyAttr(b.getArrayAttr({dtypeHost}));
218+
EXPECT_TRUE(op->hasWaitOnly(DeviceType::Host));
219+
EXPECT_FALSE(op->hasWaitOnly());
220+
op->removeWaitOnlyAttr();
221+
222+
auto dtypeStar = DeviceTypeAttr::get(&context, DeviceType::Star);
223+
op->setWaitOnlyAttr(b.getArrayAttr({dtypeHost, dtypeStar}));
224+
EXPECT_TRUE(op->hasWaitOnly(DeviceType::Star));
225+
EXPECT_TRUE(op->hasWaitOnly(DeviceType::Host));
226+
EXPECT_FALSE(op->hasWaitOnly());
227+
228+
op->removeWaitOnlyAttr();
229+
}
230+
231+
TEST_F(OpenACCOpsTest, waitOnlyTest) {
232+
testWaitOnly<ParallelOp>(b, context, loc, dtypes, dtypesWithoutNone);
233+
testWaitOnly<KernelsOp>(b, context, loc, dtypes, dtypesWithoutNone);
234+
testWaitOnly<SerialOp>(b, context, loc, dtypes, dtypesWithoutNone);
235+
}
236+
237+
template <typename Op>
238+
void testWaitValues(OpBuilder &b, MLIRContext &context, Location loc,
239+
llvm::SmallVector<DeviceType> &dtypes,
240+
llvm::SmallVector<DeviceType> &dtypesWithoutNone) {
241+
OwningOpRef<Op> op = b.create<Op>(loc, TypeRange{}, ValueRange{});
242+
EXPECT_EQ(op->getWaitValues().begin(), op->getWaitValues().end());
243+
244+
OwningOpRef<arith::ConstantIndexOp> val1 =
245+
b.create<arith::ConstantIndexOp>(loc, 1);
246+
OwningOpRef<arith::ConstantIndexOp> val2 =
247+
b.create<arith::ConstantIndexOp>(loc, 4);
248+
auto dtypeNone = DeviceTypeAttr::get(&context, DeviceType::None);
249+
op->getWaitOperandsMutable().assign(val1->getResult());
250+
op->setWaitOperandsDeviceTypeAttr(b.getArrayAttr({dtypeNone}));
251+
op->setWaitOperandsSegments(b.getDenseI32ArrayAttr({1}));
252+
EXPECT_EQ(op->getWaitValues().front(), val1->getResult());
253+
for (auto d : dtypesWithoutNone)
254+
EXPECT_EQ(op->getWaitValues(d).begin(), op->getWaitValues(d).end());
255+
256+
op->getWaitOperandsMutable().clear();
257+
op->removeWaitOperandsDeviceTypeAttr();
258+
op->removeWaitOperandsSegmentsAttr();
259+
for (auto d : dtypes)
260+
EXPECT_EQ(op->getWaitValues(d).begin(), op->getWaitValues(d).end());
261+
262+
op->getWaitOperandsMutable().append(val1->getResult());
263+
op->getWaitOperandsMutable().append(val2->getResult());
264+
op->setWaitOperandsDeviceTypeAttr(
265+
b.getArrayAttr({DeviceTypeAttr::get(&context, DeviceType::Host),
266+
DeviceTypeAttr::get(&context, DeviceType::Star)}));
267+
op->setWaitOperandsSegments(b.getDenseI32ArrayAttr({1, 1}));
268+
EXPECT_EQ(op->getWaitValues(DeviceType::None).begin(),
269+
op->getWaitValues(DeviceType::None).end());
270+
EXPECT_EQ(op->getWaitValues(DeviceType::Host).front(), val1->getResult());
271+
EXPECT_EQ(op->getWaitValues(DeviceType::Star).front(), val2->getResult());
272+
273+
op->getWaitOperandsMutable().clear();
274+
op->removeWaitOperandsDeviceTypeAttr();
275+
op->removeWaitOperandsSegmentsAttr();
276+
for (auto d : dtypes)
277+
EXPECT_EQ(op->getWaitValues(d).begin(), op->getWaitValues(d).end());
278+
279+
op->getWaitOperandsMutable().append(val1->getResult());
280+
op->getWaitOperandsMutable().append(val2->getResult());
281+
op->getWaitOperandsMutable().append(val1->getResult());
282+
op->setWaitOperandsDeviceTypeAttr(
283+
b.getArrayAttr({DeviceTypeAttr::get(&context, DeviceType::Default),
284+
DeviceTypeAttr::get(&context, DeviceType::Multicore)}));
285+
op->setWaitOperandsSegments(b.getDenseI32ArrayAttr({2, 1}));
286+
EXPECT_EQ(op->getWaitValues(DeviceType::None).begin(),
287+
op->getWaitValues(DeviceType::None).end());
288+
EXPECT_EQ(op->getWaitValues(DeviceType::Default).front(), val1->getResult());
289+
EXPECT_EQ(op->getWaitValues(DeviceType::Default).drop_front().front(),
290+
val2->getResult());
291+
EXPECT_EQ(op->getWaitValues(DeviceType::Multicore).front(),
292+
val1->getResult());
293+
294+
op->getWaitOperandsMutable().clear();
295+
op->removeWaitOperandsDeviceTypeAttr();
296+
op->removeWaitOperandsSegmentsAttr();
297+
}
298+
299+
TEST_F(OpenACCOpsTest, waitValuesTest) {
300+
testWaitValues<KernelsOp>(b, context, loc, dtypes, dtypesWithoutNone);
301+
testWaitValues<ParallelOp>(b, context, loc, dtypes, dtypesWithoutNone);
302+
testWaitValues<SerialOp>(b, context, loc, dtypes, dtypesWithoutNone);
303+
}
304+
305+
TEST_F(OpenACCOpsTest, loopOpGangVectorWorkerTest) {
306+
OwningOpRef<LoopOp> op = b.create<LoopOp>(loc, TypeRange{}, ValueRange{});
307+
EXPECT_FALSE(op->hasGang());
308+
EXPECT_FALSE(op->hasVector());
309+
EXPECT_FALSE(op->hasWorker());
310+
for (auto d : dtypes) {
311+
EXPECT_FALSE(op->hasGang(d));
312+
EXPECT_FALSE(op->hasVector(d));
313+
EXPECT_FALSE(op->hasWorker(d));
314+
}
315+
316+
auto dtypeNone = DeviceTypeAttr::get(&context, DeviceType::None);
317+
op->setGangAttr(b.getArrayAttr({dtypeNone}));
318+
EXPECT_TRUE(op->hasGang());
319+
EXPECT_TRUE(op->hasGang(DeviceType::None));
320+
for (auto d : dtypesWithoutNone)
321+
EXPECT_FALSE(op->hasGang(d));
322+
for (auto d : dtypes) {
323+
EXPECT_FALSE(op->hasVector(d));
324+
EXPECT_FALSE(op->hasWorker(d));
325+
}
326+
op->removeGangAttr();
327+
328+
op->setWorkerAttr(b.getArrayAttr({dtypeNone}));
329+
EXPECT_TRUE(op->hasWorker());
330+
EXPECT_TRUE(op->hasWorker(DeviceType::None));
331+
for (auto d : dtypesWithoutNone)
332+
EXPECT_FALSE(op->hasWorker(d));
333+
for (auto d : dtypes) {
334+
EXPECT_FALSE(op->hasGang(d));
335+
EXPECT_FALSE(op->hasVector(d));
336+
}
337+
op->removeWorkerAttr();
338+
339+
op->setVectorAttr(b.getArrayAttr({dtypeNone}));
340+
EXPECT_TRUE(op->hasVector());
341+
EXPECT_TRUE(op->hasVector(DeviceType::None));
342+
for (auto d : dtypesWithoutNone)
343+
EXPECT_FALSE(op->hasVector(d));
344+
for (auto d : dtypes) {
345+
EXPECT_FALSE(op->hasGang(d));
346+
EXPECT_FALSE(op->hasWorker(d));
347+
}
348+
op->removeVectorAttr();
349+
}

0 commit comments

Comments
 (0)