Skip to content

Fixes in 'tosa.reshape' lowering and folder #85798

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
304 changes: 160 additions & 144 deletions mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,99 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"

#include <numeric>

using namespace mlir;
using namespace tosa;

static bool findIntermediateShape(ArrayRef<int64_t> lhsShape,
ArrayRef<int64_t> rhsShape,
SmallVector<int64_t> &intermediateShape,
bool isDynamic) {
if (isDynamic) {
// TODO (natashaknk): Make dynamic intermediate shape not always be rank-1
intermediateShape = {ShapedType::kDynamic};
return true;
}
namespace {

if (lhsShape.empty() || rhsShape.empty()) {
intermediateShape = {};
return true;
}
// Infer the type to which the input of a 'tosa.reshape' op must be cast when
// lowered.
TensorType inferReshapeInputType(TypedValue<TensorType> input,
ArrayRef<int64_t> newShape) {
// No need to cast input for non-empty target shape
if (!newShape.empty())
return input.getType();

// The input type must be cast into a tensor with the same rank and all static
// dimensions set to 1. This prevents the generation of a tensor.collapse_shape
// op that converts a dynamically shaped tensor into a 0D tensor. While such
// construct is not incorrect on its own, bufferization cannot properly handle
// it at the moment, so we avoid it.
SmallVector<int64_t> shape(input.getType().getRank(), 1);
return input.getType().clone(shape);
}

// Infer the result type of 'tensor.expand_shape' in the collapse-expand
// pair emitted for a 'tosa.reshape' op.
TensorType inferReshapeExpandedType(TensorType inputType,
ArrayRef<int64_t> newShape) {
// Special case for 0D output tensor. Note: Watch out when using Type::clone()
// with just '{}', as it will invoke the incorrect overload.
if (newShape.empty())
return inputType.clone(ArrayRef<int64_t>{});

// Check if the input is static, and if so, get its total size
bool inputIsStatic = inputType.hasStaticShape();
int64_t totalSize = inputIsStatic ? inputType.getNumElements() : -1;

// Compute result shape
bool resultIsStatic = true;
auto resultShape = llvm::map_to_vector(newShape, [&](int64_t size) -> int64_t {
// If this is not a placeholder, do not change it
if (size >= 0)
return size;

// If we do not know the total size of the tensor, keep this dimension
// dynamic in the result shape.
if (!inputIsStatic) {
resultIsStatic = false;
return ShapedType::kDynamic;
}

// Calculate the product of all elements in 'newShape' except for the -1
// placeholder, which we discard by negating the result.
int64_t totalSizeNoPlaceholder = -std::accumulate(
newShape.begin(), newShape.end(), 1, std::multiplies());

// If there is a 0 component in 'newShape', resolve the placeholder as 0.
if (totalSizeNoPlaceholder == 0)
return 0;

// Resolve the placeholder as the quotient between the total tensor size and
// the product of all other sizes.
return totalSize / totalSizeNoPlaceholder;
});

// A syntactic restriction in 'tensor.expand_shape' forbids a dynamically
// shaped input from being reshaped into a statically shaped result. We may
// simply turn the first result dimension dynamic to address this.
if (!inputIsStatic && resultIsStatic)
resultShape[0] = ShapedType::kDynamic;

// The 'tensor.expand_shape' op also forbids a statically shaped input from
// being reshaped into a dynamically shaped result, but the placeholder
// inference algorithm above guarantees that this will never be the case.
assert(!inputIsStatic || resultIsStatic);

// Create result type
return inputType.clone(resultShape);
}

// Infer the result type of 'tensor.collapse_shape' in the collapse-expand
// pair emitted for a 'tosa.reshape' op.
TensorType inferReshapeCollapsedType(TensorType lhsType, TensorType rhsType) {
auto lhsShape = lhsType.getShape();
auto rhsShape = rhsType.getShape();

if (lhsShape.empty() || rhsShape.empty())
return lhsType.clone(ArrayRef<int64_t>{});

if (ShapedType::isDynamicShape(lhsShape) || ShapedType::isDynamicShape(rhsShape))
return lhsType.clone({ShapedType::kDynamic});

SmallVector<int64_t> intermediateShape;
unsigned currLhsDim = 0, currRhsDim = 0;
while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) {
int64_t rhsSize = rhsShape[currRhsDim];
Expand All @@ -62,174 +137,113 @@ static bool findIntermediateShape(ArrayRef<int64_t> lhsShape,
currLhsDim++;
}

// If the iterators didn't reach the end and their leftover dimensions are not
// equal to 1 an intermediate shape was not found.
while (currLhsDim < lhsShape.size()) {
if (lhsShape[currLhsDim++] != 1) {
return false;
}
// Static shapes are guaranteed to be compatible by the op verifier, so all
// leftover dimensions should be 1.
for (; currLhsDim < lhsShape.size(); currLhsDim++) {
assert(lhsShape[currLhsDim] == 1);
}

while (currRhsDim < rhsShape.size()) {
if (rhsShape[currRhsDim++] != 1) {
return false;
}
for (; currRhsDim < rhsShape.size(); currRhsDim++) {
assert(rhsShape[currRhsDim] == 1);
}

return true;
return lhsType.clone(intermediateShape);
}

static bool createReassociationMapsForCollapse(
PatternRewriter &rewriter, ArrayRef<int64_t> srcShape,
ArrayRef<int64_t> dstShape,
SmallVector<ReassociationExprs, 4> &reassociationMap, bool isDynamic) {
SmallVector<ReassociationExprs>
createReassociationMapForCollapse(OpBuilder &builder, Type srcType, Type dstType) {
auto srcShape = cast<TensorType>(srcType).getShape();
auto dstShape = cast<TensorType>(dstType).getShape();

// If the shape is dynamic, create a map for collapsing into one dimension.
if (isDynamic) {
SmallVector<AffineExpr, 2> exprs;
for (int i = 0, s = srcShape.size(); i < s; ++i)
exprs.push_back(rewriter.getAffineDimExpr(i));
reassociationMap = {exprs};
return true;
}
if (srcShape.empty() || dstShape.empty())
return {};

if (dstShape.empty()) {
reassociationMap = {};
return true;
if (ShapedType::isDynamicShape(srcShape) || ShapedType::isDynamicShape(dstShape)) {
assert(dstShape.size() == 1);
SmallVector<AffineExpr, 2> exprs;
for (auto i : llvm::seq<int64_t>(srcShape.size()))
exprs.push_back(builder.getAffineDimExpr(i));
return {exprs};
}

reassociationMap.resize(dstShape.size());
SmallVector<ReassociationExprs> reassociationMap(dstShape.size());
unsigned currSrcDim = 0, currDstDim = 0;
while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
int64_t dstSize = dstShape[currDstDim];
int64_t srcSize = srcShape[currSrcDim];
while (srcSize < dstSize && currSrcDim < srcShape.size()) {
reassociationMap[currDstDim].push_back(
rewriter.getAffineDimExpr(currSrcDim++));
builder.getAffineDimExpr(currSrcDim++));
srcSize *= srcShape[currSrcDim];
}
if (srcSize == dstSize) {
reassociationMap[currDstDim].push_back(
rewriter.getAffineDimExpr(currSrcDim++));
builder.getAffineDimExpr(currSrcDim++));
// If the next dim in collapsedShape is not 1, treat subsequent dims in
// expandedShape which are 1 to be collapsed.
if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) {
while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
reassociationMap[currDstDim].push_back(
rewriter.getAffineDimExpr(currSrcDim++));
builder.getAffineDimExpr(currSrcDim++));
}
}
}
currDstDim++;
}

// If both iterators didn't reach the end, we have leftover dimentions which
// implies that we have a mismatch in shape.
return currSrcDim == srcShape.size() && currDstDim == dstShape.size();
// If the source and target shapes are compatible, both iterators must have
// reached the end. This condition is guaranteed by the op verifier for
// static shapes.
assert(currSrcDim == srcShape.size() && currDstDim == dstShape.size());
return reassociationMap;
}

