Skip to content

Commit 8dea784

Browse files
committed
[mlir][tosa] Add tosa shape inference with InferReturnTypeComponent
Added InferReturnTypeComponents for NAry operations, reshape, and reverse. With the additional tosa-infer-shapes pass, we can infer/propagate shapes across a set of TOSA operations. Current version does not modify the FuncOp type by inserting an unrealized conversion cast prior to any new non-matchin returns. Differential Revision: https://reviews.llvm.org/D105312
1 parent 0176ac9 commit 8dea784

File tree

11 files changed

+835
-64
lines changed

11 files changed

+835
-64
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#include "mlir/Dialect/Quant/QuantOps.h"
1717
#include "mlir/Dialect/Traits.h"
18+
#include "mlir/Interfaces/InferTypeOpInterface.h"
1819
#include "mlir/Interfaces/LoopLikeInterface.h"
1920
#include "mlir/Interfaces/SideEffectInterfaces.h"
2021

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 145 additions & 62 deletions
Large diffs are not rendered by default.

mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@
1313
#ifndef MLIR_DIALECT_TOSA_TRANSFORMS_PASSES_H
1414
#define MLIR_DIALECT_TOSA_TRANSFORMS_PASSES_H
1515

16+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1617
#include "mlir/Pass/Pass.h"
1718

