Skip to content

[TOSA] Move CreateOpAndInfer into ConversionUtils.h #106122

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 137 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@

#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/PatternMatch.h"
#include <optional>

Expand Down Expand Up @@ -79,6 +81,141 @@ checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op,
LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc,
Value &input1, Value &input2);

LogicalResult EqualizeRanks(ImplicitLocOpBuilder &builder, Value &input1,
Value &input2);

namespace {

// Creates a TOSA operation and performs shape inference on the individual
// op. This allows shape inference when lowering down to TOSA.
template <typename TosaOp, typename... Args>
TosaOp createOpAndInferShape(ImplicitLocOpBuilder &builder, Type resultTy,
Args &&...args) {
auto op = builder.create<TosaOp>(resultTy, args...);

InferShapedTypeOpInterface shapeInterface =
dyn_cast<InferShapedTypeOpInterface>(op.getOperation());
if (!shapeInterface)
return op;

SmallVector<ShapedTypeComponents> returnedShapes;
if (shapeInterface
.inferReturnTypeComponents(op.getContext(), builder.getLoc(),
op->getOperands(), op->getAttrDictionary(),
op->getPropertiesStorage(),
op->getRegions(), returnedShapes)
.failed())
return op;

// We need to use the element type of the existing result type to generate
// the new result shaped type. This is because rescale can include a cast to
// different bit-width types and does not have a TypeAttr to define the
// target type.
auto result = op->getResult(0);
auto predictedShape = returnedShapes[0];
auto currentKnowledge = ValueKnowledge::getKnowledgeFromType(resultTy);

// Compute the knowledge based on the inferred type.
auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
inferredKnowledge.dtype = mlir::cast<ShapedType>(resultTy).getElementType();
inferredKnowledge.hasRank = predictedShape.hasRank();
if (predictedShape.hasRank()) {
for (auto dim : predictedShape.getDims()) {
inferredKnowledge.sizes.push_back(dim);
}
}

// Compute the new type based on the joined version.
auto newKnowledge = ValueKnowledge::join(currentKnowledge, inferredKnowledge);
Type newTy =
newKnowledge.hasRank
? Type{mlir::RankedTensorType::get(llvm::ArrayRef(newKnowledge.sizes),
newKnowledge.dtype)}
: Type{mlir::UnrankedTensorType::get(newKnowledge.dtype)};
result.setType(newTy);
return op;
}

} // namespace

// Creates a TOSA operation by:
// - first equalize ranks for ops with SameOperandsAndResultRank trait
// - create operator
// - performs shape inference on this operator
template <typename TosaOp, typename... Args>
TosaOp CreateOpAndInferShape(ImplicitLocOpBuilder &builder, Type resultTy,
Args &&...args) {
if (TosaOp::template hasTrait<OpTrait::SameOperandsAndResultRank>()) {
// op requires same ranks for tensor operands
if constexpr (sizeof...(Args) == 2) {
auto argX = std::get<0>(std::tie(args...));
auto argY = std::get<1>(std::tie(args...));
using ArgX = decltype(argX);
using ArgY = decltype(argY);
if constexpr (std::is_same_v<ArgX, Value> &&
std::is_same_v<ArgY, Value>) {
Value x = std::get<0>(std::tie(args...));
Value y = std::get<1>(std::tie(args...));
if (EqualizeRanks(builder, x, y).failed()) {
// incompatible broadcast shapes, no reshape is inserted
// ResultsBroadcastableShape verify will handle this
}
return createOpAndInferShape<TosaOp>(builder, resultTy, x, y);
}
}
if constexpr (sizeof...(Args) == 3) {
auto argX = std::get<0>(std::tie(args...));
auto argY = std::get<1>(std::tie(args...));
auto argZ = std::get<2>(std::tie(args...));
using ArgX = decltype(argX);
using ArgY = decltype(argY);
using ArgZ = decltype(argZ);
if constexpr (std::is_same_v<ArgX, Value> &&
std::is_same_v<ArgY, Value> && std::is_same_v<ArgZ, bool>) {
// special case for ArithmeticRightShiftOp
Value x = std::get<0>(std::tie(args...));
Value y = std::get<1>(std::tie(args...));
bool round = std::get<2>(std::tie(args...));
if (EqualizeRanks(builder, x, y).failed()) {
// incompatible broadcast shapes, no reshape is inserted
// ResultsBroadcastableShape verify will handle this
}
return createOpAndInferShape<TosaOp>(builder, resultTy, x, y, round);
}
if constexpr (std::is_same_v<ArgX, Value> &&
std::is_same_v<ArgY, Value> &&
std::is_same_v<ArgZ, Value>) {
// special case for Select
Value x = std::get<0>(std::tie(args...));
Value y = std::get<1>(std::tie(args...));
Value z = std::get<2>(std::tie(args...));

if (EqualizeRanks(builder, x, y).failed() ||
EqualizeRanks(builder, x, z).failed() ||
EqualizeRanks(builder, y, z).failed()) {
// incompatible broadcast shapes, no reshape is inserted
// ResultsBroadcastableShape verify will handle this
}

return createOpAndInferShape<TosaOp>(builder, resultTy, x, y, z);
}
}
}

return createOpAndInferShape<TosaOp>(builder, resultTy, args...);
}

