Skip to content

Commit 26d896f

Browse files
authored
Fixes in 'tosa.reshape' lowering and folder (#85798)
- Revamped lowering conversion pattern for `tosa.reshape` to handle previously unsupported combinations of dynamic dimensions in input and output tensors. The lowering strategy continues to rely on pairs `tensor.collapse_shape` + `tensor.expand_shape`, which allow for downstream fusion with surrounding `linalg.generic` ops. - Fixed bug in canonicalization pattern `ReshapeOp::fold()` in `TosaCanonicalizations.cpp`. The input and result types being equal is not a sufficient condition for folding. If there is more than 1 dynamic dimension in the input and result types, a productive reshape could still occur. - This work exposed the fact that bufferization does not properly handle a `tensor.collapse_shape` op producing a 0D tensor from a dynamically shaped one due to a limitation in `memref.collapse_shape`. While the proper way to address this would involve releasing the `memref.collapse_shape` restriction and verifying correct bufferization, this is left as possible future work. For now, this scenario is avoided by casting the `tosa.reshape` input tensor to a static shape if necessary (see `inferReshapeInputType()`. - An extended set of tests are intended to cover relevant conversion paths. Tests are named using pattern `test_reshape_<rank>_{up|down|same}_{s2s|s2d|d2s|d2d}_{explicit|auto}[_empty][_identity]`, where: - `<rank>` is the input rank (e.g., 3d, 6d) - `{up|down|same}` indicates whether the reshape increases, decreases, or retains the input rank. - `{s2s|s2d|d2s|d2d}` indicates whether reshape converts a statically shaped input to a statically shaped result (`s2s`), a statically shaped input to a dynamically shaped result (`s2d`), etc. - `{explicit|auto}` is used to indicate that all values in the `new_shape` attribute are >=0 (`explicit`) or that a -1 placeholder value is used (`auto`). - `empty` is used to indicate that `new_shape` includes a component set to 0. - `identity` is used when the input and result shapes are the same.
1 parent 507e59a commit 26d896f

File tree

5 files changed

+495
-195
lines changed

5 files changed

+495
-195
lines changed

mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp

Lines changed: 160 additions & 144 deletions
Original file line numberDiff line numberDiff line change
@@ -19,24 +19,99 @@
1919
#include "mlir/IR/PatternMatch.h"
2020
#include "mlir/Transforms/DialectConversion.h"
2121

22+
#include <numeric>
23+
2224
using namespace mlir;
2325
using namespace tosa;
2426

25-
static bool findIntermediateShape(ArrayRef<int64_t> lhsShape,
26-
ArrayRef<int64_t> rhsShape,
27-
SmallVector<int64_t> &intermediateShape,
28-
bool isDynamic) {
29-
if (isDynamic) {
30-
// TODO (natashaknk): Make dynamic intermediate shape not always be rank-1
31-
intermediateShape = {ShapedType::kDynamic};
32-
return true;
33-
}
27+
namespace {
3428

35-
if (lhsShape.empty() || rhsShape.empty()) {
36-
intermediateShape = {};
37-
return true;
38-
}
29+
// Infer the type to which the input of a 'tosa.reshape' op must be cast when
30+
// lowered.
31+
TensorType inferReshapeInputType(TypedValue<TensorType> input,
32+
ArrayRef<int64_t> newShape) {
33+
// No need to cast input for non-empty target shape
34+
if (!newShape.empty())
35+
return input.getType();
36+
37+
// The input type must be cast into a tensor with the same rank and all static
38+
// dimensions set to 1. This prevents the generation of a tensor.collapse_shape
39+
// op that converts a dynamically shaped tensor into a 0D tensor. While such
40+
// construct is not incorrect on its own, bufferization cannot properly handle
41+
// it at the moment, so we avoid it.
42+
SmallVector<int64_t> shape(input.getType().getRank(), 1);
43+
return input.getType().clone(shape);
44+
}
45+
46+
// Infer the result type of 'tensor.expand_shape' in the collapse-expand
47+
// pair emitted for a 'tosa.reshape' op.
48+
TensorType inferReshapeExpandedType(TensorType inputType,
49+
ArrayRef<int64_t> newShape) {
50+
// Special case for 0D output tensor. Note: Watch out when using Type::clone()
51+
// with just '{}', as it will invoke the incorrect overload.
52+
if (newShape.empty())
53+
return inputType.clone(ArrayRef<int64_t>{});
54+
55+
// Check if the input is static, and if so, get its total size
56+
bool inputIsStatic = inputType.hasStaticShape();
57+
int64_t totalSize = inputIsStatic ? inputType.getNumElements() : -1;
58+
59+
// Compute result shape
60+
bool resultIsStatic = true;
61+
auto resultShape = llvm::map_to_vector(newShape, [&](int64_t size) -> int64_t {
62+
// If this is not a placeholder, do not change it
63+
if (size >= 0)
64+
return size;
65+
66+
// If we do not know the total size of the tensor, keep this dimension
67+
// dynamic in the result shape.
68+
if (!inputIsStatic) {
69+
resultIsStatic = false;
70+
return ShapedType::kDynamic;
71+
}
3972

73+
// Calculate the product of all elements in 'newShape' except for the -1
74+
// placeholder, which we discard by negating the result.
75+
int64_t totalSizeNoPlaceholder = -std::accumulate(
76+
newShape.begin(), newShape.end(), 1, std::multiplies());
77+
78+
// If there is a 0 component in 'newShape', resolve the placeholder as 0.
79+
if (totalSizeNoPlaceholder == 0)
80+
return 0;
81+
82+
// Resolve the placeholder as the quotient between the total tensor size and
83+
// the product of all other sizes.
84+
return totalSize / totalSizeNoPlaceholder;
85+
});
86+
87+
// A syntactic restriction in 'tensor.expand_shape' forbids a dynamically
88+
// shaped input from being reshaped into a statically shaped result. We may
89+
// simply turn the first result dimension dynamic to address this.
90+
if (!inputIsStatic && resultIsStatic)
91+
resultShape[0] = ShapedType::kDynamic;
92+
93+
// The 'tensor.expand_shape' op also forbids a statically shaped input from
94+
// being reshaped into a dynamically shaped result, but the placeholder
95+
// inference algorithm above guarantees that this will never be the case.
96+
assert(!inputIsStatic || resultIsStatic);
97+
98+
// Create result type
99+
return inputType.clone(resultShape);
100+
}
101+
102+
// Infer the result type of 'tensor.collapse_shape' in the collapse-expand
103+
// pair emitted for a 'tosa.reshape' op.
104+
TensorType inferReshapeCollapsedType(TensorType lhsType, TensorType rhsType) {
105+
auto lhsShape = lhsType.getShape();
106+
auto rhsShape = rhsType.getShape();
107+
108+
if (lhsShape.empty() || rhsShape.empty())
109+
return lhsType.clone(ArrayRef<int64_t>{});
110+
111+
if (ShapedType::isDynamicShape(lhsShape) || ShapedType::isDynamicShape(rhsShape))
112+
return lhsType.clone({ShapedType::kDynamic});
113+
114+
SmallVector<int64_t> intermediateShape;
40115
unsigned currLhsDim = 0, currRhsDim = 0;
41116
while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) {
42117
int64_t rhsSize = rhsShape[currRhsDim];
@@ -62,174 +137,113 @@ static bool findIntermediateShape(ArrayRef<int64_t> lhsShape,
62137
currLhsDim++;
63138
}
64139

65-
// If the iterators didn't reach the end and their leftover dimensions are not
66-
// equal to 1 an intermediate shape was not found.
67-
while (currLhsDim < lhsShape.size()) {
68-
if (lhsShape[currLhsDim++] != 1) {
69-
return false;
70-
}
140+
// Static shapes are guaranteed to be compatible by the op verifier, so all
141+
// leftover dimensions should be 1.
142+
for (; currLhsDim < lhsShape.size(); currLhsDim++) {
143+
assert(lhsShape[currLhsDim] == 1);
71144
}
72-
73-
while (currRhsDim < rhsShape.size()) {
74-
if (rhsShape[currRhsDim++] != 1) {
75-
return false;
76-
}
145+
for (; currRhsDim < rhsShape.size(); currRhsDim++) {
146+
assert(rhsShape[currRhsDim] == 1);
77147
}
78-
79-
return true;
148+
149+
return lhsType.clone(intermediateShape);
80150
}
81151

82-
static bool createReassociationMapsForCollapse(
83-
PatternRewriter &rewriter, ArrayRef<int64_t> srcShape,
84-
ArrayRef<int64_t> dstShape,
85-
SmallVector<ReassociationExprs, 4> &reassociationMap, bool isDynamic) {
152+
SmallVector<ReassociationExprs>
153+
createReassociationMapForCollapse(OpBuilder &builder, Type srcType, Type dstType) {
154+
auto srcShape = cast<TensorType>(srcType).getShape();
155+
auto dstShape = cast<TensorType>(dstType).getShape();
86156

87-
// If the shape is dynamic, create a map for collapsing into one dimension.
88-
if (isDynamic) {
89-
SmallVector<AffineExpr, 2> exprs;
90-
for (int i = 0, s = srcShape.size(); i < s; ++i)
91-
exprs.push_back(rewriter.getAffineDimExpr(i));
92-
reassociationMap = {exprs};
93-
return true;
94-
}
157+
if (srcShape.empty() || dstShape.empty())
158+
return {};
95159

96-
if (dstShape.empty()) {
97-
reassociationMap = {};
98-
return true;
160+
if (ShapedType::isDynamicShape(srcShape) || ShapedType::isDynamicShape(dstShape)) {
161+
assert(dstShape.size() == 1);
162+
SmallVector<AffineExpr, 2> exprs;
163+
for (auto i : llvm::seq<int64_t>(srcShape.size()))
164+
exprs.push_back(builder.getAffineDimExpr(i));
165+
return {exprs};
99166
}
100167

101-
reassociationMap.resize(dstShape.size());
168+
SmallVector<ReassociationExprs> reassociationMap(dstShape.size());
102169
unsigned currSrcDim = 0, currDstDim = 0;
103170
while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
104171
int64_t dstSize = dstShape[currDstDim];
105172
int64_t srcSize = srcShape[currSrcDim];
106173
while (srcSize < dstSize && currSrcDim < srcShape.size()) {
107174
reassociationMap[currDstDim].push_back(
108-
rewriter.getAffineDimExpr(currSrcDim++));
175+
builder.getAffineDimExpr(currSrcDim++));
109176
srcSize *= srcShape[currSrcDim];
110177
}
111178
if (srcSize == dstSize) {
112179
reassociationMap[currDstDim].push_back(
113-
rewriter.getAffineDimExpr(currSrcDim++));
180+
builder.getAffineDimExpr(currSrcDim++));
114181
// If the next dim in collapsedShape is not 1, treat subsequent dims in
115182
// expandedShape which are 1 to be collapsed.
116183
if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) {
117184
while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
118185
reassociationMap[currDstDim].push_back(
119-
rewriter.getAffineDimExpr(currSrcDim++));
186+
builder.getAffineDimExpr(currSrcDim++));
120187
}
121188
}
122189
}
123190
currDstDim++;
124191
}
125192

126-
// If both iterators didn't reach the end, we have leftover dimentions which
127-
// implies that we have a mismatch in shape.
128-
return currSrcDim == srcShape.size() && currDstDim == dstShape.size();
193+
// If the source and target shapes are compatible, both iterators must have
194+
// reached the end. This condition is guaranteed by the op verifier for
195+
// static shapes.
196+
assert(currSrcDim == srcShape.size() && currDstDim == dstShape.size());
197+
return reassociationMap;
129198
}
130199

131-
namespace {
132-
Value createCollapse(ConversionPatternRewriter &rewriter, Location loc,
133-
ShapedType resultTy, Value operand) {
134-
ShapedType operandTy = cast<ShapedType>(operand.getType());
135-
if (resultTy == operandTy)
136-
return operand;
137-
138-
bool isDynamic = !operandTy.hasStaticShape();
139-
140-
if (isDynamic && resultTy.getRank() != 1) {
141-
(void)rewriter.notifyMatchFailure(
142-
loc, "Cannot collapse dynamic dims to more than one dimension");
143-
return {};
144-
}
145-
146-
SmallVector<ReassociationExprs, 4> reassociationMap;
147-
if (!createReassociationMapsForCollapse(rewriter, operandTy.getShape(),
148-
resultTy.getShape(),
149-
reassociationMap, isDynamic)) {
150-
(void)rewriter.notifyMatchFailure(
151-
loc, "tosa.reshape Attempting to collapse into an incompatible shape");
152-
return {};
153-
}
154-
155-
SmallVector<int64_t> intermediateShape;
156-
if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
157-
intermediateShape, isDynamic)) {
158-
(void)rewriter.notifyMatchFailure(
159-
loc, "tosa.reshape Cannot collapse into given shape");
160-
return {};
161-
}
162-
return rewriter.create<tensor::CollapseShapeOp>(loc, resultTy, operand,
163-
reassociationMap);
200+
// Create a tensor.collapse_shape op that reshapes the input into the given
201+
// result type.
202+
Value createCollapse(OpBuilder &builder, Location loc, TensorType resultType,
203+
Value input) {
204+
auto reassociationMap =
205+
createReassociationMapForCollapse(builder, input.getType(), resultType);
206+
return builder.createOrFold<tensor::CollapseShapeOp>(loc, resultType, input,
207+
reassociationMap);
164208
}
165209

166-
Value createExpand(ConversionPatternRewriter &rewriter, Location loc,
167-
ShapedType resultTy, Value operand) {
168-
ShapedType operandTy = cast<ShapedType>(operand.getType());
169-
if (resultTy == operandTy)
170-
return operand;
171-
172-
bool isDynamic = !operandTy.hasStaticShape();
173-
174-
if (isDynamic && operandTy.getRank() != 1) {
175-
(void)rewriter.notifyMatchFailure(
176-
loc, "Cannot expand dynamic dims from more than one dimension");
177-
return {};
178-
}
179-
180-
SmallVector<ReassociationExprs, 4> reassociationMap;
181-
if (!createReassociationMapsForCollapse(rewriter, resultTy.getShape(),
182-
operandTy.getShape(),
183-
reassociationMap, isDynamic)) {
184-
(void)rewriter.notifyMatchFailure(
185-
loc, "tosa.reshape Attempting to expand into an incompatible shape");
186-
return {};
187-
}
188-
189-
SmallVector<int64_t> intermediateShape;
190-
if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
191-
intermediateShape, isDynamic) ||
192-
intermediateShape != operandTy.getShape()) {
193-
(void)rewriter.notifyMatchFailure(
194-
loc, "tosa.reshape Cannot expand into given shape");
195-
return {};
196-
}
197-
return rewriter.create<tensor::ExpandShapeOp>(loc, resultTy, operand,
198-
reassociationMap);
210+
// Create a tensor.expand_shape op that reshapes the input into the given result
211+
// type.
212+
Value createExpand(OpBuilder &builder, Location loc, TensorType resultType,
213+
Value input) {
214+
auto reassociationMap =
215+
createReassociationMapForCollapse(builder, resultType, input.getType());
216+
return builder.createOrFold<tensor::ExpandShapeOp>(loc, resultType, input,
217+
reassociationMap);
199218
}
200219

201-
class ReshapeConverterCollapseExpand
202-
: public OpConversionPattern<tosa::ReshapeOp> {
220+
class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
203221
public:
204222
using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
205223

206224
LogicalResult
207225
matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
208226
ConversionPatternRewriter &rewriter) const final {
209-
ShapedType operandTy = cast<ShapedType>(adaptor.getInput1().getType());
210-
ShapedType resultTy = cast<ShapedType>(reshape.getType());
211-
bool isDynamic = !operandTy.hasStaticShape();
212-
213-
SmallVector<int64_t> intermediateShape;
214-
if (!findIntermediateShape(resultTy.getShape(), operandTy.getShape(),
215-
intermediateShape, isDynamic)) {
216-
return rewriter.notifyMatchFailure(
217-
reshape, "tosa.reshape Cannot identify an intermediate shape between "
218-
"the given two shapes");
219-
}
220-
auto intermediateTy = RankedTensorType::get(
221-
intermediateShape, reshape.getType().getElementType());
222-
223-
Value collapse = createCollapse(rewriter, reshape.getLoc(), intermediateTy,
224-
adaptor.getInput1());
225-
if (!collapse)
226-
return failure();
227-
228-
Value expand = createExpand(rewriter, reshape.getLoc(), resultTy, collapse);
229-
if (!expand)
230-
return failure();
231-
232-
rewriter.replaceOp(reshape, expand);
227+
auto loc = reshape.getLoc();
228+
auto resultType = reshape.getResult().getType();
229+
auto input = reshape.getInput1();
230+
auto newShape = reshape.getNewShape();
231+
232+
// Infer all intermediate types
233+
auto inputType = inferReshapeInputType(input, newShape);
234+
auto expandedType = inferReshapeExpandedType(inputType, newShape);
235+
auto collapsedType = inferReshapeCollapsedType(inputType, expandedType);
236+
237+
// Cast input if needed
238+
auto castInput = rewriter.createOrFold<tensor::CastOp>(loc, inputType, input);
239+
240+
// Emit collaspe-expand pair
241+
auto collapsed = createCollapse(rewriter, loc, collapsedType, castInput);
242+
auto expanded = createExpand(rewriter, loc, expandedType, collapsed);
243+
244+
// Cast to final result type if needed
245+
auto result = rewriter.createOrFold<tensor::CastOp>(loc, resultType, expanded);
246+
rewriter.replaceOp(reshape, result);
233247
return success();
234248
}
235249
};
@@ -416,8 +430,10 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
416430

