Skip to content

Commit a7c439d

Browse files
committed
[mlir][test] Add basic unit test for ScalableVectorType
1 parent bd8720d commit a7c439d

File tree

2 files changed

+78
-1
lines changed

2 files changed

+78
-1
lines changed

mlir/unittests/Support/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
add_mlir_unittest(MLIRSupportTests
22
IndentedOstreamTest.cpp
33
StorageUniquerTest.cpp
4+
ScalableVectorTypeTest.cpp
45
)
56

67
target_link_libraries(MLIRSupportTests
7-
PRIVATE MLIRSupport)
8+
PRIVATE MLIRSupport MLIRIR)
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
//===- ScalableVectorTypeTest.cpp - ScalableVectorType Tests --------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Support/ScalableVectorType.h"
10+
#include "mlir/IR/Dialect.h"
11+
#include "gtest/gtest.h"
12+
13+
using namespace mlir;
14+
15+
TEST(ScalableVectorTypeTest, TestVectorDim) {
16+
auto fixedDim = VectorDim::getFixed(4);
17+
ASSERT_FALSE(fixedDim.isScalable());
18+
ASSERT_TRUE(fixedDim.isFixed());
19+
ASSERT_EQ(fixedDim.getFixedSize(), 4);
20+
21+
auto scalableDim = VectorDim::getScalable(8);
22+
ASSERT_TRUE(scalableDim.isScalable());
23+
ASSERT_FALSE(scalableDim.isFixed());
24+
ASSERT_EQ(scalableDim.getMinSize(), 8);
25+
}
26+
27+
TEST(ScalableVectorTypeTest, BasicFunctionality) {
28+
MLIRContext context;
29+
30+
Type f32 = FloatType::getF32(&context);
31+
32+
// Construct n-D scalable vector.
33+
VectorType scalableVector = ScalableVectorType::get(
34+
{VectorDim::getFixed(1), VectorDim::getFixed(2),
35+
VectorDim::getScalable(3), VectorDim::getFixed(4),
36+
VectorDim::getScalable(5)},
37+
f32);
38+
// Construct fixed vector.
39+
VectorType fixedVector = ScalableVectorType::get(VectorDim::getFixed(1), f32);
40+
41+
// Check casts.
42+
ASSERT_TRUE(isa<ScalableVectorType>(scalableVector));
43+
ASSERT_FALSE(isa<ScalableVectorType>(fixedVector));
44+
ASSERT_FALSE(VectorDimList::from(fixedVector).hasScalableDims());
45+
46+
// Check rank/size.
47+
auto vType = cast<ScalableVectorType>(scalableVector);
48+
ASSERT_EQ(vType.getDims().size(), unsigned(scalableVector.getRank()));
49+
ASSERT_TRUE(vType.getDims().hasScalableDims());
50+
51+
// Check iterating over dimensions.
52+
std::array expectedDims{VectorDim::getFixed(1), VectorDim::getFixed(2),
53+
VectorDim::getScalable(3), VectorDim::getFixed(4),
54+
VectorDim::getScalable(5)};
55+
unsigned i = 0;
56+
for (VectorDim dim : vType.getDims()) {
57+
ASSERT_EQ(dim, expectedDims[i]);
58+
i++;
59+
}
60+
}
61+
62+
TEST(ScalableVectorTypeTest, VectorDimListHelpers) {
63+
std::array<int64_t, 4> sizes{42, 10, 3, 1};
64+
std::array<bool, 4> scalableFlags{false, true, false, true};
65+
66+
// Manually construct from sizes + flags.
67+
VectorDimList dimList(sizes, scalableFlags);
68+
69+
ASSERT_EQ(dimList.size(), 4U);
70+
71+
ASSERT_EQ(dimList.front(), VectorDim::getFixed(42));
72+
ASSERT_EQ(dimList.back(), VectorDim::getScalable(1));
73+
74+
std::array innerDims{VectorDim::getScalable(10), VectorDim::getFixed(3)};
75+
ASSERT_EQ(dimList.slice(1, 2), innerDims);
76+
}

0 commit comments

Comments
 (0)