Skip to content

Commit 4dd5cc0

Browse files
committed
[fixup] Add scalar target tests & fix em
Signed-off-by: Artem Gindinson <[email protected]>
1 parent 9a40046 commit 4dd5cc0

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,8 @@ mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
292292
return std::nullopt;
293293
// Early handling for scalar target types.
294294
if (numTargetDims == 0) {
295-
ReassociationIndices allSourceIndices(numSourceDims);
295+
ReassociationIndices allSourceIndices;
296+
allSourceIndices.reserve(numSourceDims);
296297
for (unsigned sourceDimIdx = 0; sourceDimIdx < numSourceDims;
297298
++sourceDimIdx) {
298299
int64_t sourceSize = sourceShape[sourceDimIdx];

mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
//===----------------------------------------------------------------------===//
88

99
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
10+
#include "mlir/IR/BuiltinTypeInterfaces.h"
1011
#include "llvm/ADT/STLExtras.h"
1112
#include "gtest/gtest.h"
1213
#include <optional>
@@ -20,6 +21,29 @@ makeOptionalIndices(std::initializer_list<ReassociationIndices> list) {
2021
return std::optional<SmallVector<ReassociationIndices>>(list);
2122
}
2223

24+
TEST(ReassociationIndicesForCollapse, ScalarTest) {
25+
EXPECT_EQ(getReassociationIndicesForCollapse({1}, {}),
26+
makeOptionalIndices({{0}}));
27+
EXPECT_EQ(getReassociationIndicesForCollapse({1, 1}, {}),
28+
makeOptionalIndices({{0, 1}}));
29+
EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic}, {}),
30+
makeOptionalIndices({{0}}));
31+
EXPECT_EQ(getReassociationIndicesForCollapse({1, ShapedType::kDynamic,
32+
ShapedType::kDynamic, 1,
33+
ShapedType::kDynamic},
34+
{}),
35+
makeOptionalIndices({{0, 1, 2, 3, 4}}));
36+
}
37+
38+
TEST(ReassociationIndicesForCollapse, ScalarTestFailure) {
39+
EXPECT_EQ(getReassociationIndicesForCollapse({}, {}), std::nullopt);
40+
EXPECT_EQ(getReassociationIndicesForCollapse({}, {1}), std::nullopt);
41+
EXPECT_EQ(getReassociationIndicesForCollapse({2}, {}), std::nullopt);
42+
EXPECT_EQ(
43+
getReassociationIndicesForCollapse({1, 2, ShapedType::kDynamic, 1}, {}),
44+
std::nullopt);
45+
}
46+
2347
TEST(ReassociationIndicesForCollapse, StaticTest) {
2448
EXPECT_EQ(getReassociationIndicesForCollapse({10, 20}, {200}),
2549
makeOptionalIndices({{0, 1}}));

0 commit comments

Comments
 (0)