-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[TOSA] Move CreateOpAndInfer into ConversionUtils.h #106122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-tosa Author: Tai Ly (Tai78641) ChangesThis moves CreateOpAndInfer from TF legalize_util.h into ConversionUtils.h Renamed to CreateOpAndInferShape so we can upstream this independently of tensorflow (otherwise a redefinition error would break TF compile if not upstreamed together with removal of CreateOpAndInfer in TF) Full diff: https://github.com/llvm/llvm-project/pull/106122.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
index ceab7d9c628a54..60e7ed1ce2f876 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h
@@ -15,7 +15,9 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/PatternMatch.h"
#include <optional>
@@ -79,6 +81,141 @@ checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op,
LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc,
Value &input1, Value &input2);
+LogicalResult EqualizeRanks(ImplicitLocOpBuilder &builder, Value &input1,
+ Value &input2);
+
+namespace {
+
+// Creates a TOSA operation and performs shape inference on the individual
+// op. This allows shape inference during the TFLite to TOSA lowering.
+template <typename TosaOp, typename... Args>
+TosaOp createOpAndInferShape(ImplicitLocOpBuilder &builder, Type result_ty,
+ Args &&...args) {
+ auto op = builder.create<TosaOp>(result_ty, args...);
+
+ InferShapedTypeOpInterface shapeInterface =
+ dyn_cast<InferShapedTypeOpInterface>(op.getOperation());
+ if (!shapeInterface)
+ return op;
+
+ SmallVector<ShapedTypeComponents> returnedShapes;
+ if (shapeInterface
+ .inferReturnTypeComponents(op.getContext(), builder.getLoc(),
+ op->getOperands(), op->getAttrDictionary(),
+ op->getPropertiesStorage(),
+ op->getRegions(), returnedShapes)
+ .failed())
+ return op;
+
+ // We need to use the element type of the existing result type to generate
+ // the new result shaped type. This is because rescale can include a cast to
+ // different bit-width types and does not have a TypeAttr to define the
+ // target type.
+ auto result = op->getResult(0);
+ auto predictedShape = returnedShapes[0];
+ auto currentKnowledge = ValueKnowledge::getKnowledgeFromType(result_ty);
+
+ // Compute the knowledge based on the inferred type.
+ auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
+ inferredKnowledge.dtype = mlir::cast<ShapedType>(result_ty).getElementType();
+ inferredKnowledge.hasRank = predictedShape.hasRank();
+ if (predictedShape.hasRank()) {
+ for (auto dim : predictedShape.getDims()) {
+ inferredKnowledge.sizes.push_back(dim);
+ }
+ }
+
+ // Compute the new type based on the joined version.
+ auto newKnowledge = ValueKnowledge::join(currentKnowledge, inferredKnowledge);
+ Type new_ty =
+ newKnowledge.hasRank
+ ? Type{mlir::RankedTensorType::get(llvm::ArrayRef(newKnowledge.sizes),
+ newKnowledge.dtype)}
+ : Type{mlir::UnrankedTensorType::get(newKnowledge.dtype)};
+ result.setType(new_ty);
+ return op;
+}
+
+} // namespace
+
+// Creates a TOSA operation by:
+// - first equalize ranks for ops with SameOperandsAndResultRank trait
+// - create operator
+// - performs shape inference on this operator
+template <typename TosaOp, typename... Args>
+TosaOp CreateOpAndInferShape(ImplicitLocOpBuilder &builder, Type result_ty,
+ Args &&...args) {
+ if (TosaOp::template hasTrait<OpTrait::SameOperandsAndResultRank>()) {
+ // op requires same ranks for tensor operands
+ if constexpr (sizeof...(Args) == 2) {
+ auto argX = std::get<0>(std::tie(args...));
+ auto argY = std::get<1>(std::tie(args...));
+ using ArgX = decltype(argX);
+ using ArgY = decltype(argY);
+ if constexpr (std::is_same_v<ArgX, Value> &&
+ std::is_same_v<ArgY, Value>) {
+ Value x = std::get<0>(std::tie(args...));
+ Value y = std::get<1>(std::tie(args...));
+ if (EqualizeRanks(builder, x, y).failed()) {
+ // incompatible broadcast shapes, no reshape is inserted
+ // ResultsBroadcastableShape verify will handle this
+ }
+ return createOpAndInferShape<TosaOp>(builder, result_ty, x, y);
+ }
+ }
+ if constexpr (sizeof...(Args) == 3) {
+ auto argX = std::get<0>(std::tie(args...));
+ auto argY = std::get<1>(std::tie(args...));
+ auto argZ = std::get<2>(std::tie(args...));
+ using ArgX = decltype(argX);
+ using ArgY = decltype(argY);
+ using ArgZ = decltype(argZ);
+ if constexpr (std::is_same_v<ArgX, Value> &&
+ std::is_same_v<ArgY, Value> && std::is_same_v<ArgZ, bool>) {
+ // special case for ArithmeticRightShiftOp
+ Value x = std::get<0>(std::tie(args...));
+ Value y = std::get<1>(std::tie(args...));
+ bool round = std::get<2>(std::tie(args...));
+ if (EqualizeRanks(builder, x, y).failed()) {
+ // incompatible broadcast shapes, no reshape is inserted
+ // ResultsBroadcastableShape verify will handle this
+ }
+ return createOpAndInferShape<TosaOp>(builder, result_ty, x, y, round);
+ }
+ if constexpr (std::is_same_v<ArgX, Value> &&
+ std::is_same_v<ArgY, Value> &&
+ std::is_same_v<ArgZ, Value>) {
+ // special case for Select
+ Value x = std::get<0>(std::tie(args...));
+ Value y = std::get<1>(std::tie(args...));
+ Value z = std::get<2>(std::tie(args...));
+
+ if (EqualizeRanks(builder, x, y).failed() ||
+ EqualizeRanks(builder, x, z).failed() ||
+ EqualizeRanks(builder, y, z).failed()) {
+ // incompatible broadcast shapes, no reshape is inserted
+ // ResultsBroadcastableShape verify will handle this
+ }
+
+ return createOpAndInferShape<TosaOp>(builder, result_ty, x, y, z);
+ }
+ }
+ }
+
+ return createOpAndInferShape<TosaOp>(builder, result_ty, args...);
+}
+
+// Creates a TOSA operation by:
+// - first equalize ranks for ops with SameOperandsAndResultRank trait
+// - create operator
+// - performs shape inference on this operator
+template <typename TosaOp, typename... Args>
+TosaOp CreateOpAndInferShape(PatternRewriter &rewriter, Location loc,
+ Type result_ty, Args &&...args) {
+ ImplicitLocOpBuilder builder(loc, rewriter);
+ return CreateOpAndInferShape<TosaOp>(builder, result_ty, args...);
+}
+
} // namespace tosa
} // namespace mlir
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index a94bb3a920b1db..0779cdb9667a1a 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -26,53 +26,6 @@ using namespace mlir::tosa;
namespace {
-template <typename TosaOp, typename... Args>
-TosaOp createOpAndInfer(PatternRewriter &rewriter, Location loc, Type resultTy,
- Args &&...args) {
- auto op = rewriter.create<TosaOp>(loc, resultTy, args...);
-
- InferShapedTypeOpInterface shapeInterface =
- dyn_cast<InferShapedTypeOpInterface>(op.getOperation());
- if (!shapeInterface)
- return op;
-
- SmallVector<ShapedTypeComponents> returnedShapes;
- if (shapeInterface
- .inferReturnTypeComponents(
- op.getContext(), op.getLoc(), op->getOperands(),
- op->getDiscardableAttrDictionary(), op->getPropertiesStorage(),
- op->getRegions(), returnedShapes)
- .failed())
- return op;
-
- // We need to use the element type of the existing result type to generate
- // the new result shaped type. This is because rescale can include a cast to
- // different bit-width types and does not have a TypeAttr to define the
- // target type.
- auto result = op->getResult(0);
- auto predictedShape = returnedShapes[0];
- auto currentKnowledge =
- mlir::tosa::ValueKnowledge::getKnowledgeFromType(resultTy);
-
- // Compute the knowledge based on the inferred type.
- auto inferredKnowledge =
- mlir::tosa::ValueKnowledge::getPessimisticValueState();
- inferredKnowledge.dtype = cast<ShapedType>(resultTy).getElementType();
- inferredKnowledge.hasRank = predictedShape.hasRank();
- if (predictedShape.hasRank()) {
- for (auto dim : predictedShape.getDims()) {
- inferredKnowledge.sizes.push_back(dim);
- }
- }
-
- // Compute the new type based on the joined version.
- auto newKnowledge =
- mlir::tosa::ValueKnowledge::join(currentKnowledge, inferredKnowledge);
- auto newTy = newKnowledge.getType();
- result.setType(newTy);
- return op;
-}
-
class TransposeConvNonStridedConverter
: public OpRewritePattern<tosa::TransposeConv2DOp> {
public:
@@ -187,20 +140,20 @@ class TransposeConvStridedConverter
(weightWidth % stride[1]) ? (stride[1] - weightWidth % stride[1]) : 0;
DenseElementsAttr weightPaddingAttr = DenseIntElementsAttr::get(
RankedTensorType::get({4, 2}, rewriter.getI32Type()), weightPadding);
- Value weightPaddingVal = createOpAndInfer<tosa::ConstOp>(
+ Value weightPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
rewriter, loc, weightPaddingAttr.getType(), weightPaddingAttr);
if (op.getQuantizationInfo().has_value()) {
auto quantInfo = op.getQuantizationInfo().value();
- weight = createOpAndInfer<tosa::PadOp>(
+ weight = CreateOpAndInferShape<tosa::PadOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
weightPaddingVal, nullptr,
rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getWeightZp()));
} else {
- weight = createOpAndInfer<tosa::PadOp>(rewriter, loc,
- UnrankedTensorType::get(weightETy),
- weight, weightPaddingVal);
+ weight = CreateOpAndInferShape<tosa::PadOp>(
+ rewriter, loc, UnrankedTensorType::get(weightETy), weight,
+ weightPaddingVal);
}
weightTy = cast<ShapedType>(weight.getType());
@@ -212,7 +165,7 @@ class TransposeConvStridedConverter
outputChannels, weightHeight / stride[0],
stride[0], weightWidth / stride[1],
stride[1], inputChannels};
- weight = createOpAndInfer<tosa::ReshapeOp>(
+ weight = CreateOpAndInferShape<tosa::ReshapeOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
rewriter.getDenseI64ArrayAttr(weightReshapeDims0));
@@ -221,7 +174,7 @@ class TransposeConvStridedConverter
loc, RankedTensorType::get({6}, rewriter.getI32Type()),
rewriter.getI32TensorAttr({2, 4, 0, 1, 3, 5}));
- weight = createOpAndInfer<tosa::TransposeOp>(
+ weight = CreateOpAndInferShape<tosa::TransposeOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
transposeWeightVal);
@@ -229,15 +182,15 @@ class TransposeConvStridedConverter
llvm::SmallVector<int64_t, 6> weightReshapeDims1 = {
outputChannels * stride[0] * stride[1], weightHeight / stride[0],
weightWidth / stride[1], inputChannels};
- weight = createOpAndInfer<tosa::ReshapeOp>(
+ weight = CreateOpAndInferShape<tosa::ReshapeOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
rewriter.getDenseI64ArrayAttr(weightReshapeDims1));
ShapedType restridedWeightTy = cast<ShapedType>(weight.getType());
- weight = createOpAndInfer<tosa::ReverseOp>(
+ weight = CreateOpAndInferShape<tosa::ReverseOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
/* axis = */ rewriter.getI32IntegerAttr(1));
- weight = createOpAndInfer<tosa::ReverseOp>(
+ weight = CreateOpAndInferShape<tosa::ReverseOp>(
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
/* axis = */ rewriter.getI32IntegerAttr(2));
@@ -251,19 +204,19 @@ class TransposeConvStridedConverter
DenseElementsAttr inputPaddingAttr = DenseIntElementsAttr::get(
RankedTensorType::get({4, 2}, rewriter.getI32Type()), inputPadding);
- Value inputPaddingVal = createOpAndInfer<tosa::ConstOp>(
+ Value inputPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
rewriter, loc, inputPaddingAttr.getType(), inputPaddingAttr);
if (op.getQuantizationInfo().has_value()) {
auto quantInfo = op.getQuantizationInfo().value();
- input = createOpAndInfer<tosa::PadOp>(
+ input = CreateOpAndInferShape<tosa::PadOp>(
rewriter, loc, UnrankedTensorType::get(inputETy), input,
inputPaddingVal, nullptr,
rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getInputZp()));
} else {
- input = createOpAndInfer<tosa::PadOp>(rewriter, loc,
- UnrankedTensorType::get(inputETy),
- input, inputPaddingVal);
+ input = CreateOpAndInferShape<tosa::PadOp>(
+ rewriter, loc, UnrankedTensorType::get(inputETy), input,
+ inputPaddingVal);
}
// We use a zero bias as we need to broadcast the bias.
@@ -279,7 +232,7 @@ class TransposeConvStridedConverter
// Perform the convolution using the zero bias.
Value conv2d;
if (op.getQuantizationInfo()) {
- conv2d = createOpAndInfer<tosa::Conv2DOp>(
+ conv2d = CreateOpAndInferShape<tosa::Conv2DOp>(
rewriter, loc, UnrankedTensorType::get(resultETy), input,
weight, zeroBias,
/*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
@@ -288,7 +241,7 @@ class TransposeConvStridedConverter
*op.getQuantizationInfo())
.getResult();
} else {
- conv2d = createOpAndInfer<tosa::Conv2DOp>(
+ conv2d = CreateOpAndInferShape<tosa::Conv2DOp>(
rewriter, loc, UnrankedTensorType::get(resultETy), input,
weight, zeroBias,
/*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
@@ -307,7 +260,7 @@ class TransposeConvStridedConverter
// Factor striding out of the convolution result.
llvm::SmallVector<int64_t, 6> convReshapeDims0 = {
batch, convHeight, convWidth, stride[0], stride[1], outputChannels};
- conv2d = createOpAndInfer<tosa::ReshapeOp>(
+ conv2d = CreateOpAndInferShape<tosa::ReshapeOp>(
rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
rewriter.getDenseI64ArrayAttr(convReshapeDims0));
@@ -316,14 +269,14 @@ class TransposeConvStridedConverter
loc, RankedTensorType::get({6}, rewriter.getI32Type()),
rewriter.getI32TensorAttr({0, 1, 3, 2, 4, 5}));
- conv2d = createOpAndInfer<tosa::TransposeOp>(
+ conv2d = CreateOpAndInferShape<tosa::TransposeOp>(
rewriter, loc, UnrankedTensorType::get(convETy), conv2d,
transposeConvVal);
// Fuse striding behavior back into width / height.
llvm::SmallVector<int64_t, 6> convReshapeDims1 = {
batch, convHeight * stride[0], convWidth * stride[1], outputChannels};
- conv2d = createOpAndInfer<tosa::ReshapeOp>(
+ conv2d = CreateOpAndInferShape<tosa::ReshapeOp>(
rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
rewriter.getDenseI64ArrayAttr(convReshapeDims1));
@@ -348,7 +301,7 @@ class TransposeConvStridedConverter
sliceSize[1] = resultSliceHeight;
sliceSize[2] = resultSliceWidth;
- auto slice = createOpAndInfer<tosa::SliceOp>(
+ auto slice = CreateOpAndInferShape<tosa::SliceOp>(
rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
rewriter.getDenseI64ArrayAttr(sliceBegin),
rewriter.getDenseI64ArrayAttr(sliceSize))
@@ -363,10 +316,10 @@ class TransposeConvStridedConverter
DenseElementsAttr resultPaddingAttr = DenseIntElementsAttr::get(
RankedTensorType::get({4, 2}, rewriter.getI32Type()), resultPadding);
- Value resultPaddingVal = createOpAndInfer<tosa::ConstOp>(
+ Value resultPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
rewriter, loc, resultPaddingAttr.getType(), resultPaddingAttr);
- Value resultPad = createOpAndInfer<tosa::PadOp>(
+ Value resultPad = CreateOpAndInferShape<tosa::PadOp>(
rewriter, loc, UnrankedTensorType::get(resultETy), slice,
resultPaddingVal);
diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
index f276924a8a9f62..1f6e3b2ab83919 100644
--- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
+++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp
@@ -102,6 +102,12 @@ computeReshapeOutput(ArrayRef<int64_t> higherRankShape,
LogicalResult mlir::tosa::EqualizeRanks(PatternRewriter &rewriter, Location loc,
Value &input1, Value &input2) {
+ ImplicitLocOpBuilder builder(loc, rewriter);
+ return EqualizeRanks(builder, input1, input2);
+}
+
+LogicalResult mlir::tosa::EqualizeRanks(ImplicitLocOpBuilder &builder,
+ Value &input1, Value &input2) {
auto input1Ty = llvm::dyn_cast<RankedTensorType>(input1.getType());
auto input2Ty = llvm::dyn_cast<RankedTensorType>(input2.getType());
@@ -140,9 +146,9 @@ LogicalResult mlir::tosa::EqualizeRanks(PatternRewriter &rewriter, Location loc,
auto reshapeOutputType = RankedTensorType::get(
ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType());
- auto reshapeLower = rewriter.create<tosa::ReshapeOp>(
- loc, reshapeOutputType, lowerTensorValue,
- rewriter.getDenseI64ArrayAttr(reshapeOutputShape));
+ auto reshapeLower = builder.create<tosa::ReshapeOp>(
+ reshapeOutputType, lowerTensorValue,
+ builder.getDenseI64ArrayAttr(reshapeOutputShape));
if (input1Rank > input2Rank) {
input1 = higherTensorValue;
|
// - create operator | ||
// - performs shape inference on this operator | ||
template <typename TosaOp, typename... Args> | ||
TosaOp CreateOpAndInferShape(ImplicitLocOpBuilder &builder, Type result_ty, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Most top-level usages come from a PatternRewriter
; why have a level of indirection to the ImplicitLocOpBuilder
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was needed by TFL lowering code
// - create operator | ||
// - performs shape inference on this operator | ||
template <typename TosaOp, typename... Args> | ||
TosaOp CreateOpAndInferShape(ImplicitLocOpBuilder &builder, Type result_ty, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: resultTy
instead of result_ty
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
|
||
// Compute the new type based on the joined version. | ||
auto newKnowledge = ValueKnowledge::join(currentKnowledge, inferredKnowledge); | ||
Type new_ty = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: newTy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
namespace { | ||
|
||
// Creates a TOSA operation and performs shape inference on the individual | ||
// op. This allows shape inference during the TFLite to TOSA lowering. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better to rephrase this "This allows shape inference when lowering down to TOSA" instead of mentioning explicit framework names?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
This moves CreateOpAndInfer from TF legalize_util.h into ConversionUtils.h Renamed to CreateOpAndInferShape so we can upstream this independently of tensorflow (otherwise a redefinition error would break TF compile if not upstreamed together with removal of CreateOpAndInfer in TF) Signed-off-by: Tai Ly <[email protected]> Change-Id: I53f39ec63f2e3763f8e50c03d1203e8dbed6f1bf
862439a
to
6195e1f
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
This moves CreateOpAndInfer from TF legalize_util.h into ConversionUtils.h
also removed duplicate createOpAndInfer function from TosaDecomposeTransposeConv.cpp
Renamed to CreateOpAndInferShape so we can upstream this independently of tensorflow (otherwise a redefinition error would break TF compile if not upstreamed together with removal of CreateOpAndInfer in TF)