Skip to content

Commit 98a8ac5

Browse files
committed
Revert "Fixes in 'tosa.reshape' lowering and folder (llvm#85798)"
This reverts commit 26d896f.
1 parent 9ec3f8a commit 98a8ac5

File tree

5 files changed

+195
-495
lines changed

5 files changed

+195
-495
lines changed

mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp

Lines changed: 144 additions & 160 deletions
Original file line numberDiff line numberDiff line change
@@ -19,99 +19,24 @@
1919
#include "mlir/IR/PatternMatch.h"
2020
#include "mlir/Transforms/DialectConversion.h"
2121

22-
#include <numeric>
23-
2422
using namespace mlir;
2523
using namespace tosa;
2624

27-
namespace {
28-
29-
// Infer the type to which the input of a 'tosa.reshape' op must be cast when
30-
// lowered.
31-
TensorType inferReshapeInputType(TypedValue<TensorType> input,
32-
ArrayRef<int64_t> newShape) {
33-
// No need to cast input for non-empty target shape
34-
if (!newShape.empty())
35-
return input.getType();
36-
37-
// The input type must be cast into a tensor with the same rank and all static
38-
// dimensions set to 1. This prevents the generation of a tensor.collapse_shape
39-
// op that converts a dynamically shaped tensor into a 0D tensor. While such
40-
// construct is not incorrect on its own, bufferization cannot properly handle
41-
// it at the moment, so we avoid it.
42-
SmallVector<int64_t> shape(input.getType().getRank(), 1);
43-
return input.getType().clone(shape);
44-
}
45-
46-
// Infer the result type of 'tensor.expand_shape' in the collapse-expand
47-
// pair emitted for a 'tosa.reshape' op.
48-
TensorType inferReshapeExpandedType(TensorType inputType,
49-
ArrayRef<int64_t> newShape) {
50-
// Special case for 0D output tensor. Note: Watch out when using Type::clone()
51-
// with just '{}', as it will invoke the incorrect overload.
52-
if (newShape.empty())
53-
return inputType.clone(ArrayRef<int64_t>{});
54-
55-
// Check if the input is static, and if so, get its total size
56-
bool inputIsStatic = inputType.hasStaticShape();
57-
int64_t totalSize = inputIsStatic ? inputType.getNumElements() : -1;
58-
59-
// Compute result shape
60-
bool resultIsStatic = true;
61-
auto resultShape = llvm::map_to_vector(newShape, [&](int64_t size) -> int64_t {
62-
// If this is not a placeholder, do not change it
63-
if (size >= 0)
64-
return size;
65-
66-
// If we do not know the total size of the tensor, keep this dimension
67-
// dynamic in the result shape.
68-
if (!inputIsStatic) {
69-
resultIsStatic = false;
70-
return ShapedType::kDynamic;
71-
}
72-
73-
// Calculate the product of all elements in 'newShape' except for the -1
74-
// placeholder, which we discard by negating the result.
75-
int64_t totalSizeNoPlaceholder = -std::accumulate(
76-
newShape.begin(), newShape.end(), 1, std::multiplies());
77-
78-
// If there is a 0 component in 'newShape', resolve the placeholder as 0.
79-
if (totalSizeNoPlaceholder == 0)
80-
return 0;
81-
82-
// Resolve the placeholder as the quotient between the total tensor size and
83-
// the product of all other sizes.
84-
return totalSize / totalSizeNoPlaceholder;
85-
});
86-
87-
// A syntactic restriction in 'tensor.expand_shape' forbids a dynamically
88-
// shaped input from being reshaped into a statically shaped result. We may
89-
// simply turn the first result dimension dynamic to address this.
90-
if (!inputIsStatic && resultIsStatic)
91-
resultShape[0] = ShapedType::kDynamic;
92-
93-
// The 'tensor.expand_shape' op also forbids a statically shaped input from
94-
// being reshaped into a dynamically shaped result, but the placeholder
95-
// inference algorithm above guarantees that this will never be the case.
96-
assert(!inputIsStatic || resultIsStatic);
97-
98-
// Create result type
99-
return inputType.clone(resultShape);
100-
}
101-
102-
// Infer the result type of 'tensor.collapse_shape' in the collapse-expand
103-
// pair emitted for a 'tosa.reshape' op.
104-
TensorType inferReshapeCollapsedType(TensorType lhsType, TensorType rhsType) {
105-
auto lhsShape = lhsType.getShape();
106-
auto rhsShape = rhsType.getShape();
107-
108-
if (lhsShape.empty() || rhsShape.empty())
109-
return lhsType.clone(ArrayRef<int64_t>{});
25+
static bool findIntermediateShape(ArrayRef<int64_t> lhsShape,
26+
ArrayRef<int64_t> rhsShape,
27+
SmallVector<int64_t> &intermediateShape,
28+
bool isDynamic) {
29+
if (isDynamic) {
30+
// TODO (natashaknk): Make dynamic intermediate shape not always be rank-1
31+
intermediateShape = {ShapedType::kDynamic};
32+
return true;
33+
}
11034

111-
if (ShapedType::isDynamicShape(lhsShape) || ShapedType::isDynamicShape(rhsShape))
112-
return lhsType.clone({ShapedType::kDynamic});
35+
if (lhsShape.empty() || rhsShape.empty()) {
36+
intermediateShape = {};
37+
return true;
38+
}
11339

114-
SmallVector<int64_t> intermediateShape;
11540
unsigned currLhsDim = 0, currRhsDim = 0;
11641
while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) {
11742
int64_t rhsSize = rhsShape[currRhsDim];
@@ -137,113 +62,174 @@ TensorType inferReshapeCollapsedType(TensorType lhsType, TensorType rhsType) {
13762
currLhsDim++;
13863
}
13964

140-
// Static shapes are guaranteed to be compatible by the op verifier, so all
141-
// leftover dimensions should be 1.
142-
for (; currLhsDim < lhsShape.size(); currLhsDim++) {
143-
assert(lhsShape[currLhsDim] == 1);
65+
// If the iterators didn't reach the end and their leftover dimensions are not
66+
// equal to 1 an intermediate shape was not found.
67+
while (currLhsDim < lhsShape.size()) {
68+
if (lhsShape[currLhsDim++] != 1) {
69+
return false;
70+
}
14471
}
145-
for (; currRhsDim < rhsShape.size(); currRhsDim++) {
146-
assert(rhsShape[currRhsDim] == 1);
72+
73+
while (currRhsDim < rhsShape.size()) {
74+
if (rhsShape[currRhsDim++] != 1) {
75+
return false;
76+
}
14777
}
148-
149-
return lhsType.clone(intermediateShape);
78+
79+
return true;
15080
}
15181

152-
SmallVector<ReassociationExprs>
153-
createReassociationMapForCollapse(OpBuilder &builder, Type srcType, Type dstType) {
154-
auto srcShape = cast<TensorType>(srcType).getShape();
155-
auto dstShape = cast<TensorType>(dstType).getShape();
82+
static bool createReassociationMapsForCollapse(
83+
PatternRewriter &rewriter, ArrayRef<int64_t> srcShape,
84+
ArrayRef<int64_t> dstShape,
85+
SmallVector<ReassociationExprs, 4> &reassociationMap, bool isDynamic) {
15686

157-
if (srcShape.empty() || dstShape.empty())
158-
return {};
159-
160-
if (ShapedType::isDynamicShape(srcShape) || ShapedType::isDynamicShape(dstShape)) {
161-
assert(dstShape.size() == 1);
87+
// If the shape is dynamic, create a map for collapsing into one dimension.
88+
if (isDynamic) {
16289
SmallVector<AffineExpr, 2> exprs;
163-
for (auto i : llvm::seq<int64_t>(srcShape.size()))
164-
exprs.push_back(builder.getAffineDimExpr(i));
165-
return {exprs};
90+
for (int i = 0, s = srcShape.size(); i < s; ++i)
91+
exprs.push_back(rewriter.getAffineDimExpr(i));
92+
reassociationMap = {exprs};
93+
return true;
94+
}
95+
96+
if (dstShape.empty()) {
97+
reassociationMap = {};
98+
return true;
16699
}
167100

168-
SmallVector<ReassociationExprs> reassociationMap(dstShape.size());
101+
reassociationMap.resize(dstShape.size());
169102
unsigned currSrcDim = 0, currDstDim = 0;
170103
while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
171104
int64_t dstSize = dstShape[currDstDim];
172105
int64_t srcSize = srcShape[currSrcDim];
173106
while (srcSize < dstSize && currSrcDim < srcShape.size()) {
174107
reassociationMap[currDstDim].push_back(
175-
builder.getAffineDimExpr(currSrcDim++));
108+
rewriter.getAffineDimExpr(currSrcDim++));
176109
srcSize *= srcShape[currSrcDim];
177110
}
178111
if (srcSize == dstSize) {
179112
reassociationMap[currDstDim].push_back(
180-
builder.getAffineDimExpr(currSrcDim++));
113+
rewriter.getAffineDimExpr(currSrcDim++));
181114
// If the next dim in collapsedShape is not 1, treat subsequent dims in
182115
// expandedShape which are 1 to be collapsed.
183116
if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) {
184117
while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
185118
reassociationMap[currDstDim].push_back(
186-
builder.getAffineDimExpr(currSrcDim++));
119+
rewriter.getAffineDimExpr(currSrcDim++));
187120
}
188121
}
189122
}
190123
currDstDim++;
191124
}
192125

193-
// If the source and target shapes are compatible, both iterators must have
194-
// reached the end. This condition is guaranteed by the op verifier for
195-
// static shapes.
196-
assert(currSrcDim == srcShape.size() && currDstDim == dstShape.size());
197-
return reassociationMap;
126+
// If both iterators didn't reach the end, we have leftover dimentions which
127+
// implies that we have a mismatch in shape.
128+
return currSrcDim == srcShape.size() && currDstDim == dstShape.size();
198129
}
199130

