Skip to content

Commit a54efdb

Browse files
authored
[MLIR][TOSA] add additional verification to TOSA (#108133)
---------- Motivation: ---------- Spec conformance. Allows assumptions to be made in TOSA code. ------------ Changes Made: ------------ Add full permutation tensor verification to tosa.TRANSPOSE. Priorly would not verify that permuted values were between 0 - (rank - 1). Update tosa.TRANSPOSE perms data type to be strictly i32. Verify input/output shapes for tosa.TRANSPOSE. Add verifier to tosa.CONST, with consideration for quantization. Fix TOSA conformance of tensor type to disallow dimensions with size 0 for ranked tensors, per spec. This is not the same as rank 0 tensors. Here is an example of a disallowed tensor: tensor<3x0xi32>. Naturally, this means that the number of elements in a TOSA tensor will always be greater than 0. Signed-off-by: Arteen Abrishami <[email protected]>
1 parent e55d6f5 commit a54efdb

File tree

12 files changed

+301
-161
lines changed

12 files changed

+301
-161
lines changed

mlir/include/mlir/Dialect/Tosa/IR/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ add_mlir_doc(TosaOps TosaOps Dialects/ -gen-op-doc)
33
add_mlir_interface(TosaInterfaces)
44

55
set(LLVM_TARGET_DEFINITIONS TosaOps.td)
6-
mlir_tablegen(TosaAttributes.h.inc -gen-attrdef-decls)
7-
mlir_tablegen(TosaAttributes.cpp.inc -gen-attrdef-defs)
6+
mlir_tablegen(TosaAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=tosa)
7+
mlir_tablegen(TosaAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=tosa)
88
add_public_tablegen_target(MLIRTosaAttributesIncGen)
99

1010
set(LLVM_TARGET_DEFINITIONS TosaDialectBytecode.td)

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 27 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
7373

7474
let arguments = (ins
7575
Tosa_Tensor4D:$input,
76-
7776
Tosa_IntArrayAttr2:$kernel,
7877
Tosa_IntArrayAttr2:$stride,
7978
Tosa_IntArrayAttr4:$pad,
@@ -102,9 +101,8 @@ def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> {
102101

103102
let arguments = (ins
104103
Tosa_Tensor4D:$input,
105-
4DTensorOf<[Tosa_Weight]>:$weight,
104+
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
106105
Tosa_Tensor1D:$bias,
107-
108106
Tosa_IntArrayAttr4:$pad,
109107
Tosa_IntArrayAttr2:$stride,
110108
Tosa_IntArrayAttr2:$dilation,
@@ -132,9 +130,8 @@ def Tosa_Conv3DOp : Tosa_InferShapedTypeOp<"conv3d"> {
132130

133131
let arguments = (ins
134132
Tosa_Tensor5D:$input,
135-
TensorRankOf<[Tosa_Weight], [5]>:$weight,
133+
TosaTensorRankOf<[Tosa_Weight], [5]>:$weight,
136134
Tosa_Tensor1D:$bias,
137-
138135
Tosa_IntArrayAttr6:$pad,
139136
Tosa_IntArrayAttr3:$stride,
140137
Tosa_IntArrayAttr3:$dilation,
@@ -163,9 +160,8 @@ def Tosa_DepthwiseConv2DOp : Tosa_InferShapedTypeOp<"depthwise_conv2d"> {
163160

164161
let arguments = (ins
165162
Tosa_Tensor4D:$input,
166-
4DTensorOf<[Tosa_Weight]>:$weight,
163+
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
167164
Tosa_Tensor1D:$bias,
168-
169165
Tosa_IntArrayAttr4:$pad,
170166
Tosa_IntArrayAttr2:$stride,
171167
Tosa_IntArrayAttr2:$dilation,
@@ -232,7 +228,7 @@ def Tosa_FullyConnectedOp : Tosa_InferShapedTypeOp<"fully_connected"> {
232228

233229
let arguments = (ins
234230
Tosa_Tensor2D:$input,
235-
2DTensorOf<[Tosa_Weight]>:$weight,
231+
TosaTensorRankOf<[Tosa_Weight], [2]>:$weight,
236232
Tosa_Tensor1D:$bias,
237233
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info
238234
);
@@ -347,9 +343,8 @@ def Tosa_TransposeConv2DOp : Tosa_InferShapedTypeOp<"transpose_conv2d"> {
347343

348344
let arguments = (ins
349345
Tosa_Tensor4D:$input,
350-
4DTensorOf<[Tosa_Weight]>:$filter,
346+
TosaTensorRankOf<[Tosa_Weight], [4]>:$filter,
351347
Tosa_Tensor1D:$bias,
352-
353348
Tosa_IntArrayAttr4:$out_pad,
354349
Tosa_IntArrayAttr2:$stride,
355350
Tosa_IntArrayAttrUpto4:$out_shape,
@@ -641,12 +636,12 @@ def Tosa_LogicalAndOp : Tosa_ElementwiseOp<"logical_and", [
641636
}];
642637

643638
let arguments = (ins
644-
I1Tensor:$input1,
645-
I1Tensor:$input2
639+
Tosa_I1Tensor:$input1,
640+
Tosa_I1Tensor:$input2
646641
);
647642

648643
let results = (outs
649-
I1Tensor:$z
644+
Tosa_I1Tensor:$z
650645
);
651646
}
652647

@@ -708,12 +703,12 @@ def Tosa_LogicalOrOp : Tosa_ElementwiseOp<"logical_or", [
708703
}];
709704

710705
let arguments = (ins
711-
I1Tensor:$input1,
712-
I1Tensor:$input2
706+
Tosa_I1Tensor:$input1,
707+
Tosa_I1Tensor:$input2
713708
);
714709

715710
let results = (outs
716-
I1Tensor:$z
711+
Tosa_I1Tensor:$z
717712
);
718713
}
719714

@@ -731,12 +726,12 @@ def Tosa_LogicalXorOp : Tosa_ElementwiseOp<"logical_xor", [
731726
}];
732727

733728
let arguments = (ins
734-
I1Tensor:$input1,
735-
I1Tensor:$input2
729+
Tosa_I1Tensor:$input1,
730+
Tosa_I1Tensor:$input2
736731
);
737732

738733
let results = (outs
739-
I1Tensor:$z
734+
Tosa_I1Tensor:$z
740735
);
741736
}
742737

@@ -1085,11 +1080,11 @@ def Tosa_LogicalNotOp : Tosa_ElementwiseOp<"logical_not",
10851080
}];
10861081

10871082
let arguments = (ins
1088-
I1Tensor:$input1
1083+
Tosa_I1Tensor:$input1
10891084
);
10901085

10911086
let results = (outs
1092-
I1Tensor:$output
1087+
Tosa_I1Tensor:$output
10931088
);
10941089
}
10951090

@@ -1208,7 +1203,7 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
12081203
}];
12091204

