7
7
// ===----------------------------------------------------------------------===//
8
8
9
9
#include " mlir/Dialect/Utils/ReshapeOpsUtils.h"
10
+ #include " mlir/IR/BuiltinTypeInterfaces.h"
10
11
#include " llvm/ADT/STLExtras.h"
11
12
#include " gtest/gtest.h"
12
13
#include < optional>
@@ -20,6 +21,29 @@ makeOptionalIndices(std::initializer_list<ReassociationIndices> list) {
20
21
return std::optional<SmallVector<ReassociationIndices>>(list);
21
22
}
22
23
24
+ TEST (ReassociationIndicesForCollapse, ScalarTest) {
25
+ EXPECT_EQ (getReassociationIndicesForCollapse ({1 }, {}),
26
+ makeOptionalIndices ({{0 }}));
27
+ EXPECT_EQ (getReassociationIndicesForCollapse ({1 , 1 }, {}),
28
+ makeOptionalIndices ({{0 , 1 }}));
29
+ EXPECT_EQ (getReassociationIndicesForCollapse ({ShapedType::kDynamic }, {}),
30
+ makeOptionalIndices ({{0 }}));
31
+ EXPECT_EQ (getReassociationIndicesForCollapse ({1 , ShapedType::kDynamic ,
32
+ ShapedType::kDynamic , 1 ,
33
+ ShapedType::kDynamic },
34
+ {}),
35
+ makeOptionalIndices ({{0 , 1 , 2 , 3 , 4 }}));
36
+ }
37
+
38
+ TEST (ReassociationIndicesForCollapse, ScalarTestFailure) {
39
+ EXPECT_EQ (getReassociationIndicesForCollapse ({}, {}), std::nullopt);
40
+ EXPECT_EQ (getReassociationIndicesForCollapse ({}, {1 }), std::nullopt);
41
+ EXPECT_EQ (getReassociationIndicesForCollapse ({2 }, {}), std::nullopt);
42
+ EXPECT_EQ (
43
+ getReassociationIndicesForCollapse ({1 , 2 , ShapedType::kDynamic , 1 }, {}),
44
+ std::nullopt);
45
+ }
46
+
23
47
TEST (ReassociationIndicesForCollapse, StaticTest) {
24
48
EXPECT_EQ (getReassociationIndicesForCollapse ({10 , 20 }, {200 }),
25
49
makeOptionalIndices ({{0 , 1 }}));
0 commit comments