200-
// Create a tensor.collapse_shape op that reshapes the input into the given
201-
// result type.
202-
Value createCollapse(OpBuilder &builder, Location loc, TensorType resultType,
203-
Value input) {
204-
auto reassociationMap =
205-
createReassociationMapForCollapse(builder, input.getType(), resultType);
206-
return builder.createOrFold<tensor::CollapseShapeOp>(loc, resultType, input,
207-
reassociationMap);
131+
namespace {
132+
Value createCollapse(ConversionPatternRewriter &rewriter, Location loc,
133+
ShapedType resultTy, Value operand) {
134+
ShapedType operandTy = cast<ShapedType>(operand.getType());
135+
if (resultTy == operandTy)
136+
return operand;
137+
138+
bool isDynamic = !operandTy.hasStaticShape();
139+
140+
if (isDynamic && resultTy.getRank() != 1) {
141+
(void)rewriter.notifyMatchFailure(
142+
loc, "Cannot collapse dynamic dims to more than one dimension");
143+
return {};
144+
}
145+
146+
SmallVector<ReassociationExprs, 4> reassociationMap;
147+
if (!createReassociationMapsForCollapse(rewriter, operandTy.getShape(),
148+
resultTy.getShape(),
149+
reassociationMap, isDynamic)) {
150+
(void)rewriter.notifyMatchFailure(
151+
loc, "tosa.reshape Attempting to collapse into an incompatible shape");
152+
return {};
153+
}
154+
155+
SmallVector<int64_t> intermediateShape;
156+
if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
157+
intermediateShape, isDynamic)) {
158+
(void)rewriter.notifyMatchFailure(
159+
loc, "tosa.reshape Cannot collapse into given shape");
160+
return {};
161+
}
162+
return rewriter.create<tensor::CollapseShapeOp>(loc, resultTy, operand,
163+
reassociationMap);
208164
}
209165