12101205
let arguments = (ins
1211-
I1Tensor:$pred,
1206+
Tosa_I1Tensor:$pred,
12121207
Tosa_Tensor:$on_true,
12131208
Tosa_Tensor:$on_false
12141209
);
@@ -1249,7 +1244,7 @@ def Tosa_EqualOp : Tosa_ElementwiseOp<"equal", [
12491244
);
12501245

12511246
let results = (outs
1252-
I1Tensor:$output
1247+
Tosa_I1Tensor:$output
12531248
);
12541249

12551250
let extraClassDeclaration = [{
@@ -1277,7 +1272,7 @@ def Tosa_GreaterOp : Tosa_ElementwiseOp<"greater", [SameOperandsElementType]> {
12771272
);
12781273

12791274
let results = (outs
1280-
I1Tensor:$output
1275+
Tosa_I1Tensor:$output
12811276
);
12821277

12831278
let hasFolder = 1;
@@ -1300,7 +1295,7 @@ def Tosa_GreaterEqualOp : Tosa_ElementwiseOp<"greater_equal",
13001295
);
13011296

13021297
let results = (outs
1303-
I1Tensor:$output
1298+
Tosa_I1Tensor:$output
13041299
);
13051300

13061301
let hasFolder = 1;
@@ -1721,15 +1716,15 @@ def Tosa_TransposeOp : Tosa_InferShapedTypeOp<"transpose",
17211716

17221717
let arguments = (ins
17231718
Tosa_Tensor:$input1,
1724-
Tosa_Int32Or64Tensor:$perms
1719+
Tosa_Int32Tensor:$perms
17251720
);
17261721

17271722
let results = (
17281723
outs Tosa_Tensor:$output
17291724
);
17301725

17311726
let extraClassDeclaration = [{
1732-
LogicalResult getConstantPerms(llvm::SmallVector<int64_t> &perms);
1727+
LogicalResult getConstantPerms(llvm::SmallVector<int32_t> &perms);
17331728
}];
17341729

17351730
let hasCanonicalizer = 1;
@@ -1755,7 +1750,7 @@ def Tosa_GatherOp : Tosa_InferShapedTypeOp<"gather"> {
17551750

17561751
let arguments = (ins
17571752
Tosa_Tensor3D:$values,
1758-
2DTensorOf<[Tosa_Int32]>:$indices
1753+
TosaTensorRankOf<[Tosa_Int32], [2]>:$indices
17591754
);
17601755

17611756
let results = (outs
@@ -1776,7 +1771,7 @@ def Tosa_ScatterOp : Tosa_InferShapedTypeOp<"scatter"> {
17761771

17771772
let arguments = (ins
17781773
Tosa_Tensor3D:$values_in,
1779-
2DTensorOf<[Tosa_Int32]>:$indices,
1774+
TosaTensorRankOf<[Tosa_Int32], [2]>:$indices,
17801775
Tosa_Tensor3D:$input
17811776
);
17821777

@@ -1947,10 +1942,11 @@ def Tosa_ConstOp : Tosa_Op<"const", [ConstantLike, Pure,
19471942
);
19481943

19491944
let results = (outs
1950-
TensorOf<[AnyTypeOf<[Tosa_AnyNumber]>]>:$output
1945+
TosaTensorOf<[AnyTypeOf<[Tosa_AnyNumber]>]>:$output
19511946
);
19521947

19531948
let hasFolder = 1;
1949+
let hasVerifier = 1;
19541950
}
19551951

19561952
//===----------------------------------------------------------------------===//
@@ -2054,7 +2050,7 @@ def Tosa_IfOp : Tosa_Op<"cond_if",
20542050
}];
20552051

20562052
let arguments = (ins
2057-
I1Tensor:$cond,
2053+
Tosa_I1Tensor:$cond,
20582054
Variadic<Tosa_Tensor>:$inputs
20592055
);
20602056

mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -82,58 +82,83 @@ def Tosa_AnyNumber : AnyTypeOf<[Tosa_Int, Tosa_QuantizedInt, AnyFloat],
8282
def Tosa_Weight : AnyTypeOf<[Tosa_Int4, Tosa_Int8,
8383
Tosa_QuantizedInt, AnyFloat]>;
8484

85+
//===----------------------------------------------------------------------===//
86+
// TOSA Tensor Conformance
87+
//===----------------------------------------------------------------------===//
88+
89+
def HasNo0Dimensions : And<[
90+
IsRankedTensorTypePred,
91+
CPred<"::llvm::all_of(::llvm::cast<::mlir::RankedTensorType>($_self).getShape(), [](auto v) { return v != 0; })">]>;
92+
93+
class TosaTensorOf<
94+
list<Type> allowedTypes, string summary = "tosa-conformant tensor">
95+
: TensorOf<allowedTypes, [Or<[HasNo0Dimensions, IsUnrankedTensorTypePred]>], summary>;
96+
97+
class TosaRankedTensorOf<
98+
list<Type> allowedTypes, list<Pred> preds = [], string summary = "tosa-conformant ranked tensor">
99+
: RankedTensorOf<allowedTypes, !listconcat([HasNo0Dimensions], preds), summary>;
100+
101+
class TosaUnrankedTensorOf<list<Type> allowedTypes, list<Pred> preds = [], string summary = "tosa-conformant unranked tensor">
102+
: UnrankedTensorOf<allowedTypes, preds, summary>;
103+
104+
class TosaTensorRankOf<list<Type> allowedTypes, list<int> ranks>
105+
: TosaRankedTensorOf<allowedTypes,
106+
[HasAnyRankOfPred<ranks>],
107+
!interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensor">;
108+
85109
//===----------------------------------------------------------------------===//
86110
// Tensor types
87111
//===----------------------------------------------------------------------===//
88112

89-
def Tosa_Int32Tensor : TensorOf<[Tosa_Int32]>;
90-
def Tosa_Int32Or64Tensor : TensorOf<[Tosa_Int32Or64]>;
113+
def Tosa_I1Tensor : TosaTensorOf<[I1]>;
114+
def Tosa_Int32Tensor : TosaTensorOf<[Tosa_Int32]>;
115+
def Tosa_Int32Or64Tensor :TosaTensorOf<[Tosa_Int32Or64]>;
91116

92-
def Tosa_FloatTensor : TensorOf<[AnyFloat]>;
117+
def Tosa_FloatTensor : TosaTensorOf<[AnyFloat]>;
93118

94119
// Either ranked or unranked tensor of TOSA supported element types.
95-
def Tosa_Tensor : TensorOf<[Tosa_AnyNumber]>;
120+
def Tosa_Tensor : TosaTensorOf<[Tosa_AnyNumber]>;
96121

97122
// Must be ranked but no further constraints
98-
def Tosa_RankedTensor : RankedTensorOf<[Tosa_AnyNumber]>;
123+
def Tosa_RankedTensor : TosaRankedTensorOf<[Tosa_AnyNumber]>;
99124

100125
// Any tensor element type allowed in Tosa ops.
101126
def Tosa_ElementType : Type<Or<[Tosa_Int.predicate, Tosa_QuantizedInt.predicate,
102127
AnyFloat.predicate]>, "tosa.dtype">;
103128

104129
class Tosa_TensorOfOrNone<list<Type> allowedTypes, string description = ""> :
105-
AnyTypeOf<[TensorOf<allowedTypes>, NoneType], description>;
130+
AnyTypeOf<[TosaTensorOf<allowedTypes>, NoneType], description>;
106131

107132
//===----------------------------------------------------------------------===//
108133
// Tensor types with constrained ranks.
109134
//===----------------------------------------------------------------------===//
110135

111136
// Rank-0 (scalar) tensor
112-
def Tosa_ScalarTensor : TensorRankOf<[Tosa_AnyNumber], [0]>;
137+
def Tosa_ScalarTensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>;
113138

114139
// We include unranked tensors as a supported type for all possible tosa
115140
// Tensors as unranked does not guarantee invalid. If unranked tensors exist
116141
// they should be shape propagate used Tosa's shape inference pass and verified
117142
// to not include any remaining unranked tensors.
118-
def Tosa_UnrankedTensor : UnrankedTensorOf<[Tosa_AnyNumber]>;
143+
def Tosa_UnrankedTensor : TosaUnrankedTensorOf<[Tosa_AnyNumber]>;
119144

120-
def Tosa_Tensor1D : AnyTypeOf<[Tosa_UnrankedTensor, 1DTensorOf<[Tosa_AnyNumber]>], "1-d tensor", "::mlir::TensorType">;
121-
def Tosa_Tensor2D : AnyTypeOf<[Tosa_UnrankedTensor, 2DTensorOf<[Tosa_AnyNumber]>], "2-d tensor", "::mlir::TensorType">;
122-
def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, 3DTensorOf<[Tosa_AnyNumber]>], "3-d tensor", "::mlir::TensorType">;
123-
def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, 4DTensorOf<[Tosa_AnyNumber]>], "4-d tensor", "::mlir::TensorType">;
124-
def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [5]>], "5-d tensor", "::mlir::TensorType">;
145+
def Tosa_Tensor1D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1]>], "1-d tosa-conformant tensor", "::mlir::TensorType">;
146+
def Tosa_Tensor2D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [2]>], "2-d tosa-conformant tensor", "::mlir::TensorType">;
147+
def Tosa_Tensor3D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [3]>], "3-d tosa-conformant tensor", "::mlir::TensorType">;
148+
def Tosa_Tensor4D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [4]>], "4-d tosa-conformant tensor", "::mlir::TensorType">;
149+
def Tosa_Tensor5D : AnyTypeOf<[Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [5]>], "5-d tosa-conformant tensor", "::mlir::TensorType">;
125150