1819
namespace mlir {
1920
namespace tosa {
2021

22+
std::unique_ptr<Pass> createTosaInferShapesPass();
2123
std::unique_ptr<Pass> createTosaMakeBroadcastablePass();
2224
std::unique_ptr<Pass> createTosaTestQuantUtilAPIPass();
2325

mlir/include/mlir/Dialect/Tosa/Transforms/Passes.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,21 @@
1515

1616
include "mlir/Pass/PassBase.td"
1717

18+
def TosaInferShapes : FunctionPass<"tosa-infer-shapes"> {
19+
let summary = "Propagate shapes across TOSA operations";
20+
let description = [{
21+
Pass that uses operand types and propagates shapes to TOSA operations.
22+
This includes legalizing rankless and dynamic shapes towards static.
23+
}];
24+
25+
let constructor = "createTosaInferShapesPass()";
26+
let dependentDialects = [
27+
"StandardOpsDialect",
28+
"tensor::TensorDialect",
29+
"tosa::TosaDialect",
30+
];
31+
}
32+
1833
def TosaMakeBroadcastable : FunctionPass<"tosa-make-broadcastable"> {
1934
let summary = "TOSA rank Reshape to enable Broadcasting";
2035
let description = [{

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,148 @@ static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result,
291291
result.types.push_back(outputType);
292292
}
293293

294+
//===----------------------------------------------------------------------===//
295+
// TOSA Operator Return Type Inference.
296+
//===----------------------------------------------------------------------===//
297+
298+
static void getI64Values(ArrayAttr arrayAttr, SmallVector<int64_t> &values) {
299+
for (auto it : arrayAttr) {
300+
values.push_back(it.cast<IntegerAttr>().getValue().getSExtValue());
301+
}
302+
}
303+
304+
LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
305+
MLIRContext *context, ::llvm::Optional<Location> location,
306+
ValueRange operands, DictionaryAttr attributes, RegionRange regions,
307+
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
308+
ShapedType type = operands.front().getType().cast<ShapedType>();
309+
310+
auto newShape = attributes.get("new_shape").cast<ArrayAttr>();
311+
llvm::SmallVector<int64_t> newShapeValue;
312+
getI64Values(newShape, newShapeValue);
313+
314+
// We cannot infer from the total number of elements so we must take the
315+
// shape attribute as exact.
316+
if (!type.hasRank() || !type.hasStaticShape()) {
317+
inferredReturnShapes.push_back(ShapedTypeComponents(newShapeValue));
318+
return success();
319+
}
320+
321+
// Determine the number of elements covered by the slice of all static
322+
// dimensions. This allows us to infer the length of the remaining dynamic
323+
// dimension.
324+
int64_t numElements = type.getNumElements();
325+
int64_t staticMul = 1;
326+
for (auto val : newShapeValue) {
327+
if (val != -1) {
328+
staticMul *= val;
329+
}
330+
}
331+
332+
// Determine the length of the dynamic dimension.
333+
for (auto &val : newShapeValue) {
334+
if (val == -1)
335+
val = numElements / staticMul;
336+
}
337+
338+
inferredReturnShapes.push_back(ShapedTypeComponents(newShapeValue));
339+
return success();
340+
}
341+
342+
static LogicalResult resolveBroadcastShape(ValueRange operands,
343+
SmallVector<int64_t> &outShape) {
344+
int64_t outRank = 0;
345+
for (auto operand : operands) {
346+
auto type = operand.getType().cast<ShapedType>();
347+
if (!type.hasRank())
348+
return failure();
349+
outRank = std::max<int64_t>(outRank, type.getRank());
350+
}
351+
352+
outShape.resize(outRank, 1);
353+
354+
for (auto operand : operands) {
355+
auto type = operand.getType().cast<ShapedType>();
356+
auto shape = type.getShape();
357+
auto rankDiff = outShape.size() - shape.size();
358+
359+
for (size_t i = 0; i < shape.size(); i++) {
360+
auto dim1 = outShape[i + rankDiff];
361+
auto dim2 = shape[i];
362+
auto resolvedDim = dim1;
363+
364+
if (dim1 == 1) {
365+
resolvedDim = dim2;
366+
} else if (dim2 == 1) {
367+
resolvedDim = dim1;
368+
} else if (dim1 != dim2) {
369+
return failure();
370+
}
371+
outShape[i + rankDiff] = resolvedDim;
372+
}
373+
}
374+
375+
return success();
376+
}
377+
378+
static LogicalResult NAryInferReturnTypes(
379+
ValueRange operands,
380+
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
381+
llvm::SmallVector<int64_t> outShape;
382+
if (resolveBroadcastShape(operands, outShape).failed()) {
383+
inferredReturnShapes.push_back(ShapedTypeComponents());
384+
} else {
385+
inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
386+
}
387+
return success();
388+
}
389+
390+
#define NARY_SHAPE_INFER(OP) \
391+
LogicalResult OP::inferReturnTypeComponents( \
392+
MLIRContext *context, ::llvm::Optional<Location> location, \
393+
ValueRange operands, DictionaryAttr attributes, RegionRange regions, \
394+
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { \
395+
return NAryInferReturnTypes(operands, inferredReturnShapes); \
396+
}
397+
398+
NARY_SHAPE_INFER(tosa::AbsOp)
399+
NARY_SHAPE_INFER(tosa::AddOp)
400+
NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp)
401+
NARY_SHAPE_INFER(tosa::BitwiseAndOp)
402+
NARY_SHAPE_INFER(tosa::BitwiseOrOp)
403+
NARY_SHAPE_INFER(tosa::BitwiseXorOp)
404+
NARY_SHAPE_INFER(tosa::BitwiseNotOp)
405+
NARY_SHAPE_INFER(tosa::CeilOp)
406+
NARY_SHAPE_INFER(tosa::ClampOp)
407+
NARY_SHAPE_INFER(tosa::ClzOp)
408+
NARY_SHAPE_INFER(tosa::DivOp)
409+
NARY_SHAPE_INFER(tosa::EqualOp)
410+
NARY_SHAPE_INFER(tosa::ExpOp)
411+
NARY_SHAPE_INFER(tosa::FloorOp)
412+
NARY_SHAPE_INFER(tosa::GreaterEqualOp)
413+
NARY_SHAPE_INFER(tosa::GreaterOp)
414+
NARY_SHAPE_INFER(tosa::LogOp)
415+
NARY_SHAPE_INFER(tosa::LogicalAndOp)
416+
NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp)
417+
NARY_SHAPE_INFER(tosa::LogicalNotOp)
418+
NARY_SHAPE_INFER(tosa::LogicalOrOp)
419+
NARY_SHAPE_INFER(tosa::LogicalRightShiftOp)
420+
NARY_SHAPE_INFER(tosa::LogicalXorOp)
421+
NARY_SHAPE_INFER(tosa::MaximumOp)
422+
NARY_SHAPE_INFER(tosa::MinimumOp)
423+
NARY_SHAPE_INFER(tosa::MulOp)
424+
NARY_SHAPE_INFER(tosa::NegateOp)
425+
NARY_SHAPE_INFER(tosa::PowOp)
426+
NARY_SHAPE_INFER(tosa::ReciprocalOp)
427+
NARY_SHAPE_INFER(tosa::ReluNOp)
428+
NARY_SHAPE_INFER(tosa::ReverseOp)
429+
NARY_SHAPE_INFER(tosa::RsqrtOp)
430+
NARY_SHAPE_INFER(tosa::SelectOp)
431+
NARY_SHAPE_INFER(tosa::SubOp)
432+
NARY_SHAPE_INFER(tosa::TanhOp)
433+
NARY_SHAPE_INFER(tosa::SigmoidOp)
434+
#undef PRED_SHAPE_INFER
435+
294436
//===----------------------------------------------------------------------===//
295437
// TOSA Operator Definitions.
296438
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Tosa/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_mlir_dialect_library(MLIRTosaTransforms
2+
TosaInferShapes.cpp
23
TosaMakeBroadcastable.cpp
34

45
ADDITIONAL_HEADER_DIRS

0 commit comments

Comments
 (0)