Skip to content

Commit a5f0b23

Browse files
tatwaichongeric-k256
authored andcommitted
[mlir][tosa][fix] Add proper type checking trait for tosa mul
when operating integer type tensors, tosa elementwise multiplication requires the element type of result to be a 32-bit integer rather than the same type as inputs. Change-Id: Ifd3d7ebd879be5c6b2c8e23aa6d7ef41f39c6d41 Reviewed By: mgehre-amd Differential Revision: https://reviews.llvm.org/D154988
1 parent e4ad1f9 commit a5f0b23

File tree

5 files changed

+65
-6
lines changed

5 files changed

+65
-6
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515

1616
#include "mlir/Bytecode/BytecodeOpInterface.h"
1717
#include "mlir/Dialect/Traits.h"
18+
#include "mlir/IR/OpDefinition.h"
1819
#include "mlir/IR/OpImplementation.h"
20+
#include "mlir/IR/TypeUtilities.h"
1921
#include "mlir/Interfaces/InferTypeOpInterface.h"
2022
#include "mlir/Interfaces/LoopLikeInterface.h"
2123
#include "mlir/Interfaces/SideEffectInterfaces.h"
@@ -35,6 +37,49 @@ namespace tosa {
3537
#include "mlir/Dialect/Tosa/IR/TosaInterfaces.h.inc"
3638

3739
} // namespace tosa
40+
41+
namespace OpTrait {
42+
namespace tosa {
43+
44+
// This trait verifies if the element type amoung operands and result
45+
// of multiplication match tosa specification.
46+
template <typename ConcreteType>
47+
class MulOperandsAndResultElementType
48+
: public TraitBase<ConcreteType, MulOperandsAndResultElementType> {
49+
public:
50+
static LogicalResult verifyTrait(Operation *op) {
51+
auto resElemType = getElementTypeOrSelf(op->getResult(0));
52+
53+
// In cases of floating point type, op requires the same element
54+
// type for all operands and result.
55+
if (llvm::isa<FloatType>(resElemType))
56+
return impl::verifySameOperandsAndResultElementType(op);
57+
58+
if (auto resIntType = resElemType.dyn_cast<IntegerType>()) {
59+
IntegerType lhsIntType =
60+
getElementTypeOrSelf(op->getOperand(0)).cast<IntegerType>();
61+
IntegerType rhsIntType =
62+
getElementTypeOrSelf(op->getOperand(1)).cast<IntegerType>();
63+
if (lhsIntType != rhsIntType)
64+
return op->emitOpError(
65+
"requires the same element type for all operands");
66+
67+
// Though the spec requires the element type of result to be i32, a more
68+
// relaxed way is provided at dialect level for easier cooperating with
69+
// other dialects.
70+
if (lhsIntType.getWidth() > resIntType.getWidth())
71+
return op->emitOpError("invalid data type size for operands or result");
72+
73+
return success();
74+
}
75+
76+
return failure();
77+
}
78+
};
79+
80+
} // namespace tosa
81+
} // namespace OpTrait
82+
3883
} // namespace mlir
3984

4085
#define GET_ATTRDEF_CLASSES

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -747,12 +747,17 @@ def Tosa_MinimumOp : Tosa_ElementwiseOp<"minimum", [
747747
);
748748
}
749749

750+
def MulOperandsAndResultElementType :
751+
NativeOpTrait<"MulOperandsAndResultElementType"> {
752+
let cppNamespace = "mlir::OpTrait::tosa";
753+
}
754+
750755
//===----------------------------------------------------------------------===//
751756
// Operator: mul
752757
//===----------------------------------------------------------------------===//
753758
def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [
754759
Commutative,
755-
SameOperandsAndResultElementType]> {
760+
MulOperandsAndResultElementType]> {
756761
let summary = "Multiplication operator";
757762

758763
let description = [{

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -538,8 +538,10 @@ func.func @test_simple_f16(%arg0: tensor<1xf16>) -> () {
538538
// CHECK-LABEL: @test_simple_i16
539539
func.func @test_simple_i16(%arg0: tensor<1xi16>) -> () {
540540
// CHECK: linalg.generic
541+
// CHECK: arith.extsi
542+
// CHECK: arith.extsi
541543
// CHECK: arith.muli
542-
%0 = "tosa.mul"(%arg0, %arg0) {shift = 0 : i32} : (tensor<1xi16>, tensor<1xi16>) -> tensor<1xi16>
544+
%0 = "tosa.mul"(%arg0, %arg0) {shift = 0 : i32} : (tensor<1xi16>, tensor<1xi16>) -> tensor<1xi32>
543545

544546
return
545547
}

mlir/test/Dialect/Tosa/constant-op-fold.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -294,13 +294,13 @@ func.func @fold_mul_one_lhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
294294
// -----
295295

296296
// CHECK-LABEL: @fold_mul_splat_i8
297-
func.func @fold_mul_splat_i8() -> tensor<10xi8> {
297+
func.func @fold_mul_splat_i8() -> tensor<10xi32> {
298298
%one = "tosa.const"() {value = dense<17> : tensor<10xi8>} : () -> tensor<10xi8>
299299
%two = "tosa.const"() {value = dense<32> : tensor<10xi8>} : () -> tensor<10xi8>
300-
%mul = "tosa.mul"(%one, %two) {shift = 3 : i32} : (tensor<10xi8>, tensor<10xi8>) -> tensor<10xi8>
301-
// CHECK: %[[THREE:.+]] = "tosa.const"() <{value = dense<68> : tensor<10xi8>}
300+
%mul = "tosa.mul"(%one, %two) {shift = 3 : i32} : (tensor<10xi8>, tensor<10xi8>) -> tensor<10xi32>
301+
// CHECK: %[[THREE:.+]] = "tosa.const"() <{value = dense<68> : tensor<10xi32>}
302302
// CHECK: return %[[THREE]]
303-
return %mul : tensor<10xi8>
303+
return %mul : tensor<10xi32>
304304
}
305305

306306
// -----

mlir/test/Dialect/Tosa/ops.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,13 @@ func.func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> te
229229
return %0 : tensor<13x21x3xf32>
230230
}
231231

232+
// -----
233+
// CHECK-LABEL: mul
234+
func.func @test_mul_relaxed_result_type(%arg0: tensor<13x21x3xi16>, %arg1: tensor<13x1x3xi16>) -> tensor<13x21x3xi16> {
235+
%0 = "tosa.mul"(%arg0, %arg1) { shift = 1 : i32 } : (tensor<13x21x3xi16>, tensor<13x1x3xi16>) -> tensor<13x21x3xi16>
236+
return %0 : tensor<13x21x3xi16>
237+
}
238+
232239
// -----
233240
// CHECK-LABEL: pow
234241
func.func @test_pow(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x3xf32> {

0 commit comments

Comments
 (0)