Skip to content

Commit a58e774

Browse files
[mlir][tosa] Make TOSA MUL's Shift an Input (#121953)
The TOSA-v1.0 specification makes the shift attribute of the MUL (Hammard product) operator an input. Move the `shift` parameter of the MUL operator in the MILR TOSA dialect from an attribute to an input and update any lit tests appropriately. Expand the verifier of the `tosa::MulOp` operation to check the various constraints defined in the TOSA-v1.0 specification. Specifically, ensure that all input operands (excluding the optional shift) are of the same rank. This means that broadcasting tests which previously checked rank-0 tensors would be broadcast are no longer valid and are removed. Signed-off-by: Jack Frankland <[email protected]> Co-authored-by: TatWai Chong <[email protected]>
1 parent 5a8fe9e commit a58e774

File tree

17 files changed

+339
-132
lines changed

17 files changed

+339
-132
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,9 +239,7 @@ class Tosa_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
239239
Tosa_Op<mnemonic, !listconcat(traits, [
240240
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
241241
["inferReturnTypeComponents"]>,
242-
ResultsBroadcastableShape,
243242
TosaElementwiseOperator,
244-
SameOperandsAndResultRank,
245243
Pure])> {
246244
let assemblyFormat =
247245
"operands attr-dict `:` functional-type(operands, results)";
@@ -250,6 +248,8 @@ class Tosa_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
250248
class Tosa_ElementwiseUnaryOp<string mnemonic, list<Trait> traits = []> :
251249
Tosa_ElementwiseOp<mnemonic, !listconcat(traits, [
252250
SameOperandsAndResultShape,
251+
ResultsBroadcastableShape,
252+
SameOperandsAndResultRank,
253253
SameOperandsAndResultElementType])> {}
254254

255255
class Tosa_InferTensorTypeOp<string mnemonic, list<Trait> traits = []>

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#define MLIR_DIALECT_TOSA_IR_TOSAOPS_H
1515

1616
#include "mlir/Bytecode/BytecodeOpInterface.h"
17+
#include "mlir/Dialect/Quant/IR/QuantTypes.h"
1718
#include "mlir/Dialect/Traits.h"
1819
#include "mlir/IR/OpDefinition.h"
1920
#include "mlir/IR/OpImplementation.h"
@@ -53,34 +54,43 @@ class MulOperandsAndResultElementType
5354
: public TraitBase<ConcreteType, MulOperandsAndResultElementType> {
5455
public:
5556
static LogicalResult verifyTrait(Operation *op) {
56-
auto resElemType = getElementTypeOrSelf(op->getResult(0));
57-
58-
// In cases of floating point type, op requires the same element
59-
// type for all operands and result.
60-
if (llvm::isa<FloatType>(resElemType))
61-
return impl::verifySameOperandsAndResultElementType(op);
62-
57+
// Check we have a single result.
58+
if (failed(impl::verifyOneResult(op)))
59+
return failure();
60+
Type resElemType = getElementTypeOrSelf(op->getResult(0));
61+
62+
// Check we have lhs and rhs.
63+
if (failed(impl::verifyAtLeastNOperands(op, 2)))
64+
return failure();
65+
66+
Type lhsElemType = getElementTypeOrSelf(op->getOperand(0));
67+
Type rhsElemType = getElementTypeOrSelf(op->getOperand(1));
68+
69+
// Check that for i32 a shift has been explicitly provided.
70+
if (lhsElemType.isInteger(32) && failed(impl::verifyNOperands(op, 3)))
71+
return failure();
72+
73+
// Verify operands type match (ignoring the shift parameter which will
74+
// always be i8).
75+
if (lhsElemType != rhsElemType)
76+
return op->emitOpError("requires the same element type for all operands");
77+
78+
// Though the spec requires the element type of result to be i32, a more
79+
// relaxed way is provided at dialect level for easier cooperating with
80+
// other dialects.
6381
if (auto resIntType = dyn_cast<IntegerType>(resElemType)) {
64-
IntegerType lhsIntType =
65-
cast<IntegerType>(getElementTypeOrSelf(op->getOperand(0)));
66-
IntegerType rhsIntType =
67-
cast<IntegerType>(getElementTypeOrSelf(op->getOperand(1)));
68-
if (lhsIntType != rhsIntType)
69-
return op->emitOpError(
70-
"requires the same element type for all operands");
71-
72-
// Though the spec requires the element type of result to be i32, a more
73-
// relaxed way is provided at dialect level for easier cooperating with
74-
// other dialects.
82+
auto lhsIntType = cast<IntegerType>(lhsElemType);
7583
if (lhsIntType.getWidth() > resIntType.getWidth())
7684
return op->emitOpError("invalid data type size for operands or result");
77-
78-
return success();
85+
} else {
86+
// In cases of floating point type or quant types, op requires the same
87+
// element type for all operands and result (excluding shift).
88+
if (resElemType != lhsElemType)
89+
return op->emitOpError(
90+
"requires the same element type for all operands and results");
7991
}
8092

81-
// In cases of all other types, op requires the same element
82-
// type for all operands and result.
83-
return impl::verifySameOperandsAndResultElementType(op);
93+
return llvm::success();
8494
}
8595
};
8696

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 67 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -482,7 +482,9 @@ def Tosa_ErfOp : Tosa_ElementwiseUnaryOp<"erf"> {
482482
//===----------------------------------------------------------------------===//
483483
def Tosa_AddOp : Tosa_ElementwiseOp<"add", [
484484
Commutative,
485-
SameOperandsAndResultElementType]> {
485+
ResultsBroadcastableShape,
486+
SameOperandsAndResultElementType,
487+
SameOperandsAndResultRank]> {
486488
let summary = "Elementwise addition operator";
487489

488490
let description = [{
@@ -515,8 +517,10 @@ def Tosa_AddOp : Tosa_ElementwiseOp<"add", [
515517
//===----------------------------------------------------------------------===//
516518
// Operator: arithmetic_right_shift
517519
//===----------------------------------------------------------------------===//
518-
def Tosa_ArithmeticRightShiftOp : Tosa_ElementwiseOp<"arithmetic_right_shift",
519-
[SameOperandsAndResultElementType]> {
520+
def Tosa_ArithmeticRightShiftOp : Tosa_ElementwiseOp<"arithmetic_right_shift", [
521+
ResultsBroadcastableShape,
522+
SameOperandsAndResultElementType,
523+
SameOperandsAndResultRank]> {
520524
let summary = "Elementwise Arithmetic Right Shift";
521525

522526
let description = [{
@@ -540,7 +544,9 @@ def Tosa_ArithmeticRightShiftOp : Tosa_ElementwiseOp<"arithmetic_right_shift",
540544
//===----------------------------------------------------------------------===//
541545
def Tosa_BitwiseAndOp : Tosa_ElementwiseOp<"bitwise_and", [
542546
Commutative,
543-
SameOperandsAndResultElementType]> {
547+
ResultsBroadcastableShape,
548+
SameOperandsAndResultElementType,
549+
SameOperandsAndResultRank]> {
544550
let summary = "Bitwise AND operator";
545551

546552
let description = [{
@@ -563,7 +569,9 @@ def Tosa_BitwiseAndOp : Tosa_ElementwiseOp<"bitwise_and", [
563569
//===----------------------------------------------------------------------===//
564570
def Tosa_BitwiseOrOp : Tosa_ElementwiseOp<"bitwise_or", [
565571
Commutative,
566-
SameOperandsAndResultElementType]> {
572+
ResultsBroadcastableShape,
573+
SameOperandsAndResultElementType,
574+
SameOperandsAndResultRank]> {
567575
let summary = "Bitwise OR operator";
568576

569577
let description = [{
@@ -586,7 +594,9 @@ def Tosa_BitwiseOrOp : Tosa_ElementwiseOp<"bitwise_or", [
586594
//===----------------------------------------------------------------------===//
587595
def Tosa_BitwiseXorOp : Tosa_ElementwiseOp<"bitwise_xor", [
588596
Commutative,
589-
SameOperandsAndResultElementType]> {
597+
ResultsBroadcastableShape,
598+
SameOperandsAndResultElementType,
599+
SameOperandsAndResultRank]> {
590600
let summary = "Bitwise XOR operator";
591601

592602
let description = [{
@@ -607,7 +617,10 @@ def Tosa_BitwiseXorOp : Tosa_ElementwiseOp<"bitwise_xor", [
607617
//===----------------------------------------------------------------------===//
608618
// Operator: int_div
609619
//===----------------------------------------------------------------------===//
610-
def Tosa_IntDivOp : Tosa_ElementwiseOp<"int_div", [SameOperandsAndResultElementType]> {
620+
def Tosa_IntDivOp : Tosa_ElementwiseOp<"int_div", [
621+
ResultsBroadcastableShape,
622+
SameOperandsAndResultRank,
623+
SameOperandsAndResultElementType]> {
611624
let summary = "Integer divide operator";
612625

613626
let description = [{
@@ -632,7 +645,9 @@ def Tosa_IntDivOp : Tosa_ElementwiseOp<"int_div", [SameOperandsAndResultElementT
632645
//===----------------------------------------------------------------------===//
633646
def Tosa_LogicalAndOp : Tosa_ElementwiseOp<"logical_and", [
634647
Commutative,
635-
SameOperandsAndResultElementType]> {
648+
ResultsBroadcastableShape,
649+
SameOperandsAndResultElementType,
650+
SameOperandsAndResultRank]> {
636651
let summary = "Returns the truth value of x AND y element-wise.";
637652

638653
let description = [{
@@ -653,8 +668,10 @@ def Tosa_LogicalAndOp : Tosa_ElementwiseOp<"logical_and", [
653668
//===----------------------------------------------------------------------===//
654669
// Operator: logical_left_shift
655670
//===----------------------------------------------------------------------===//
656-
def Tosa_LogicalLeftShiftOp : Tosa_ElementwiseOp<"logical_left_shift",
657-
[SameOperandsAndResultElementType]> {
671+
def Tosa_LogicalLeftShiftOp : Tosa_ElementwiseOp<"logical_left_shift", [
672+
ResultsBroadcastableShape,
673+
SameOperandsAndResultElementType,
674+
SameOperandsAndResultRank]> {
658675
let summary = "Elementwise Logical Left Shift";
659676

660677
let description = [{
@@ -675,8 +692,10 @@ def Tosa_LogicalLeftShiftOp : Tosa_ElementwiseOp<"logical_left_shift",
675692
//===----------------------------------------------------------------------===//
676693
// Operator: logical_right_shift
677694
//===----------------------------------------------------------------------===//
678-
def Tosa_LogicalRightShiftOp : Tosa_ElementwiseOp<"logical_right_shift",
679-
[SameOperandsAndResultElementType]> {
695+
def Tosa_LogicalRightShiftOp : Tosa_ElementwiseOp<"logical_right_shift", [
696+
ResultsBroadcastableShape,
697+
SameOperandsAndResultElementType,
698+
SameOperandsAndResultRank]> {
680699
let summary = "Elementwise Logical Right Shift";
681700

682701
let description = [{
@@ -699,7 +718,9 @@ def Tosa_LogicalRightShiftOp : Tosa_ElementwiseOp<"logical_right_shift",
699718
//===----------------------------------------------------------------------===//
700719
def Tosa_LogicalOrOp : Tosa_ElementwiseOp<"logical_or", [
701720
Commutative,
702-
SameOperandsAndResultElementType]> {
721+
ResultsBroadcastableShape,
722+
SameOperandsAndResultElementType,
723+
SameOperandsAndResultRank]> {
703724
let summary = "Returns the truth value of x OR y element-wise.";
704725

705726
let description = [{
@@ -722,7 +743,9 @@ def Tosa_LogicalOrOp : Tosa_ElementwiseOp<"logical_or", [
722743
//===----------------------------------------------------------------------===//
723744
def Tosa_LogicalXorOp : Tosa_ElementwiseOp<"logical_xor", [
724745
Commutative,
725-
SameOperandsAndResultElementType]> {
746+
ResultsBroadcastableShape,
747+
SameOperandsAndResultElementType,
748+
SameOperandsAndResultRank]> {
726749
let summary = "Returns the truth value of x XOR y element-wise.";
727750

728751
let description = [{
@@ -745,7 +768,9 @@ def Tosa_LogicalXorOp : Tosa_ElementwiseOp<"logical_xor", [
745768
//===----------------------------------------------------------------------===//
746769
def Tosa_MaximumOp : Tosa_ElementwiseOp<"maximum", [
747770
Commutative,
748-
SameOperandsAndResultElementType]> {
771+
ResultsBroadcastableShape,
772+
SameOperandsAndResultElementType,
773+
SameOperandsAndResultRank]> {
749774
let summary = "Elementwise Maximum";
750775

751776
let description = [{
@@ -769,7 +794,9 @@ def Tosa_MaximumOp : Tosa_ElementwiseOp<"maximum", [
769794
//===----------------------------------------------------------------------===//
770795
def Tosa_MinimumOp : Tosa_ElementwiseOp<"minimum", [
771796
Commutative,
772-
SameOperandsAndResultElementType]> {
797+
ResultsBroadcastableShape,
798+
SameOperandsAndResultElementType,
799+
SameOperandsAndResultRank]> {
773800
let summary = "Elementwise Minimum";
774801

775802
let description = [{
@@ -810,7 +837,7 @@ def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [
810837
let arguments = (ins
811838
Tosa_Tensor:$input1,
812839
Tosa_Tensor:$input2,
813-
I8Attr:$shift
840+
Optional<TosaTensorRankOf<[Tosa_Int8], [1]>>:$shift
814841
);
815842

816843
let results = (outs
@@ -824,7 +851,10 @@ def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [
824851
//===----------------------------------------------------------------------===//
825852
// Operator: pow
826853
//===----------------------------------------------------------------------===//
827-
def Tosa_PowOp : Tosa_ElementwiseOp<"pow", [SameOperandsAndResultElementType]> {
854+
def Tosa_PowOp : Tosa_ElementwiseOp<"pow", [
855+
ResultsBroadcastableShape,
856+
SameOperandsAndResultElementType,
857+
SameOperandsAndResultRank]> {
828858
let summary = "Computes the power of one value to another.";
829859

830860
let description = [{
@@ -845,7 +875,10 @@ def Tosa_PowOp : Tosa_ElementwiseOp<"pow", [SameOperandsAndResultElementType]> {
845875
//===----------------------------------------------------------------------===//
846876
// Operator: sub
847877
//===----------------------------------------------------------------------===//
848-
def Tosa_SubOp : Tosa_ElementwiseOp<"sub", [SameOperandsAndResultElementType]> {
878+
def Tosa_SubOp : Tosa_ElementwiseOp<"sub", [
879+
ResultsBroadcastableShape,
880+
SameOperandsAndResultElementType,
881+
SameOperandsAndResultRank]> {
849882
let summary = "Elementwise subtraction operator";
850883

851884
let description = [{
@@ -1196,7 +1229,9 @@ def Tosa_SinOp : Tosa_ElementwiseUnaryOp<"sin"> {
11961229
//===----------------------------------------------------------------------===//
11971230
// Operator: select
11981231
//===----------------------------------------------------------------------===//
1199-
def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
1232+
def Tosa_SelectOp : Tosa_ElementwiseOp<"select", [
1233+
ResultsBroadcastableShape,
1234+
SameOperandsAndResultRank]> {
12001235
let summary = "Elementwise select operator";
12011236

12021237
let description = [{
@@ -1232,7 +1267,9 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> {
12321267
def Tosa_EqualOp : Tosa_ElementwiseOp<"equal", [
12331268
InferTensorType,
12341269
Commutative,
1235-
SameOperandsElementType]> {
1270+
ResultsBroadcastableShape,
1271+
SameOperandsElementType,
1272+
SameOperandsAndResultRank]> {
12361273
let summary = "Returns the truth value of (x == y) element-wise.";
12371274

12381275
let description = [{
@@ -1260,7 +1297,10 @@ def Tosa_EqualOp : Tosa_ElementwiseOp<"equal", [
12601297
//===----------------------------------------------------------------------===//
12611298
// Operator: greater
12621299
//===----------------------------------------------------------------------===//
1263-
def Tosa_GreaterOp : Tosa_ElementwiseOp<"greater", [SameOperandsElementType]> {
1300+
def Tosa_GreaterOp : Tosa_ElementwiseOp<"greater", [
1301+
ResultsBroadcastableShape,
1302+
SameOperandsElementType,
1303+
SameOperandsAndResultRank]> {
12641304
let summary = "Returns the truth value of (x > y) element-wise.";
12651305

12661306
let description = [{
@@ -1282,8 +1322,11 @@ def Tosa_GreaterOp : Tosa_ElementwiseOp<"greater", [SameOperandsElementType]> {
12821322
//===----------------------------------------------------------------------===//
12831323
// Operator: greater_equal
12841324
//===----------------------------------------------------------------------===//
1285-
def Tosa_GreaterEqualOp : Tosa_ElementwiseOp<"greater_equal",
1286-
[SameOperandsElementType]> {
1325+
def Tosa_GreaterEqualOp : Tosa_ElementwiseOp<"greater_equal", [
1326+
ResultsBroadcastableShape,
1327+
SameOperandsElementType,
1328+
SameOperandsAndResultRank
1329+
]> {
12871330
let summary = "Returns the truth value of (x >= y) element-wise.";
12881331

12891332
let description = [{

0 commit comments

Comments
 (0)