// Creates a TOSA operation by:
// - first equalize ranks for ops with SameOperandsAndResultRank trait
// - create operator
// - performs shape inference on this operator
template <typename TosaOp, typename... Args>
TosaOp CreateOpAndInferShape(PatternRewriter &rewriter, Location loc,
Type resultTy, Args &&...args) {
ImplicitLocOpBuilder builder(loc, rewriter);
return CreateOpAndInferShape<TosaOp>(builder, resultTy, args...);
}

} // namespace tosa
} // namespace mlir

Expand Down
93 changes: 23 additions & 70 deletions mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,53 +26,6 @@ using namespace mlir::tosa;

namespace {

template <typename TosaOp, typename... Args>
TosaOp createOpAndInfer(PatternRewriter &rewriter, Location loc, Type resultTy,
Args &&...args) {
auto op = rewriter.create<TosaOp>(loc, resultTy, args...);

InferShapedTypeOpInterface shapeInterface =
dyn_cast<InferShapedTypeOpInterface>(op.getOperation());
if (!shapeInterface)
return op;

SmallVector<ShapedTypeComponents> returnedShapes;
if (shapeInterface
.inferReturnTypeComponents(
op.getContext(), op.getLoc(), op->getOperands(),
op->getDiscardableAttrDictionary(), op->getPropertiesStorage(),
op->getRegions(), returnedShapes)
.failed())
return op;

// We need to use the element type of the existing result type to generate
// the new result shaped type. This is because rescale can include a cast to
// different bit-width types and does not have a TypeAttr to define the
// target type.
auto result = op->getResult(0);
auto predictedShape = returnedShapes[0];
auto currentKnowledge =
mlir::tosa::ValueKnowledge::getKnowledgeFromType(resultTy);

// Compute the knowledge based on the inferred type.
auto inferredKnowledge =
mlir::tosa::ValueKnowledge::getPessimisticValueState();
inferredKnowledge.dtype = cast<ShapedType>(resultTy).getElementType();
inferredKnowledge.hasRank = predictedShape.hasRank();
if (predictedShape.hasRank()) {
for (auto dim : predictedShape.getDims()) {
inferredKnowledge.sizes.push_back(dim);
}
}

// Compute the new type based on the joined version.
auto newKnowledge =
mlir::tosa::ValueKnowledge::join(currentKnowledge, inferredKnowledge);
auto newTy = newKnowledge.getType();
result.setType(newTy);
return op;
}

class TransposeConvNonStridedConverter
: public OpRewritePattern<tosa::TransposeConv2DOp> {
public:
Expand Down Expand Up @@ -187,20 +140,20 @@ class TransposeConvStridedConverter
(weightWidth % stride[1]) ? (stride[1] - weightWidth % stride[1]) : 0;
DenseElementsAttr weightPaddingAttr = DenseIntElementsAttr::get(
RankedTensorType::get({4, 2}, rewriter.getI32Type()), weightPadding);
Value weightPaddingVal = createOpAndInfer<tosa::ConstOp>(
Value weightPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
rewriter, loc, weightPaddingAttr.getType(), weightPaddingAttr);

if (op.getQuantizationInfo().has_value()) {
auto quantInfo = op.getQuantizationInfo().value();
weight = createOpAndInfer<tosa::PadOp>(
weight = CreateOpAndInferShape<tosa::PadOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
weightPaddingVal, nullptr,
rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getWeightZp()));

} else {
weight = createOpAndInfer<tosa::PadOp>(rewriter, loc,
UnrankedTensorType::get(weightETy),
weight, weightPaddingVal);
weight = CreateOpAndInferShape<tosa::PadOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
weightPaddingVal);
}