210-
// Create a tensor.expand_shape op that reshapes the input into the given result
211-
// type.
212-
Value createExpand(OpBuilder &builder, Location loc, TensorType resultType,
213-
Value input) {
214-
auto reassociationMap =
215-
createReassociationMapForCollapse(builder, resultType, input.getType());
216-
return builder.createOrFold<tensor::ExpandShapeOp>(loc, resultType, input,
217-
reassociationMap);
166+
Value createExpand(ConversionPatternRewriter &rewriter, Location loc,
167+
ShapedType resultTy, Value operand) {
168+
ShapedType operandTy = cast<ShapedType>(operand.getType());
169+
if (resultTy == operandTy)
170+
return operand;
171+
172+
bool isDynamic = !operandTy.hasStaticShape();
173+
174+
if (isDynamic && operandTy.getRank() != 1) {
175+
(void)rewriter.notifyMatchFailure(
176+
loc, "Cannot expand dynamic dims from more than one dimension");
177+
return {};
178+
}
179+
180+
SmallVector<ReassociationExprs, 4> reassociationMap;
181+
if (!createReassociationMapsForCollapse(rewriter, resultTy.getShape(),
182+
operandTy.getShape(),
183+
reassociationMap, isDynamic)) {
184+
(void)rewriter.notifyMatchFailure(
185+
loc, "tosa.reshape Attempting to expand into an incompatible shape");
186+
return {};
187+
}
188+
189+
SmallVector<int64_t> intermediateShape;
190+
if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
191+
intermediateShape, isDynamic) ||
192+
intermediateShape != operandTy.getShape()) {
193+
(void)rewriter.notifyMatchFailure(
194+
loc, "tosa.reshape Cannot expand into given shape");
195+
return {};
196+
}
197+
return rewriter.create<tensor::ExpandShapeOp>(loc, resultTy, operand,
198+
reassociationMap);
218199
}
219200

