Skip to content

Commit c883452

Browse files
authored
[TOSA] Move CreateOpAndInfer into ConversionUtils.h (#106122)
This moves CreateOpAndInfer from TF legalize_util.h into ConversionUtils.h also removed duplicate createOpAndInfer function from TosaDecomposeTransposeConv.cpp Renamed to CreateOpAndInferShape so we can upstream this independently of tensorflow (otherwise a redefinition error would break TF compile if not upstreamed together with removal of CreateOpAndInfer in TF) --------- Signed-off-by: Tai Ly <[email protected]>
1 parent 6d37259 commit c883452

File tree

3 files changed

+169
-73
lines changed

3 files changed

+169
-73
lines changed

mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515

1616
#include "mlir/Dialect/Arith/IR/Arith.h"
1717
#include "mlir/Dialect/Tensor/IR/Tensor.h"
18+
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
1819
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
20+
#include "mlir/IR/ImplicitLocOpBuilder.h"
1921
#include "mlir/IR/PatternMatch.h"
2022
#include <optional>
2123

@@ -79,6 +81,141 @@ checkHasDynamicBatchDims(PatternRewriter &rewriter, Op op,
7981
LogicalResult EqualizeRanks(PatternRewriter &rewriter, Location loc,
8082
Value &input1, Value &input2);
8183

84+
LogicalResult EqualizeRanks(ImplicitLocOpBuilder &builder, Value &input1,
85+
Value &input2);
86+
87+
namespace {
88+
89+
// Creates a TOSA operation and performs shape inference on the individual
90+
// op. This allows shape inference when lowering down to TOSA.
91+
template <typename TosaOp, typename... Args>
92+
TosaOp createOpAndInferShape(ImplicitLocOpBuilder &builder, Type resultTy,
93+
Args &&...args) {
94+
auto op = builder.create<TosaOp>(resultTy, args...);
95+
96+
InferShapedTypeOpInterface shapeInterface =
97+
dyn_cast<InferShapedTypeOpInterface>(op.getOperation());
98+
if (!shapeInterface)
99+
return op;
100+
101+
SmallVector<ShapedTypeComponents> returnedShapes;
102+
if (shapeInterface
103+
.inferReturnTypeComponents(op.getContext(), builder.getLoc(),
104+
op->getOperands(), op->getAttrDictionary(),
105+
op->getPropertiesStorage(),
106+
op->getRegions(), returnedShapes)
107+
.failed())
108+
return op;
109+
110+
// We need to use the element type of the existing result type to generate
111+
// the new result shaped type. This is because rescale can include a cast to
112+
// different bit-width types and does not have a TypeAttr to define the
113+
// target type.
114+
auto result = op->getResult(0);
115+
auto predictedShape = returnedShapes[0];
116+
auto currentKnowledge = ValueKnowledge::getKnowledgeFromType(resultTy);
117+
118+
// Compute the knowledge based on the inferred type.
119+
auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
120+
inferredKnowledge.dtype = mlir::cast<ShapedType>(resultTy).getElementType();
121+
inferredKnowledge.hasRank = predictedShape.hasRank();
122+
if (predictedShape.hasRank()) {
123+
for (auto dim : predictedShape.getDims()) {
124+
inferredKnowledge.sizes.push_back(dim);
125+
}
126+
}
127+
128+
// Compute the new type based on the joined version.
129+
auto newKnowledge = ValueKnowledge::join(currentKnowledge, inferredKnowledge);
130+
Type newTy =
131+
newKnowledge.hasRank
132+
? Type{mlir::RankedTensorType::get(llvm::ArrayRef(newKnowledge.sizes),
133+
newKnowledge.dtype)}
134+
: Type{mlir::UnrankedTensorType::get(newKnowledge.dtype)};
135+
result.setType(newTy);
136+
return op;
137+
}
138+
139+
} // namespace
140+
141+
// Creates a TOSA operation by:
142+
// - first equalize ranks for ops with SameOperandsAndResultRank trait
143+
// - create operator
144+
// - performs shape inference on this operator
145+
template <typename TosaOp, typename... Args>
146+
TosaOp CreateOpAndInferShape(ImplicitLocOpBuilder &builder, Type resultTy,
147+
Args &&...args) {
148+
if (TosaOp::template hasTrait<OpTrait::SameOperandsAndResultRank>()) {
149+
// op requires same ranks for tensor operands
150+
if constexpr (sizeof...(Args) == 2) {
151+
auto argX = std::get<0>(std::tie(args...));
152+
auto argY = std::get<1>(std::tie(args...));
153+
using ArgX = decltype(argX);
154+
using ArgY = decltype(argY);
155+
if constexpr (std::is_same_v<ArgX, Value> &&
156+
std::is_same_v<ArgY, Value>) {
157+
Value x = std::get<0>(std::tie(args...));
158+
Value y = std::get<1>(std::tie(args...));
159+
if (EqualizeRanks(builder, x, y).failed()) {
160+
// incompatible broadcast shapes, no reshape is inserted
161+
// ResultsBroadcastableShape verify will handle this
162+
}
163+
return createOpAndInferShape<TosaOp>(builder, resultTy, x, y);
164+
}
165+
}
166+
if constexpr (sizeof...(Args) == 3) {
167+
auto argX = std::get<0>(std::tie(args...));
168+
auto argY = std::get<1>(std::tie(args...));
169+
auto argZ = std::get<2>(std::tie(args...));
170+
using ArgX = decltype(argX);
171+
using ArgY = decltype(argY);
172+
using ArgZ = decltype(argZ);
173+
if constexpr (std::is_same_v<ArgX, Value> &&
174+
std::is_same_v<ArgY, Value> && std::is_same_v<ArgZ, bool>) {
175+
// special case for ArithmeticRightShiftOp
176+
Value x = std::get<0>(std::tie(args...));
177+
Value y = std::get<1>(std::tie(args...));
178+
bool round = std::get<2>(std::tie(args...));
179+
if (EqualizeRanks(builder, x, y).failed()) {
180+
// incompatible broadcast shapes, no reshape is inserted
181+
// ResultsBroadcastableShape verify will handle this
182+
}
183+
return createOpAndInferShape<TosaOp>(builder, resultTy, x, y, round);
184+
}
185+
if constexpr (std::is_same_v<ArgX, Value> &&
186+
std::is_same_v<ArgY, Value> &&
187+
std::is_same_v<ArgZ, Value>) {
188+
// special case for Select
189+
Value x = std::get<0>(std::tie(args...));
190+
Value y = std::get<1>(std::tie(args...));
191+
Value z = std::get<2>(std::tie(args...));
192+
193+
if (EqualizeRanks(builder, x, y).failed() ||
194+
EqualizeRanks(builder, x, z).failed() ||
195+
EqualizeRanks(builder, y, z).failed()) {
196+
// incompatible broadcast shapes, no reshape is inserted
197+
// ResultsBroadcastableShape verify will handle this
198+
}
199+
200+
return createOpAndInferShape<TosaOp>(builder, resultTy, x, y, z);
201+
}
202+
}
203+
}
204+
205+
return createOpAndInferShape<TosaOp>(builder, resultTy, args...);
206+
}
207+
208+
// Creates a TOSA operation by:
209+
// - first equalize ranks for ops with SameOperandsAndResultRank trait
210+
// - create operator
211+
// - performs shape inference on this operator
212+
template <typename TosaOp, typename... Args>
213+
TosaOp CreateOpAndInferShape(PatternRewriter &rewriter, Location loc,
214+
Type resultTy, Args &&...args) {
215+
ImplicitLocOpBuilder builder(loc, rewriter);
216+
return CreateOpAndInferShape<TosaOp>(builder, resultTy, args...);
217+
}
218+
82219
} // namespace tosa
83220
} // namespace mlir
84221

mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp

Lines changed: 23 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -26,53 +26,6 @@ using namespace mlir::tosa;
2626

2727
namespace {
2828

29-
template <typename TosaOp, typename... Args>
30-
TosaOp createOpAndInfer(PatternRewriter &rewriter, Location loc, Type resultTy,
31-
Args &&...args) {
32-
auto op = rewriter.create<TosaOp>(loc, resultTy, args...);
33-
34-
InferShapedTypeOpInterface shapeInterface =
35-
dyn_cast<InferShapedTypeOpInterface>(op.getOperation());
36-
if (!shapeInterface)
37-
return op;
38-
39-
SmallVector<ShapedTypeComponents> returnedShapes;
40-
if (shapeInterface
41-
.inferReturnTypeComponents(
42-
op.getContext(), op.getLoc(), op->getOperands(),
43-
op->getDiscardableAttrDictionary(), op->getPropertiesStorage(),
44-
op->getRegions(), returnedShapes)
45-
.failed())
46-
return op;
47-
48-
// We need to use the element type of the existing result type to generate
49-
// the new result shaped type. This is because rescale can include a cast to
50-
// different bit-width types and does not have a TypeAttr to define the
51-
// target type.
52-
auto result = op->getResult(0);
53-
auto predictedShape = returnedShapes[0];
54-
auto currentKnowledge =
55-
mlir::tosa::ValueKnowledge::getKnowledgeFromType(resultTy);
56-
57-
// Compute the knowledge based on the inferred type.
58-
auto inferredKnowledge =
59-
mlir::tosa::ValueKnowledge::getPessimisticValueState();
60-
inferredKnowledge.dtype = cast<ShapedType>(resultTy).getElementType();
61-
inferredKnowledge.hasRank = predictedShape.hasRank();
62-
if (predictedShape.hasRank()) {
63-
for (auto dim : predictedShape.getDims()) {
64-
inferredKnowledge.sizes.push_back(dim);
65-
}
66-
}
67-
68-
// Compute the new type based on the joined version.
69-
auto newKnowledge =
70-
mlir::tosa::ValueKnowledge::join(currentKnowledge, inferredKnowledge);
71-
auto newTy = newKnowledge.getType();
72-
result.setType(newTy);
73-
return op;
74-
}
75-
7629
class TransposeConvNonStridedConverter
7730
: public OpRewritePattern<tosa::TransposeConv2DOp> {
7831
public:
@@ -187,20 +140,20 @@ class TransposeConvStridedConverter
187140
(weightWidth % stride[1]) ? (stride[1] - weightWidth % stride[1]) : 0;
188141
DenseElementsAttr weightPaddingAttr = DenseIntElementsAttr::get(
189142
RankedTensorType::get({4, 2}, rewriter.getI32Type()), weightPadding);
190-
Value weightPaddingVal = createOpAndInfer<tosa::ConstOp>(
143+
Value weightPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
191144
rewriter, loc, weightPaddingAttr.getType(), weightPaddingAttr);
192145

193146
if (op.getQuantizationInfo().has_value()) {
194147
auto quantInfo = op.getQuantizationInfo().value();
195-
weight = createOpAndInfer<tosa::PadOp>(
148+
weight = CreateOpAndInferShape<tosa::PadOp>(
196149
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
197150
weightPaddingVal, nullptr,
198151
rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getWeightZp()));
199152

200153
} else {
201-
weight = createOpAndInfer<tosa::PadOp>(rewriter, loc,
202-
UnrankedTensorType::get(weightETy),
203-
weight, weightPaddingVal);
154+
weight = CreateOpAndInferShape<tosa::PadOp>(
155+
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
156+
weightPaddingVal);
204157
}
205158

206159
weightTy = cast<ShapedType>(weight.getType());
@@ -212,7 +165,7 @@ class TransposeConvStridedConverter
212165
outputChannels, weightHeight / stride[0],
213166
stride[0], weightWidth / stride[1],
214167
stride[1], inputChannels};
215-
weight = createOpAndInfer<tosa::ReshapeOp>(
168+
weight = CreateOpAndInferShape<tosa::ReshapeOp>(
216169
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
217170
rewriter.getDenseI64ArrayAttr(weightReshapeDims0));
218171

@@ -221,23 +174,23 @@ class TransposeConvStridedConverter
221174
loc, RankedTensorType::get({6}, rewriter.getI32Type()),
222175
rewriter.getI32TensorAttr({2, 4, 0, 1, 3, 5}));
223176

224-
weight = createOpAndInfer<tosa::TransposeOp>(
177+
weight = CreateOpAndInferShape<tosa::TransposeOp>(
225178
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
226179
transposeWeightVal);
227180

228181
// Collapse the strides and output channels into a single dimension.
229182
llvm::SmallVector<int64_t, 6> weightReshapeDims1 = {
230183
outputChannels * stride[0] * stride[1], weightHeight / stride[0],
231184
weightWidth / stride[1], inputChannels};
232-
weight = createOpAndInfer<tosa::ReshapeOp>(
185+
weight = CreateOpAndInferShape<tosa::ReshapeOp>(
233186
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
234187
rewriter.getDenseI64ArrayAttr(weightReshapeDims1));
235188
ShapedType restridedWeightTy = cast<ShapedType>(weight.getType());
236189

237-
weight = createOpAndInfer<tosa::ReverseOp>(
190+
weight = CreateOpAndInferShape<tosa::ReverseOp>(
238191
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
239192
/* axis = */ rewriter.getI32IntegerAttr(1));
240-
weight = createOpAndInfer<tosa::ReverseOp>(
193+
weight = CreateOpAndInferShape<tosa::ReverseOp>(
241194
rewriter, loc, UnrankedTensorType::get(weightETy), weight,
242195
/* axis = */ rewriter.getI32IntegerAttr(2));
243196

@@ -251,19 +204,19 @@ class TransposeConvStridedConverter
251204
DenseElementsAttr inputPaddingAttr = DenseIntElementsAttr::get(
252205
RankedTensorType::get({4, 2}, rewriter.getI32Type()), inputPadding);
253206

254-
Value inputPaddingVal = createOpAndInfer<tosa::ConstOp>(
207+
Value inputPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
255208
rewriter, loc, inputPaddingAttr.getType(), inputPaddingAttr);
256209

257210
if (op.getQuantizationInfo().has_value()) {
258211
auto quantInfo = op.getQuantizationInfo().value();
259-
input = createOpAndInfer<tosa::PadOp>(
212+
input = CreateOpAndInferShape<tosa::PadOp>(
260213
rewriter, loc, UnrankedTensorType::get(inputETy), input,
261214
inputPaddingVal, nullptr,
262215
rewriter.getAttr<PadOpQuantizationAttr>(quantInfo.getInputZp()));
263216
} else {
264-
input = createOpAndInfer<tosa::PadOp>(rewriter, loc,
265-
UnrankedTensorType::get(inputETy),
266-
input, inputPaddingVal);
217+
input = CreateOpAndInferShape<tosa::PadOp>(
218+
rewriter, loc, UnrankedTensorType::get(inputETy), input,
219+
inputPaddingVal);
267220
}
268221

269222
// We use a zero bias as we need to broadcast the bias.
@@ -279,7 +232,7 @@ class TransposeConvStridedConverter
279232
// Perform the convolution using the zero bias.
280233
Value conv2d;
281234
if (op.getQuantizationInfo()) {
282-
conv2d = createOpAndInfer<tosa::Conv2DOp>(
235+
conv2d = CreateOpAndInferShape<tosa::Conv2DOp>(
283236
rewriter, loc, UnrankedTensorType::get(resultETy), input,
284237
weight, zeroBias,
285238
/*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
@@ -288,7 +241,7 @@ class TransposeConvStridedConverter
288241
*op.getQuantizationInfo())
289242
.getResult();
290243
} else {
291-
conv2d = createOpAndInfer<tosa::Conv2DOp>(
244+
conv2d = CreateOpAndInferShape<tosa::Conv2DOp>(
292245
rewriter, loc, UnrankedTensorType::get(resultETy), input,
293246
weight, zeroBias,
294247
/*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
@@ -307,7 +260,7 @@ class TransposeConvStridedConverter
307260
// Factor striding out of the convolution result.
308261
llvm::SmallVector<int64_t, 6> convReshapeDims0 = {
309262
batch, convHeight, convWidth, stride[0], stride[1], outputChannels};
310-
conv2d = createOpAndInfer<tosa::ReshapeOp>(
263+
conv2d = CreateOpAndInferShape<tosa::ReshapeOp>(
311264
rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
312265
rewriter.getDenseI64ArrayAttr(convReshapeDims0));
313266

@@ -316,14 +269,14 @@ class TransposeConvStridedConverter
316269
loc, RankedTensorType::get({6}, rewriter.getI32Type()),
317270
rewriter.getI32TensorAttr({0, 1, 3, 2, 4, 5}));
318271

319-
conv2d = createOpAndInfer<tosa::TransposeOp>(
272+
conv2d = CreateOpAndInferShape<tosa::TransposeOp>(
320273
rewriter, loc, UnrankedTensorType::get(convETy), conv2d,
321274
transposeConvVal);
322275

323276
// Fuse striding behavior back into width / height.
324277
llvm::SmallVector<int64_t, 6> convReshapeDims1 = {
325278
batch, convHeight * stride[0], convWidth * stride[1], outputChannels};
326-
conv2d = createOpAndInfer<tosa::ReshapeOp>(
279+
conv2d = CreateOpAndInferShape<tosa::ReshapeOp>(
327280
rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
328281
rewriter.getDenseI64ArrayAttr(convReshapeDims1));
329282

@@ -348,7 +301,7 @@ class TransposeConvStridedConverter
348301
sliceSize[1] = resultSliceHeight;
349302
sliceSize[2] = resultSliceWidth;
350303

351-
auto slice = createOpAndInfer<tosa::SliceOp>(
304+
auto slice = CreateOpAndInferShape<tosa::SliceOp>(
352305
rewriter, loc, UnrankedTensorType::get(resultETy), conv2d,
353306
rewriter.getDenseI64ArrayAttr(sliceBegin),
354307
rewriter.getDenseI64ArrayAttr(sliceSize))
@@ -363,10 +316,10 @@ class TransposeConvStridedConverter
363316
DenseElementsAttr resultPaddingAttr = DenseIntElementsAttr::get(
364317
RankedTensorType::get({4, 2}, rewriter.getI32Type()), resultPadding);
365318

366-
Value resultPaddingVal = createOpAndInfer<tosa::ConstOp>(
319+
Value resultPaddingVal = CreateOpAndInferShape<tosa::ConstOp>(
367320
rewriter, loc, resultPaddingAttr.getType(), resultPaddingAttr);
368321

369-
Value resultPad = createOpAndInfer<tosa::PadOp>(
322+
Value resultPad = CreateOpAndInferShape<tosa::PadOp>(
370323
rewriter, loc, UnrankedTensorType::get(resultETy), slice,
371324
resultPaddingVal);
372325

mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,12 @@ computeReshapeOutput(ArrayRef<int64_t> higherRankShape,
102102

103103
LogicalResult mlir::tosa::EqualizeRanks(PatternRewriter &rewriter, Location loc,
104104
Value &input1, Value &input2) {
105+
ImplicitLocOpBuilder builder(loc, rewriter);
106+
return EqualizeRanks(builder, input1, input2);
107+
}
108+
109+
LogicalResult mlir::tosa::EqualizeRanks(ImplicitLocOpBuilder &builder,
110+
Value &input1, Value &input2) {
105111
auto input1Ty = llvm::dyn_cast<RankedTensorType>(input1.getType());
106112
auto input2Ty = llvm::dyn_cast<RankedTensorType>(input2.getType());
107113

@@ -140,9 +146,9 @@ LogicalResult mlir::tosa::EqualizeRanks(PatternRewriter &rewriter, Location loc,
140146
auto reshapeOutputType = RankedTensorType::get(
141147
ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType());
142148

143-
auto reshapeLower = rewriter.create<tosa::ReshapeOp>(
144-
loc, reshapeOutputType, lowerTensorValue,
145-
rewriter.getDenseI64ArrayAttr(reshapeOutputShape));
149+
auto reshapeLower = builder.create<tosa::ReshapeOp>(
150+
reshapeOutputType, lowerTensorValue,
151+
builder.getDenseI64ArrayAttr(reshapeOutputShape));
146152

147153
if (input1Rank > input2Rank) {
148154
input1 = higherTensorValue;

0 commit comments

Comments
 (0)