weightTy = cast<ShapedType>(weight.getType());
Expand All @@ -212,7 +165,7 @@ class TransposeConvStridedConverter
outputChannels, weightHeight / stride[0],
stride[0], weightWidth / stride[1],
stride[1], inputChannels};
weight = createOpAndInfer<tosa::ReshapeOp>(
weight = CreateOpAndInferShape<tosa::ReshapeOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
rewriter.getDenseI64ArrayAttr(weightReshapeDims0));

Expand All @@ -221,23 +174,23 @@ class TransposeConvStridedConverter
loc, RankedTensorType::get({6}, rewriter.getI32Type()),
rewriter.getI32TensorAttr({2, 4, 0, 1, 3, 5}));

weight = createOpAndInfer<tosa::TransposeOp>(
weight = CreateOpAndInferShape<tosa::TransposeOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
transposeWeightVal);

// Collapse the strides and output channels into a single dimension.
llvm::SmallVector<int64_t, 6> weightReshapeDims1 = {
outputChannels * stride[0] * stride[1], weightHeight / stride[0],
weightWidth / stride[1], inputChannels};
weight = createOpAndInfer<tosa::ReshapeOp>(
weight = CreateOpAndInferShape<tosa::ReshapeOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
rewriter.getDenseI64ArrayAttr(weightReshapeDims1));
ShapedType restridedWeightTy = cast<ShapedType>(weight.getType());

weight = createOpAndInfer<tosa::ReverseOp>(
weight = CreateOpAndInferShape<tosa::ReverseOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
/* axis = */ rewriter.getI32IntegerAttr(1));
weight = createOpAndInfer<tosa::ReverseOp>(
weight = CreateOpAndInferShape<tosa::ReverseOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
/* axis = */ rewriter.getI32IntegerAttr(2));

Expand All @@ -251,19 +204,19 @@ class TransposeConvStridedConverter
DenseElementsAttr inputPaddingAttr = DenseIntElementsAttr::get(
RankedTensorType::get({4, 2}, rewriter.getI32Type()), inputPadding);

Value inputPaddingVal = createOpAndInfer<tosa::ConstOp>(
Value inputPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
rewriter, loc, inputPaddingAttr.getType(), inputPaddingAttr);

if (op.getQuantizationInfo().has_value()) {
auto quantInfo = op.getQuantizationInfo().value();
input = createOpAndInfer<tosa::PadOp>(
input = CreateOpAndInferShape<tosa::PadOp>(
rewriter, loc, UnrankedTensorType::get(inputETy), input,
inputPaddingVal, nullptr,
rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getInputZp()));
} else {
input = createOpAndInfer<tosa::PadOp>(rewriter, loc,
UnrankedTensorType::get(inputETy),
input, inputPaddingVal);
input = CreateOpAndInferShape<tosa::PadOp>(
rewriter, loc, UnrankedTensorType::get(inputETy), input,
inputPaddingVal);
}