417431
void mlir::tosa::populateTosaToTensorConversionPatterns(
418432
RewritePatternSet *patterns) {
419-
patterns->add<SliceConverter, PadConverter, ConcatConverter>(
420-
patterns->getContext());
421-
422-
patterns->add<ReshapeConverterCollapseExpand>(patterns->getContext());
433+
patterns->add<
434+
ConcatConverter,
435+
PadConverter,
436+
ReshapeConverter,
437+
SliceConverter
438+
>(patterns->getContext());
423439
}

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -795,7 +795,10 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
795795
if (!inputTy || !outputTy)
796796
return {};
797797

798-
if (inputTy == outputTy)
798+
// Fold when the input and output types are the same. This is only safe when
799+
// there is at most 1 dynamic dimension. For 2 or more dynamic dimensions,
800+
// there may still be a productive reshape.
801+
if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
799802
return getInput1();
800803

801804
// reshape(reshape(x)) -> reshape(x)

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -970,6 +970,11 @@ mlir::LogicalResult tosa::ReshapeOp::verify() {
970970
<< " elements into " << outputElementsNum;
971971
}
972972
}
973+
974+
int missingDims = llvm::count(getNewShape(), -1);
975+
if (missingDims > 1)
976+
return emitOpError() << "At most one target dimension can be -1";
977+
973978
return mlir::success();
974979
}
975980

0 commit comments

Comments
 (0)