namespace {
Value createCollapse(ConversionPatternRewriter &rewriter, Location loc,
ShapedType resultTy, Value operand) {
ShapedType operandTy = cast<ShapedType>(operand.getType());
if (resultTy == operandTy)
return operand;

bool isDynamic = !operandTy.hasStaticShape();

if (isDynamic && resultTy.getRank() != 1) {
(void)rewriter.notifyMatchFailure(
loc, "Cannot collapse dynamic dims to more than one dimension");
return {};
}

SmallVector<ReassociationExprs, 4> reassociationMap;
if (!createReassociationMapsForCollapse(rewriter, operandTy.getShape(),
resultTy.getShape(),
reassociationMap, isDynamic)) {
(void)rewriter.notifyMatchFailure(
loc, "tosa.reshape Attempting to collapse into an incompatible shape");
return {};
}

SmallVector<int64_t> intermediateShape;
if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
intermediateShape, isDynamic)) {
(void)rewriter.notifyMatchFailure(
loc, "tosa.reshape Cannot collapse into given shape");
return {};
}
return rewriter.create<tensor::CollapseShapeOp>(loc, resultTy, operand,
reassociationMap);
// Create a tensor.collapse_shape op that reshapes the input into the given
// result type.
Value createCollapse(OpBuilder &builder, Location loc, TensorType resultType,
Value input) {
auto reassociationMap =
createReassociationMapForCollapse(builder, input.getType(), resultType);
return builder.createOrFold<tensor::CollapseShapeOp>(loc, resultType, input,
reassociationMap);
}

