Skip to content

Commit eef4f66

Browse files
AGindinsonchrsmcgrr
authored andcommitted
fix(mlir): loosen restrictions on folding dynamic reshapes (#8)
The main idea behind the change is to allow expand-of-collapse folds for reshapes like `?x?xk` -> `?` (k>1). The rationale here is that the expand op must have a coherent index/affine expression specified in its `output_shape` argument (see example below), and if it doesn't, the IR has already been invalidated at an earlier stage: ``` %c32 = arith.constant 32 : index %div = arith.divsi %<some_index>, %c32 : index %collapsed = tensor.collapse_shape %41#1 [[0], [1, 2], [3, 4]] : tensor<9x?x32x?x32xf32> into tensor<9x?x?xf32> %affine = affine.apply affine_map<()[s0] -> (s0 * 32)> ()[%div] %expanded = tensor.expand_shape %collapsed [[0], [1, 2], [3]] output_shape [9, %div, 32, %affine] : tensor<9x?x?xf32> into tensor<9x?x32x?xf32> ``` On the above assumption, adjust the routine in `getReassociationIndicesForCollapse()` to allow dynamic reshapes beyond `?x...x?x1x...x1` -> `?`. Moreover, the reassociation util was refactored to clearly distinguish between dynamic and static subshapes. A few known caveats were noted as a comment; I don't think it's possible to fold all qualifying dynamic shape patterns in a deterministic way without looking into affine expressions simultaneously (which would be difficult to maintain in a single general utility for all reliant passes, and would therefore require a larger refactor).
1 parent 2cbd5a3 commit eef4f66

File tree

3 files changed

+79
-52
lines changed

3 files changed

+79
-52
lines changed

mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp

Lines changed: 57 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -31,59 +31,70 @@ mlir::getReassociationIndicesForReshape(ShapedType sourceType,
3131
std::optional<SmallVector<ReassociationIndices>>
3232
mlir::getReassociationIndicesForCollapse(ArrayRef<int64_t> sourceShape,
3333
ArrayRef<int64_t> targetShape) {
34-
if (sourceShape.size() <= targetShape.size())
34+
unsigned numSourceDims = sourceShape.size(),
35+
numTargetDims = targetShape.size();
36+
if (numSourceDims <= numTargetDims)
3537
return std::nullopt;
36-
unsigned sourceDim = 0;
37-
SmallVector<ReassociationIndices> reassociationMap;
38-
reassociationMap.reserve(targetShape.size());
39-
40-
ReassociationIndices currIndices;
41-
int64_t prodOfCollapsedDims = 1;
42-
while (sourceDim < sourceShape.size()) {
43-
unsigned targetDim = reassociationMap.size();
44-
// If we have mapped all the target dimensions stop and handle the remaining
45-
// tail of size-1 dimensions explicitly.
46-
if (targetDim == targetShape.size())
47-
break;
38+
SmallVector<ReassociationIndices, 4> reassociationMap;
39+
reassociationMap.reserve(numTargetDims);
4840

41+
unsigned sourceDim = 0, targetDim = 0;
42+
for (; targetDim < numTargetDims; ++targetDim) {
4943
int64_t currTargetShape = targetShape[targetDim];
50-
while (sourceDim < (sourceShape.size() - 1) &&
51-
sourceShape[sourceDim] != ShapedType::kDynamic &&
52-
prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape) {
44+
ReassociationIndices currIndices;
45+
// 1. Target dimension is dynamic. Source shape should contain at least
46+
// one dynamic dimension.
47+
if (currTargetShape == ShapedType::kDynamic) {
48+
// FIXME: We stop the search with the first dynamic dimension, while in
49+
// fact, we can have a valid pattern like 2x?x?x4x8 -> ?x4x8. It becomes
50+
// indeterministic altogether when we have neighboring dynamic dimensions
51+
// in the target shape. Most of these patterns will be safely rejected,
52+
// however we might achieve more correct folds by taking affine
53+
// expressions into account, if these can be passed on by the call sites.
54+
bool foundDynamic = false;
55+
while (sourceDim < numSourceDims) {
56+
currIndices.push_back(sourceDim);
57+
if (sourceShape[sourceDim++] == ShapedType::kDynamic) {
58+
foundDynamic = true;
59+
break;
60+
}
61+
}
62+
if (!foundDynamic)
63+
return std::nullopt;
64+
65+
reassociationMap.push_back(currIndices);
66+
continue;
67+
}
68+
// 2. Target dimension is static. The product of dimensions of the expanded
69+
// shape should match the collapsed dimension shape.
70+
int64_t prodOfCollapsedDims = 1;
71+
bool reachedTargetDimSize = false;
72+
while (sourceDim < numSourceDims) {
73+
// Source shape cannot be dynamic if the target dim is static.
74+
if (sourceShape[sourceDim] == ShapedType::kDynamic)
75+
return std::nullopt;
5376
prodOfCollapsedDims *= sourceShape[sourceDim];
54-
currIndices.push_back(sourceDim++);
77+
if (prodOfCollapsedDims > currTargetShape)
78+
break;
79+
else if (prodOfCollapsedDims == currTargetShape) {
80+
currIndices.push_back(sourceDim++);
81+
reachedTargetDimSize = true;
82+
break;
83+
} else // prodOfCollapsedDims < currTargetShape
84+
currIndices.push_back(sourceDim++);
5585
}
56-
57-
// If the current expanded dimension is dynamic, then the collapsed
58-
// dimensions should also be dynamic and product of all previous unprocessed
59-
// dimensions of the expanded shape should be 1.
60-
if (sourceShape[sourceDim] == ShapedType::kDynamic &&
61-
(currTargetShape != ShapedType::kDynamic || prodOfCollapsedDims != 1))
86+
if (!reachedTargetDimSize)
6287
return std::nullopt;
63-
64-
// If the collapsed dim is dynamic, the current expanded dim should also
65-
// be dynamic.
66-
if (currTargetShape == ShapedType::kDynamic &&
67-
sourceShape[sourceDim] != ShapedType::kDynamic)
68-
return std::nullopt;
69-
70-
// For static shapes, if the product of dimensions of the expanded shape
71-
// should match the collapsed dimension shape.
72-
if (prodOfCollapsedDims * sourceShape[sourceDim] != currTargetShape)
73-
return std::nullopt;
74-
75-
currIndices.push_back(sourceDim++);
76-
reassociationMap.emplace_back(ReassociationIndices{});
77-
std::swap(reassociationMap.back(), currIndices);
78-
prodOfCollapsedDims = 1;
88+
reassociationMap.push_back(currIndices);
7989
}
80-
// All the dimensions in the target must have been processed.
81-
if (reassociationMap.size() != targetShape.size())
82-
return std::nullopt;
83-
// Process any remaining entries in the source shape. They all need to be
84-
// 1 or dynamic.
85-
for (; sourceDim < sourceShape.size(); sourceDim++) {
86-
if (sourceShape[sourceDim] != ShapedType::kDynamic &&
90+
// Now that we've mapped all the target dimensions, process any remaining
91+
// entries in the source shape explicitly. Either the last target dimension
92+
// is dynamic, or all remaining source entries need to be 1 or dynamic. Same
93+
// applies when target shape is empty (can be the case for subshape
94+
// reassociations).
95+
for (; sourceDim < numSourceDims; sourceDim++) {
96+
if ((targetShape.empty() || targetShape.back() != ShapedType::kDynamic) &&
97+
sourceShape[sourceDim] != ShapedType::kDynamic &&
8798
sourceShape[sourceDim] != 1)
8899
return std::nullopt;
89100
// The map is empty when the target type is a scalar.

mlir/test/Dialect/Linalg/simplify-pack-unpack.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,8 @@ func.func @unpack_to_partial_slice(%arg0: tensor<8x32xf32>) -> tensor<255xf32> {
158158
// -----
159159

160160
// CHECK-LABEL: func.func @unpack_dynamic
161-
// CHECK-NOT: tensor.collapse
162-
// CHECK: linalg.unpack
161+
// CHECK: tensor.collapse
162+
// CHECK-NOT: linalg.unpack
163163
func.func @unpack_dynamic(%arg0: tensor<?x32xf32>) -> tensor<?xf32> {
164164
%c32 = arith.constant 32 : index
165165
%c0 = arith.constant 0 : index

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1068,28 +1068,44 @@ func.func @fold_expand_of_collapse(%arg0 : tensor<3x4x4xf32>) -> tensor<3x4x4xf3
10681068

10691069
// -----
10701070

1071-
func.func @fold_expand_of_collapse_dynamic(%arg0 : tensor<?x4x?xf32>, %arg1: index, %arg2: index)
1071+
func.func @fold_expand_of_collapse_mixed_subshape(%arg0 : tensor<?x4x?xf32>, %arg1: index, %arg2: index)
10721072
-> tensor<?x4x?xf32> {
10731073
%0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
10741074
: tensor<?x4x?xf32> into tensor<?x?xf32>
10751075
%1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2]
10761076
: tensor<?x?xf32> into tensor<?x4x?xf32>
10771077
return %1 : tensor<?x4x?xf32>
10781078
}
1079-
// CHECK-LABEL: @fold_expand_of_collapse_dynamic
1079+
// CHECK-LABEL: @fold_expand_of_collapse_mixed_subshape
10801080
// CHECK-NOT: tensor.{{.*}}_shape
10811081

10821082
// -----
10831083

1084-
func.func @no_fold_expand_of_collapse_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
1084+
func.func @fold_expand_of_collapse_mixed_target_subshape(%arg0 : tensor<?x4x?x2xf32>, %arg1: index, %arg2: index)
1085+
-> tensor<?x4x?xf32> {
1086+
%0 = tensor.collapse_shape %arg0 [[0, 1], [2, 3]]
1087+
: tensor<?x4x?x2xf32> into tensor<?x?xf32>
1088+
%1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, 4, %arg2]
1089+
: tensor<?x?xf32> into tensor<?x4x?xf32>
1090+
return %1 : tensor<?x4x?xf32>
1091+
}
1092+
// CHECK-LABEL: @fold_expand_of_collapse_mixed_target_subshape
1093+
// CHECK-NOT: tensor.expand_shape
1094+
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %arg0 {{\[}}[0], [1], [2, 3]]
1095+
// CHECK-SAME: : tensor<?x4x?x2xf32> into tensor<?x4x?xf32>
1096+
// CHECK-NEXT: return %[[COLLAPSE]]
1097+
1098+
// -----
1099+
1100+
func.func @no_fold_expand_of_collapse_fully_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index)
10851101
-> tensor<?x?x?xf32> {
10861102
%0 = tensor.collapse_shape %arg0 [[0, 1], [2]]
10871103
: tensor<?x?x?xf32> into tensor<?x?xf32>
10881104
%1 = tensor.expand_shape %0 [[0, 1], [2]] output_shape [%arg1, %arg2, %arg3]
10891105
: tensor<?x?xf32> into tensor<?x?x?xf32>
10901106
return %1 : tensor<?x?x?xf32>
10911107
}
1092-
// CHECK-LABEL: @no_fold_expand_of_collapse_dynamic
1108+
// CHECK-LABEL: @no_fold_expand_of_collapse_fully_dynamic
10931109
// CHECK: tensor.collapse_shape
10941110
// CHECK: %[[EXPAND:.+]] = tensor.expand_shape
10951111
// CHECK: return %[[EXPAND]]

0 commit comments

Comments
 (0)