@@ -74,16 +74,25 @@ def Tosa_QuantizedInt : AnyTypeOf<[Tosa_QuantizedType<"uint8", [8], 0>,
74
74
Tosa_QuantizedType<"int16", [16, 0], 1>,
75
75
Tosa_QuantizedType<"int32", [32, 0], 1>]>;
76
76
77
+ def Tosa_F8 : AnyTypeOf<[
78
+ F8E4M3FN,
79
+ F8E5M2]>;
80
+
77
81
//===----------------------------------------------------------------------===//
78
82
// Multi-category types.
79
83
//===----------------------------------------------------------------------===//
80
84
def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat],
81
85
"number">;
82
86
87
+ // Add F8 type support to Tosa_AnyNumber
88
+ def Tosa_AnyNumber_Extended : AnyTypeOf<[Tosa_AnyNumber, Tosa_F8],
89
+ "number_extended">;
90
+
83
91
// For weight tensors from tosa::Conv2DOp, tosa::Conv3DOp,
84
92
// tosa::DepthwiseConv2DOp, tosa::TransposeConv2DOp
85
93
def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8,
86
- Tosa_QuantizedInt, AnyFloat]>;
94
+ Tosa_QuantizedInt, AnyFloat, Tosa_F8]>;
95
+
87
96
88
97
//===----------------------------------------------------------------------===//
89
98
// TOSA Tensor Conformance
@@ -130,9 +139,11 @@ def Tosa_FloatTensor : TosaTensorOf<[AnyFloat]>;
130
139
131
140
// Either ranked or unranked tensor of TOSA supported element types.
132
141
def Tosa_Tensor : TosaTensorOf<[Tosa_AnyNumber]>;
142
+ def Tosa_Tensor_Extended : TosaTensorOf<[Tosa_AnyNumber_Extended]>;
133
143
134
144
// Must be ranked but no further constraints
135
- def Tosa_RankedTensor : TosaRankedTensorOf<[Tosa_AnyNumber]>;
145
+ def Tosa_RankedTensor : RankedTensorOf<[Tosa_AnyNumber]>;
146
+ def Tosa_RankedTensor_Extended : RankedTensorOf<[Tosa_AnyNumber_Extended]>;
136
147
137
148
// Any tensor element type allowed in Tosa ops.
138
149
def Tosa_ElementType : Type<Or<[Tosa_Int.predicate, Tosa_QuantizedInt.predicate,
@@ -145,23 +156,35 @@ class Tosa_TensorOfOrNone<list<Type> allowedTypes, string description = ""> :
145
156
// Tensor types with constrained ranks.
146
157
//===----------------------------------------------------------------------===//
147
158
148
- def Tosa_Rank0Tensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>;
149
-
159
+ // Scalar tensors: Rank-1 (with only one element)
150
160
def Tosa_ScalarTensor : TosaScalarTensorOf<[Tosa_AnyNumber], [1]>;
161
+ def Tosa_ScalarTensor_Extended : TosaScalarTensorOf<[Tosa_AnyNumber_Extended], [1]>;
151
162
def Tosa_ScalarInt8Tensor : TosaScalarTensorOf<[Tosa_Int8], [1]>;
152
163
153
164
// We include unranked tensors as a supported type for all possible tosa
154
165
// Tensors as unranked does not guarantee invalid. If unranked tensors exist
155
166
// they should be shape propagate used Tosa's shape inference pass and verified
156
167
// to not include any remaining unranked tensors.
157
168
def Tosa_UnrankedTensor : TosaUnrankedTensorOf<[Tosa_AnyNumber]>;
169
+ def Tosa_UnrankedTensorExtended : TosaUnrankedTensorOf<[Tosa_AnyNumber_Extended]>;
158
170
159
171
def Tosa_Tensor1D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1]>], "1-d tosa-conformant tensor", "::mlir::TensorType">;
160
172
def Tosa_Tensor2D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [2]>], "2-d tosa-conformant tensor", "::mlir::TensorType">;
161
173
def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [3]>], "3-d tosa-conformant tensor", "::mlir::TensorType">;
162
174
def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [4]>], "4-d tosa-conformant tensor", "::mlir::TensorType">;
163
175
def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [5]>], "5-d tosa-conformant tensor", "::mlir::TensorType">;
164
176
177
+ def Tosa_Tensor1D_Extended : AnyTypeOf<[Tosa_UnrankedTensorExtended, TosaTensorRankOf<[Tosa_AnyNumber_Extended], [1]>],
178
+ "1-d tosa-conformant tensor extended", "::mlir::TensorType">;
179
+ def Tosa_Tensor2D_Extended : AnyTypeOf<[Tosa_UnrankedTensorExtended, TosaTensorRankOf<[Tosa_AnyNumber_Extended], [2]>],
180
+ "2-d tosa-conformant tensor extended", "::mlir::TensorType">;
181
+ def Tosa_Tensor3D_Extended : AnyTypeOf<[Tosa_UnrankedTensorExtended, TosaTensorRankOf<[Tosa_AnyNumber_Extended], [3]>],
182
+ "3-d tosa-conformant tensor extended", "::mlir::TensorType">;
183
+ def Tosa_Tensor4D_Extended : AnyTypeOf<[Tosa_UnrankedTensorExtended, TosaTensorRankOf<[Tosa_AnyNumber_Extended], [4]>],
184
+ "4-d tosa-conformant tensor extended", "::mlir::TensorType">;
185
+ def Tosa_Tensor5D_Extended : AnyTypeOf<[Tosa_UnrankedTensorExtended, TosaTensorRankOf<[Tosa_AnyNumber_Extended], [5]>],
186
+ "5-d tosa-conformant tensor extended", "::mlir::TensorType">;
187
+
165
188
// Ranked tensors up to given rank.
166
189
def Tosa_Tensor1Dto4D : AnyTypeOf<[
167
190
Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1,2,3,4]>]>;
0 commit comments