Value createExpand(ConversionPatternRewriter &rewriter, Location loc,
ShapedType resultTy, Value operand) {
ShapedType operandTy = cast<ShapedType>(operand.getType());
if (resultTy == operandTy)
return operand;

bool isDynamic = !operandTy.hasStaticShape();

if (isDynamic && operandTy.getRank() != 1) {
(void)rewriter.notifyMatchFailure(
loc, "Cannot expand dynamic dims from more than one dimension");
return {};
}

SmallVector<ReassociationExprs, 4> reassociationMap;
if (!createReassociationMapsForCollapse(rewriter, resultTy.getShape(),
operandTy.getShape(),
reassociationMap, isDynamic)) {
(void)rewriter.notifyMatchFailure(
loc, "tosa.reshape Attempting to expand into an incompatible shape");
return {};
}

SmallVector<int64_t> intermediateShape;
if (!findIntermediateShape(operandTy.getShape(), resultTy.getShape(),
intermediateShape, isDynamic) ||
intermediateShape != operandTy.getShape()) {
(void)rewriter.notifyMatchFailure(
loc, "tosa.reshape Cannot expand into given shape");
return {};
}
return rewriter.create<tensor::ExpandShapeOp>(loc, resultTy, operand,
reassociationMap);
// Create a tensor.expand_shape op that reshapes the input into the given result
// type.
Value createExpand(OpBuilder &builder, Location loc, TensorType resultType,
Value input) {
auto reassociationMap =
createReassociationMapForCollapse(builder, resultType, input.getType());
return builder.createOrFold<tensor::ExpandShapeOp>(loc, resultType, input,
reassociationMap);
}

class ReshapeConverterCollapseExpand
: public OpConversionPattern<tosa::ReshapeOp> {
class ReshapeConverter : public OpConversionPattern<tosa::ReshapeOp> {
public:
using OpConversionPattern<tosa::ReshapeOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
ShapedType operandTy = cast<ShapedType>(adaptor.getInput1().getType());
ShapedType resultTy = cast<ShapedType>(reshape.getType());
bool isDynamic = !operandTy.hasStaticShape();

SmallVector<int64_t> intermediateShape;
if (!findIntermediateShape(resultTy.getShape(), operandTy.getShape(),
intermediateShape, isDynamic)) {
return rewriter.notifyMatchFailure(
reshape, "tosa.reshape Cannot identify an intermediate shape between "
"the given two shapes");
}
auto intermediateTy = RankedTensorType::get(
intermediateShape, reshape.getType().getElementType());

Value collapse = createCollapse(rewriter, reshape.getLoc(), intermediateTy,
adaptor.getInput1());
if (!collapse)
return failure();

Value expand = createExpand(rewriter, reshape.getLoc(), resultTy, collapse);
if (!expand)
return failure();

rewriter.replaceOp(reshape, expand);
auto loc = reshape.getLoc();
auto resultType = reshape.getResult().getType();
auto input = reshape.getInput1();
auto newShape = reshape.getNewShape();

// Infer all intermediate types
auto inputType = inferReshapeInputType(input, newShape);
auto expandedType = inferReshapeExpandedType(inputType, newShape);
auto collapsedType = inferReshapeCollapsedType(inputType, expandedType);

// Cast input if needed
auto castInput = rewriter.createOrFold<tensor::CastOp>(loc, inputType, input);

// Emit collaspe-expand pair
auto collapsed = createCollapse(rewriter, loc, collapsedType, castInput);
auto expanded = createExpand(rewriter, loc, expandedType, collapsed);

// Cast to final result type if needed
auto result = rewriter.createOrFold<tensor::CastOp>(loc, resultType, expanded);
rewriter.replaceOp(reshape, result);
return success();
}
};
Expand Down Expand Up @@ -416,8 +430,10 @@ struct ConcatConverter : public OpConversionPattern<tosa::ConcatOp> {

void mlir::tosa::populateTosaToTensorConversionPatterns(
RewritePatternSet *patterns) {
patterns->add<SliceConverter, PadConverter, ConcatConverter>(
patterns->getContext());

patterns->add<ReshapeConverterCollapseExpand>(patterns->getContext());
patterns->add<
ConcatConverter,
PadConverter,
ReshapeConverter,
SliceConverter
>(patterns->getContext());
}
5 changes: 4 additions & 1 deletion mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,10 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
if (!inputTy || !outputTy)
return {};

if (inputTy == outputTy)
// Fold when the input and output types are the same. This is only safe when
// there is at most 1 dynamic dimension. For 2 or more dynamic dimensions,
// there may still be a productive reshape.
if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned in the earlier comment, this restriction will be relaxed. This would probably necessitate dynamic cases to be predicated on something like tensor.dim for later resolution.

However, for the purposes of this PR this code looks fine.

return getInput1();

// reshape(reshape(x)) -> reshape(x)
Expand Down
5 changes: 5 additions & 0 deletions mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,11 @@ mlir::LogicalResult tosa::ReshapeOp::verify() {
<< " elements into " << outputElementsNum;
}
}

int missingDims = llvm::count(getNewShape(), -1);
if (missingDims > 1)
return emitOpError() << "At most one target dimension can be -1";

return mlir::success();
}

Expand Down
Loading