// We use a zero bias as we need to broadcast the bias.
Expand All @@ -279,7 +232,7 @@ class TransposeConvStridedConverter
// Perform the convolution using the zero bias.
Value conv2d;
if (op.getQuantizationInfo()) {
conv2d = createOpAndInfer<tosa::Conv2DOp>(
conv2d = CreateOpAndInferShape<tosa::Conv2DOp>(
rewriter, loc, UnrankedTensorType::get(resultETy), input,
weight, zeroBias,
/*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
Expand All @@ -288,7 +241,7 @@ class TransposeConvStridedConverter
*op.getQuantizationInfo())
.getResult();
} else {
conv2d = createOpAndInfer<tosa::Conv2DOp>(
conv2d = CreateOpAndInferShape<tosa::Conv2DOp>(
rewriter, loc, UnrankedTensorType::get(resultETy), input,
weight, zeroBias,
/*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
Expand All @@ -307,7 +260,7 @@ class TransposeConvStridedConverter
// Factor striding out of the convolution result.
llvm::SmallVector<int64_t, 6> convReshapeDims0 = {
batch, convHeight, convWidth, stride[0], stride[1], outputChannels};
conv2d = createOpAndInfer<tosa::ReshapeOp>(
conv2d = CreateOpAndInferShape<tosa::ReshapeOp>(
rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
rewriter.getDenseI64ArrayAttr(convReshapeDims0));

Expand All @@ -316,14 +269,14 @@ class TransposeConvStridedConverter
loc, RankedTensorType::get({6}, rewriter.getI32Type()),
rewriter.getI32TensorAttr({0, 1, 3, 2, 4, 5}));

conv2d = createOpAndInfer<tosa::TransposeOp>(
conv2d = CreateOpAndInferShape<tosa::TransposeOp>(
rewriter, loc, UnrankedTensorType::get(convETy), conv2d,
transposeConvVal);

// Fuse striding behavior back into width / height.
llvm::SmallVector<int64_t, 6> convReshapeDims1 = {
batch, convHeight * stride[0], convWidth * stride[1], outputChannels};
conv2d = createOpAndInfer<tosa::ReshapeOp>(
conv2d = CreateOpAndInferShape<tosa::ReshapeOp>(
rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
rewriter.getDenseI64ArrayAttr(convReshapeDims1));

Expand All @@ -348,7 +301,7 @@ class TransposeConvStridedConverter
sliceSize[1] = resultSliceHeight;
sliceSize[2] = resultSliceWidth;

auto slice = createOpAndInfer<tosa::SliceOp>(
auto slice = CreateOpAndInferShape<tosa::SliceOp>(
rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
rewriter.getDenseI64ArrayAttr(sliceBegin),
rewriter.getDenseI64ArrayAttr(sliceSize))
Expand All @@ -363,10 +316,10 @@ class TransposeConvStridedConverter
DenseElementsAttr resultPaddingAttr = DenseIntElementsAttr::get(
RankedTensorType::get({4, 2}, rewriter.getI32Type()), resultPadding);

Value resultPaddingVal = createOpAndInfer<tosa::ConstOp>(
Value resultPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
rewriter, loc, resultPaddingAttr.getType(), resultPaddingAttr);

Value resultPad = createOpAndInfer<tosa::PadOp>(
Value resultPad = CreateOpAndInferShape<tosa::PadOp>(
rewriter, loc, UnrankedTensorType::get(resultETy), slice,
resultPaddingVal);

Expand Down
12 changes: 9 additions & 3 deletions mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,12 @@ computeReshapeOutput(ArrayRef<int64_t> higherRankShape,

LogicalResult mlir::tosa::EqualizeRanks(PatternRewriter &rewriter, Location loc,
Value &input1, Value &input2) {
ImplicitLocOpBuilder builder(loc, rewriter);
return EqualizeRanks(builder, input1, input2);
}

LogicalResult mlir::tosa::EqualizeRanks(ImplicitLocOpBuilder &builder,
Value &input1, Value &input2) {
auto input1Ty = llvm::dyn_cast<RankedTensorType>(input1.getType());
auto input2Ty = llvm::dyn_cast<RankedTensorType>(input2.getType());

Expand Down Expand Up @@ -140,9 +146,9 @@ LogicalResult mlir::tosa::EqualizeRanks(PatternRewriter &rewriter, Location loc,
auto reshapeOutputType = RankedTensorType::get(
ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType());

auto reshapeLower = rewriter.create<tosa::ReshapeOp>(
loc, reshapeOutputType, lowerTensorValue,
rewriter.getDenseI64ArrayAttr(reshapeOutputShape));
auto reshapeLower = builder.create<tosa::ReshapeOp>(
reshapeOutputType, lowerTensorValue,
builder.getDenseI64ArrayAttr(reshapeOutputShape));

if (input1Rank > input2Rank) {
input1 = higherTensorValue;
Expand Down
Loading