220-
class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
201+
class ReshapeConverterCollapseExpand
202+
: public OpConversionPattern<tosa::ReshapeOp> {
221203
public:
222204
using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;
223205

224206
LogicalResult
225207
matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
226208
ConversionPatternRewriter &rewriter) const final {
227-
auto loc = reshape.getLoc();
228-
auto resultType = reshape.getResult().getType();
229-
auto input = reshape.getInput1();
230-
auto newShape = reshape.getNewShape();
231-
232-
// Infer all intermediate types
233-
auto inputType = inferReshapeInputType(input, newShape);
234-
auto expandedType = inferReshapeExpandedType(inputType, newShape);
235-
auto collapsedType = inferReshapeCollapsedType(inputType, expandedType);
236-
237-
// Cast input if needed
238-
auto castInput = rewriter.createOrFold<tensor::CastOp>(loc, inputType, input);
239-
240-
// Emit collaspe-expand pair
241-
auto collapsed = createCollapse(rewriter, loc, collapsedType, castInput);
242-
auto expanded = createExpand(rewriter, loc, expandedType, collapsed);
243-
244-
// Cast to final result type if needed
245-
auto result = rewriter.createOrFold<tensor::CastOp>(loc, resultType, expanded);
246-
rewriter.replaceOp(reshape, result);
209+
ShapedType operandTy = cast<ShapedType>(adaptor.getInput1().getType());
210+
ShapedType resultTy = cast<ShapedType>(reshape.getType());
211+
bool isDynamic = !operandTy.hasStaticShape();
212+
213+
SmallVector<int64_t> intermediateShape;
214+
if (!findIntermediateShape(resultTy.getShape(), operandTy.getShape(),
215+
intermediateShape, isDynamic)) {
216+
return rewriter.notifyMatchFailure(
217+
reshape, "tosa.reshape Cannot identify an intermediate shape between "
218+
"the given two shapes");
219+
}
220+
auto intermediateTy = RankedTensorType::get(
221+
intermediateShape, reshape.getType().getElementType());
222+
223+
Value collapse = createCollapse(rewriter, reshape.getLoc(), intermediateTy,
224+
adaptor.getInput1());
225+
if (!collapse)
226+
return failure();
227+
228+
Value expand = createExpand(rewriter, reshape.getLoc(), resultTy, collapse);
229+
if (!expand)
230+
return failure();
231+
232+
rewriter.replaceOp(reshape, expand);
247233
return success();
248234
}
249235
};
@@ -430,10 +416,8 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {
430416

431417
void mlir::tosa::populateTosaToTensorConversionPatterns(
432418
RewritePatternSet *patterns) {
433-
patterns->add<
434-
ConcatConverter,
435-
PadConverter,
436-
ReshapeConverter,
437-
SliceConverter
438-
>(patterns->getContext());
419+
patterns->add<SliceConverter, PadConverter, ConcatConverter>(
420+
patterns->getContext());
421+
422+
patterns->add<ReshapeConverterCollapseExpand>(patterns->getContext());
439423
}

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -795,10 +795,7 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
795795
if (!inputTy || !outputTy)
796796
return {};
797797

798-
// Fold when the input and output types are the same. This is only safe when
799-
// there is at most 1 dynamic dimension. For 2 or more dynamic dimensions,
800-
// there may still be a productive reshape.
801-
if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
798+
if (inputTy == outputTy)
802799
return getInput1();
803800

804801
// reshape(reshape(x)) -> reshape(x)

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -970,11 +970,6 @@ mlir::LogicalResult tosa::ReshapeOp::verify() {
970970
<< " elements into " << outputElementsNum;
971971
}
972972
}
973-
974-
int missingDims = llvm::count(getNewShape(), -1);
975-
if (missingDims > 1)
976-
return emitOpError() << "At most one target dimension can be -1";
977-
978973
return mlir::success();
979974
}
980975

0 commit comments

Comments
 (0)