10
10
#ifndef TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZEUTILS_H
11
11
#define TORCHMLIR_CONVERSION_TORCHTOTOSA_TOSALEGALIZEUTILS_H
12
12
13
- #include " mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project
14
- #include " mlir/Dialect/Tosa/Utils/ShapeUtils.h" // from @llvm-project
15
- #include " mlir/IR/BuiltinAttributes.h" // from @llvm-project
16
- #include " mlir/IR/BuiltinTypes.h" // from @llvm-project
17
- #include " mlir/IR/PatternMatch.h" // from @llvm-project
18
- #include " mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
19
- #include " mlir/Support/LLVM.h" // from @llvm-project
13
+ #include " mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project
14
+ #include " mlir/Dialect/Tosa/Utils/ConversionUtils.h" // from @llvm-project
15
+ #include " mlir/Dialect/Tosa/Utils/ShapeUtils.h" // from @llvm-project
16
+ #include " mlir/IR/BuiltinAttributes.h" // from @llvm-project
17
+ #include " mlir/IR/BuiltinTypes.h" // from @llvm-project
18
+ #include " mlir/IR/PatternMatch.h" // from @llvm-project
19
+ #include " mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
20
+ #include " mlir/Support/LLVM.h" // from @llvm-project
20
21
21
22
namespace mlir {
22
23
namespace tosa {
@@ -45,6 +46,10 @@ bool isScale32(mlir::quant::UniformQuantizedType output_element_type);
45
46
Value getTosaConstTensorSingleF32 (PatternRewriter &rewriter, Operation *op,
46
47
float val);
47
48
49
+ // Create an int8_t const tosa.mul shift tensor from an int
50
+ Value getTosaMulShiftConstTensor (PatternRewriter &rewriter, Operation *op,
51
+ int32_t shift);
52
+
48
53
// Create a zero constant tensor of the desired type and shape.
49
54
std::optional<Value> getZerosLikeTensor (PatternRewriter &rewriter,
50
55
Operation *op, Type type);
@@ -58,55 +63,24 @@ std::optional<Value> getConstTensor(PatternRewriter &rewriter, Operation *op,
58
63
ArrayRef<T> vec, ArrayRef<int64_t > shape,
59
64
std::optional<Type> dtype = {});
60
65
61
- LogicalResult tosaCastTensorToType (PatternRewriter &rewriter, Operation *op,
62
- Value src, Type destType, Value &result);
63
-
64
- Value promoteType (PatternRewriter &rewriter, Value input, TensorType outType );
66
+ // Default function to create tosa.cast op. This should be called instead of
67
+ // directly calling rewriter.create<tosa::CastOp>.
68
+ std::optional<Value> tosaCastTensorToType (PatternRewriter &rewriter, Value src,
69
+ TensorType destType );
65
70
66
71
// Creates a TOSA operation and performs shape inference on the individual
67
72
// op. This allows shape inference during the framework to TOSA lowering.
73
+ template <typename TosaOp, typename ... Args>
74
+ TosaOp CreateOpAndInfer (ImplicitLocOpBuilder &builder, Type result_ty,
75
+ Args &&...args) {
76
+ return CreateOpAndInferShape<TosaOp>(builder, result_ty, args...);
77
+ }
78
+
68
79
template <typename TosaOp, typename ... Args>
69
80
TosaOp CreateOpAndInfer (PatternRewriter &rewriter, Location loc, Type result_ty,
70
81
Args &&...args) {
71
- auto op = rewriter.create <TosaOp>(loc, result_ty, args...);
72
-
73
- InferShapedTypeOpInterface shapeInterface =
74
- dyn_cast<InferShapedTypeOpInterface>(op.getOperation ());
75
- if (!shapeInterface)
76
- return op;
77
-
78
- SmallVector<ShapedTypeComponents> returnedShapes;
79
- if (shapeInterface
80
- .inferReturnTypeComponents (op.getContext (), op.getLoc (),
81
- op->getOperands (), op->getAttrDictionary (),
82
- op->getPropertiesStorage (),
83
- op->getRegions (), returnedShapes)
84
- .failed ())
85
- return op;
86
-
87
- // We need to use the element type of the existing result type to generate
88
- // the new result shaped type. This is because rescale can include a cast to
89
- // different bit-width types and does not have a TypeAttr to define the
90
- // target type.
91
- auto result = op->getResult (0 );
92
- auto predictedShape = returnedShapes[0 ];
93
- auto currentKnowledge = ValueKnowledge::getKnowledgeFromType (result_ty);
94
-
95
- // Compute the knowledge based on the inferred type.
96
- auto inferredKnowledge = ValueKnowledge::getPessimisticValueState ();
97
- inferredKnowledge.dtype = cast<ShapedType>(result_ty).getElementType ();
98
- inferredKnowledge.hasRank = predictedShape.hasRank ();
99
- if (predictedShape.hasRank ()) {
100
- for (auto dim : predictedShape.getDims ()) {
101
- inferredKnowledge.sizes .push_back (dim);
102
- }
103
- }
104
-
105
- // Compute the new type based on the joined version.
106
- auto newKnowledge = ValueKnowledge::join (currentKnowledge, inferredKnowledge);
107
- auto new_ty = newKnowledge.getType ();
108
- result.setType (new_ty);
109
- return op;
82
+ ImplicitLocOpBuilder builder (loc, rewriter);
83
+ return CreateOpAndInfer<TosaOp>(builder, result_ty, args...);
110
84
}
111
85
112
86
template <typename TosaOp, typename ... Args>
0 commit comments