Skip to content

Commit cb16424

Browse files
committed
add unit test
1 parent adf9870 commit cb16424

File tree

3 files changed

+100
-0
lines changed

3 files changed

+100
-0
lines changed

mlir/unittests/Dialect/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ target_link_libraries(MLIRDialectTests
99
add_subdirectory(Index)
1010
add_subdirectory(LLVMIR)
1111
add_subdirectory(MemRef)
12+
add_subdirectory(SCF)
1213
add_subdirectory(SparseTensor)
1314
add_subdirectory(SPIRV)
1415
add_subdirectory(Transform)
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
add_mlir_unittest(MLIRSCFTests
2+
LoopLikeSCFOpsTest.cpp
3+
)
4+
target_link_libraries(MLIRSCFTests
5+
PRIVATE
6+
MLIRIR
7+
MLIRSCFDialect
8+
)
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
//===- LoopLikeSCFOpsTest.cpp - SCF LoopLikeOpInterface Tests -------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
//
10+
//===----------------------------------------------------------------------===//
11+
12+
#include "mlir/Dialect/Arith/IR/Arith.h"
13+
#include "mlir/Dialect/SCF/IR/SCF.h"
14+
#include "mlir/IR/Diagnostics.h"
15+
#include "mlir/IR/MLIRContext.h"
16+
#include "gtest/gtest.h"
17+
18+
using namespace mlir;
19+
using namespace mlir::scf;
20+
21+
//===----------------------------------------------------------------------===//
22+
// Test Fixture
23+
//===----------------------------------------------------------------------===//
24+
25+
class SCFLoopLikeTest : public ::testing::Test {
26+
protected:
27+
SCFLoopLikeTest() : b(&context), loc(UnknownLoc::get(&context)) {
28+
context.loadDialect<arith::ArithDialect, scf::SCFDialect>();
29+
}
30+
31+
void checkUnidimensional(LoopLikeOpInterface loopLikeOp) {
32+
std::optional<OpFoldResult> maybeLb = loopLikeOp.getSingleLowerBound();
33+
EXPECT_TRUE(maybeLb.has_value());
34+
std::optional<OpFoldResult> maybeUb = loopLikeOp.getSingleUpperBound();
35+
EXPECT_TRUE(maybeUb.has_value());
36+
std::optional<OpFoldResult> maybeStep = loopLikeOp.getSingleStep();
37+
EXPECT_TRUE(maybeStep.has_value());
38+
std::optional<OpFoldResult> maybeIndVar =
39+
loopLikeOp.getSingleInductionVar();
40+
EXPECT_TRUE(maybeIndVar.has_value());
41+
}
42+
43+
void checkMultidimensional(LoopLikeOpInterface loopLikeOp) {
44+
std::optional<OpFoldResult> maybeLb = loopLikeOp.getSingleLowerBound();
45+
EXPECT_FALSE(maybeLb.has_value());
46+
std::optional<OpFoldResult> maybeUb = loopLikeOp.getSingleUpperBound();
47+
EXPECT_FALSE(maybeUb.has_value());
48+
std::optional<OpFoldResult> maybeStep = loopLikeOp.getSingleStep();
49+
EXPECT_FALSE(maybeStep.has_value());
50+
std::optional<OpFoldResult> maybeIndVar =
51+
loopLikeOp.getSingleInductionVar();
52+
EXPECT_FALSE(maybeIndVar.has_value());
53+
}
54+
55+
MLIRContext context;
56+
OpBuilder b;
57+
Location loc;
58+
};
59+
60+
TEST_F(SCFLoopLikeTest, queryUnidimensionalLooplikes) {
61+
Value lb = b.create<arith::ConstantIndexOp>(loc, 0);
62+
Value ub = b.create<arith::ConstantIndexOp>(loc, 10);
63+
Value step = b.create<arith::ConstantIndexOp>(loc, 2);
64+
65+
auto forOp = b.create<scf::ForOp>(loc, lb, ub, step);
66+
checkUnidimensional(forOp);
67+
68+
auto forallOp = b.create<scf::ForallOp>(
69+
loc, ArrayRef<OpFoldResult>(lb), ArrayRef<OpFoldResult>(ub),
70+
ArrayRef<OpFoldResult>(step), ValueRange(), std::nullopt);
71+
checkUnidimensional(forallOp);
72+
73+
auto parallelOp = b.create<scf::ParallelOp>(
74+
loc, ValueRange(lb), ValueRange(ub), ValueRange(step), ValueRange());
75+
checkUnidimensional(parallelOp);
76+
}
77+
78+
TEST_F(SCFLoopLikeTest, queryMultidimensionalLooplikes) {
79+
Value lb = b.create<arith::ConstantIndexOp>(loc, 0);
80+
Value ub = b.create<arith::ConstantIndexOp>(loc, 10);
81+
Value step = b.create<arith::ConstantIndexOp>(loc, 2);
82+
83+
auto forallOp = b.create<scf::ForallOp>(
84+
loc, ArrayRef<OpFoldResult>({lb, lb}), ArrayRef<OpFoldResult>({ub, ub}),
85+
ArrayRef<OpFoldResult>({step, step}), ValueRange(), std::nullopt);
86+
checkMultidimensional(forallOp);
87+
88+
auto parallelOp = b.create<scf::ParallelOp>(
89+
loc, ValueRange({lb, lb}), ValueRange({ub, ub}), ValueRange({step, step}), ValueRange());
90+
checkMultidimensional(parallelOp);
91+
}

0 commit comments

Comments
 (0)