Skip to content

Commit 0943058

Browse files
AGindinsontomtor
authored andcommitted
[mlir][tensor] Fix getReassociationForCollapse for tensor/scalar re… (llvm#144118)
…shapes Commit 6e5a142 changed the behavior of the function when computing reassociations between tensors (consisting of unit/dynamic dimensions) and scalars/0d vectors. The IR representation for such reshapes actually expects an empty reassociation, like so: ``` func.func @example(%arg0 : tensor<?x?x?xf32>) -> tensor<f32> { %0 = tensor.collapse_shape %arg0 [] : tensor<?x?x?xf32> into tensor<f32> } ``` Restore the original behavior - the routine should resort to reporting failures when compile time-known non-unit dimensions are part of the attempted reassociation. Signed-off-by: Artem Gindinson <[email protected]>
1 parent 8a2754e commit 0943058

File tree

2 files changed

+8
-10
lines changed

2 files changed

+8
-10
lines changed

mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -299,19 +299,17 @@ mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
299299
// this utility).
300300
if (numSourceDims <= numTargetDims)
301301
return std::nullopt;
302-
// Early handling for scalar target types.
302+
// Early handling for scalar target types. We should report an invalid
303+
// reassociation for non-unit static dimensions - no chance to collapse these
304+
// into a scalar.
303305
if (numTargetDims == 0) {
304-
ReassociationIndices allSourceIndices;
305-
allSourceIndices.reserve(numSourceDims);
306306
for (unsigned sourceDimIdx = 0; sourceDimIdx < numSourceDims;
307307
++sourceDimIdx) {
308308
int64_t sourceSize = sourceShape[sourceDimIdx];
309-
// All source dimensions must be unit or dynamic.
310309
if (sourceSize != 1 && sourceSize != ShapedType::kDynamic)
311310
return std::nullopt;
312-
allSourceIndices.push_back(sourceDimIdx);
313311
}
314-
return SmallVector<ReassociationIndices>{allSourceIndices};
312+
return SmallVector<ReassociationIndices>{};
315313
}
316314

317315
// Collect source ranges by iterating over the target shape left-to-right.

mlir/unittests/Dialect/Utils/ReshapeOpsUtilsTest.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,16 @@ makeOptionalIndices(std::initializer_list<ReassociationIndices> list) {
2323

2424
TEST(ReassociationIndicesForCollapse, ScalarTest) {
2525
EXPECT_EQ(getReassociationIndicesForCollapse({1}, {}),
26-
makeOptionalIndices({{0}}));
26+
makeOptionalIndices({}));
2727
EXPECT_EQ(getReassociationIndicesForCollapse({1, 1}, {}),
28-
makeOptionalIndices({{0, 1}}));
28+
makeOptionalIndices({}));
2929
EXPECT_EQ(getReassociationIndicesForCollapse({ShapedType::kDynamic}, {}),
30-
makeOptionalIndices({{0}}));
30+
makeOptionalIndices({}));
3131
EXPECT_EQ(getReassociationIndicesForCollapse({1, ShapedType::kDynamic,
3232
ShapedType::kDynamic, 1,
3333
ShapedType::kDynamic},
3434
{}),
35-
makeOptionalIndices({{0, 1, 2, 3, 4}}));
35+
makeOptionalIndices({}));
3636
}
3737

3838
TEST(ReassociationIndicesForCollapse, ScalarTestFailure) {

0 commit comments

Comments
 (0)