126151
// Ranked tensors up to given rank.
127152
def Tosa_Tensor1Dto4D : AnyTypeOf<[
128-
Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [1,2,3,4]>]>;
153+
Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1,2,3,4]>]>;
129154
def Tosa_Tensor1Dto6D : AnyTypeOf<[
130-
Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [1,2,3,4,5,6]>]>;
155+
Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [1,2,3,4,5,6]>]>;
131156

132157
def Tosa_TensorUpto4D : AnyTypeOf<[
133-
Tosa_UnrankedTensor, TensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4]>]>;
158+
Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_AnyNumber], [0,1,2,3,4]>]>;
134159

135160
def Tosa_Int32TensorUpto4D : AnyTypeOf<[
136-
Tosa_UnrankedTensor, TensorRankOf<[Tosa_Int32], [0,1,2,3,4]>]>;
161+
Tosa_UnrankedTensor, TosaTensorRankOf<[Tosa_Int32], [0,1,2,3,4]>]>;
137162

138163
//===----------------------------------------------------------------------===//
139164
// Generic scalar, vector, or tensor of a particular type.
@@ -142,7 +167,7 @@ def Tosa_Int32TensorUpto4D : AnyTypeOf<[
142167
class Tosa_TypeLike<list<Type> types, string description = ""> : TypeConstraint<Or<[
143168
AnyTypeOf<types>.predicate,
144169
VectorOf<types>.predicate,
145-
TensorOf<types>.predicate]>,
170+
TosaTensorOf<types>.predicate]>,
146171
description>;
147172

148173
def Tosa_IntLike : Tosa_TypeLike<[Tosa_Int], "signless-integer-like">;

mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,19 @@ TosaOp CreateOpAndInferShape(PatternRewriter &rewriter, Location loc,
216216
return CreateOpAndInferShape<TosaOp>(builder, resultTy, args...);
217217
}
218218

219+
// Apply an int32_t permutation to some input, that should be of the same
220+
// size as perms. Perms should contain some permutation of 0 - perms.size() - 1.
221+
template <typename T>
222+
SmallVector<T> applyTOSAPermutation(ArrayRef<T> input,
223+
ArrayRef<int32_t> perms) {
224+
SmallVector<T> permuted;
225+
size_t N = input.size();
226+
permuted.resize_for_overwrite(N);
227+
for (size_t i = 0; i < N; i++)
228+
permuted[i] = input[perms[i]];
229+
return permuted;
230+
}
231+
219232
} // namespace tosa
220233
} // namespace mlir
221234

0 commit comments

Comments
 (0)