Skip to content

Commit 26a7f42

Browse files
tatwaichongrsuderman
authored andcommitted
[mlir][tosa] Add accumulator type attribute to TOSA dialect
Tosa supports fp16 and fp32 accumulator type for fp16 input, but no way to tell for computational operators whether accumulator should be fp16 or fp32 from input type. Add this new attribute to specify the type. Set to fp32 by default for now. When the time fp16 is supported, the accumulator type can be selected based on trade-off between performance and accuracy. Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D146317
1 parent 050c09f commit 26a7f42

File tree

7 files changed

+38
-18
lines changed

7 files changed

+38
-18
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,10 +174,10 @@ def Tosa_MatMulOpQuantInfoBuilder : OpBuilder<
174174
def Tosa_AvgPool2dOpQuantInfoBuilder : OpBuilder<
175175
(ins "::mlir::Type":$outputType, "::mlir::Value":$input,
176176
"::mlir::DenseI64ArrayAttr":$kernel, "::mlir::DenseI64ArrayAttr":$stride,
177-
"::mlir::DenseI64ArrayAttr":$pad),
177+
"::mlir::DenseI64ArrayAttr":$pad, "::mlir::TypeAttr":$acc_type),
178178
[{
179179
buildAvgPool2dOpWithQuantInfo($_builder, $_state, outputType,
180-
input, kernel, stride, pad);
180+
input, kernel, stride, pad, acc_type);
181181
}]>;
182182

183183
// This builder is called on single-parameter unary operators that have a scale

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,12 @@ def Tosa_ArgMaxOp : Tosa_Op<"argmax", [
5353
);
5454
}
5555

56+
//===----------------------------------------------------------------------===//
57+
// Accumulator types.
58+
//===----------------------------------------------------------------------===//
59+
60+
def Tosa_AccType : AnyTypeOf<[I<32>, SI<32>, F16, F32]>;
61+
5662
//===----------------------------------------------------------------------===//
5763
// Operator: avg_pool2d
5864
//===----------------------------------------------------------------------===//
@@ -74,6 +80,7 @@ def Tosa_AvgPool2dOp : Tosa_Op<"avg_pool2d", [
7480
Tosa_IntArrayAttr2:$kernel,
7581
Tosa_IntArrayAttr2:$stride,
7682
Tosa_IntArrayAttr4:$pad,
83+
TypeAttrOf<Tosa_AccType>:$acc_type,
7784
OptionalAttr<Tosa_UnaryOpQuantizationAttr>:$quantization_info
7885
);
7986

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -751,8 +751,7 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
751751
ShapedType resultTy = cast<ShapedType>(op.getType());
752752
Type resultETy = cast<ShapedType>(op.getType()).getElementType();
753753

754-
Type accETy =
755-
isa<IntegerType>(inElementTy) ? rewriter.getI32Type() : inElementTy;
754+
Type accETy = op.getAccType();
756755
ShapedType accTy = resultTy.clone(accETy);
757756

758757
auto dynamicDimsOr =

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,17 @@ LogicalResult tosa::AvgPool2dOp::verify() {
153153
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultETy))
154154
resultETy = quantType.getStorageType();
155155

156+
auto accType = getAccType();
157+
if (inputETy.isa<IntegerType>() && !accType.isInteger(32))
158+
return emitOpError("accumulator type for integer tensor is not i32");
159+
160+
if ((inputETy.isBF16() || inputETy.isF16()) &&
161+
!(accType.isF16() || accType.isF32()))
162+
return emitOpError("accumulator type for f16/bf16 tensor is not f16/f32");
163+
164+
if (inputETy.isF32() && !accType.isF32())
165+
return emitOpError("accumulator type for f32 tensor is not f32");
166+
156167
if (inputETy.isF32() && resultETy.isF32())
157168
return success();
158169
if (inputETy.isInteger(8) && resultETy.isInteger(8))
@@ -268,13 +279,16 @@ static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
268279
/// Both the tosa.avg_pool2d and unary ops use the same UnaruOpQuantizationAttr
269280
/// but avg_pool operator has its own builder as it has additional parameters
270281
/// not part of the unary ops.
271-
static void buildAvgPool2dOpWithQuantInfo(
272-
OpBuilder &builder, OperationState &result, Type outputType, Value input,
273-
DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad) {
282+
static void
283+
buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result,
284+
Type outputType, Value input,
285+
DenseArrayAttr kernel, DenseArrayAttr stride,
286+
DenseArrayAttr pad, TypeAttr acc_type) {
274287
result.addOperands(input);
275288
result.addAttribute("kernel", kernel);
276289
result.addAttribute("stride", stride);
277290
result.addAttribute("pad", pad);
291+
result.addAttribute("acc_type", acc_type);
278292
auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType);
279293
if (quantAttr)
280294
result.addAttribute("quantization_info", quantAttr);

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ func.func @avg_pool_f32(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>)
286286
// CHECK: %[[FLT:.+]] = arith.sitofp %[[CAST]]
287287
// CHECK: %[[DIV:.+]] = arith.divf %[[IN]], %[[FLT]]
288288
// CHECK: linalg.yield %[[DIV]]
289-
%0 = "tosa.avg_pool2d"(%arg0) {pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>)
289+
%0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xf32>) -> (tensor<1x5x33x62xf32>)
290290
return %0 : tensor<1x5x33x62xf32>
291291
}
292292

@@ -329,7 +329,7 @@ func.func @avg_pool_i8(%arg0: tensor<1x6x34x62xi8>) -> (tensor<1x5x33x62xi8>) {
329329
// CHECK: %[[CLAMP:.+]] = arith.select %[[CMP]], %[[CMAX]], %[[SEL]]
330330
// CHECK: %[[TRUNC:.+]] = arith.trunci %[[CLAMP]]
331331
// CHECK: linalg.yield %[[TRUNC]]
332-
%0 = "tosa.avg_pool2d"(%arg0) {pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xi8>) -> (tensor<1x5x33x62xi8>)
332+
%0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<1x6x34x62xi8>) -> (tensor<1x5x33x62xi8>)
333333
return %0 : tensor<1x5x33x62xi8>
334334
}
335335

@@ -352,7 +352,7 @@ func.func @avg_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>)
352352
// CHECK-SAME: outs(%[[FILL]] : tensor<?x5x33x62xf32>) -> tensor<?x5x33x62xf32>
353353
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x5x33x62xf32>
354354
// CHECK: %[[GENERIC:.+]] = linalg.generic
355-
%0 = "tosa.avg_pool2d"(%arg0) {pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>)
355+
%0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, kernel = array<i64: 4, 4>, stride = array<i64: 1, 1>} : (tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>)
356356
return %0 : tensor<?x5x33x62xf32>
357357
}
358358

mlir/test/Dialect/Tosa/ops.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,28 +12,28 @@ func.func @test_argmax(%arg0: tensor<14x19xf32>) -> tensor<14xi32> {
1212
// -----
1313
// CHECK-LABEL: avg_pool2d_f32
1414
func.func @test_avg_pool2d_f32(%arg0: tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32> {
15-
%0 = "tosa.avg_pool2d"(%arg0) {kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32>
15+
%0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xf32>) -> tensor<1x7x7x9xf32>
1616
return %0 : tensor<1x7x7x9xf32>
1717
}
1818

1919
// -----
2020
// CHECK-LABEL: avg_pool2d_i8
2121
func.func @test_avg_pool2d_i8(%arg0: tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> {
22-
%0 = "tosa.avg_pool2d"(%arg0) {kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8>
22+
%0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8>
2323
return %0 : tensor<1x7x7x9xi8>
2424
}
2525

2626
// -----
2727
// CHECK-LABEL: avg_pool2d_i16
2828
func.func @test_avg_pool2d_i16(%arg0: tensor<1x7x7x9xi16>) -> tensor<1x7x7x9xi16> {
29-
%0 = "tosa.avg_pool2d"(%arg0) {kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xi16>) -> tensor<1x7x7x9xi16>
29+
%0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9xi16>) -> tensor<1x7x7x9xi16>
3030
return %0 : tensor<1x7x7x9xi16>
3131
}
3232

3333
// -----
3434
// CHECK-LABEL: avg_pool2d_q8
3535
func.func @test_avg_pool2d_q8(%arg0: tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>> {
36-
%0 = "tosa.avg_pool2d"(%arg0) {kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>
36+
%0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array<i64: 2, 2>, pad = array<i64: 0, 1, 0, 1>, stride = array<i64: 1, 1>} : (tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>) -> tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>
3737
return %0 : tensor<1x7x7x9x!quant.uniform<i8:f32, 0.01>>
3838
}
3939

mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -659,7 +659,7 @@ func.func @scatter_minimum_static(%arg0 : tensor<?x4x?xi32>, %arg1 : tensor<3x?x
659659
// CHECK-LABEL: @test_pool_static
660660
func.func @test_pool_static(%arg0: tensor<3x5x6x7xf32>) {
661661
// CHECK: -> tensor<3x2x4x7xf32>
662-
%0 = "tosa.avg_pool2d"(%arg0) {kernel = array<i64: 4, 3>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<3x5x6x7xf32>) -> tensor<?x?x?x?xf32>
662+
%0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array<i64: 4, 3>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<3x5x6x7xf32>) -> tensor<?x?x?x?xf32>
663663

664664
// CHECK: -> tensor<3x2x4x7xf32>
665665
%1 = "tosa.max_pool2d"(%arg0) {kernel = array<i64: 4, 3>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<3x5x6x7xf32>) -> tensor<?x?x?x?xf32>
@@ -689,7 +689,7 @@ func.func @conv2d_dynamic_input(%input: tensor<?x?x?x?xf32>, %weights: tensor<5x
689689
// CHECK-LABEL: @test_pool_dynamic_input
690690
func.func @test_pool_dynamic_input(%arg0: tensor<?x?x?x?xf32>) {
691691
// CHECK: -> tensor<?x?x?x?xf32>
692-
%0 = "tosa.avg_pool2d"(%arg0) {kernel = array<i64: 4, 3>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
692+
%0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array<i64: 4, 3>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
693693

694694
// CHECK: -> tensor<?x?x?x?xf32>
695695
%1 = "tosa.max_pool2d"(%arg0) {kernel = array<i64: 4, 3>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
@@ -701,7 +701,7 @@ func.func @test_pool_dynamic_input(%arg0: tensor<?x?x?x?xf32>) {
701701
// CHECK-LABEL: @test_pool_padded
702702
func.func @test_pool_padded(%arg0: tensor<3x5x6x7xf32>) {
703703
// CHECK: -> tensor<3x5x11x7xf32>
704-
%0 = "tosa.avg_pool2d"(%arg0) {kernel = array<i64: 4, 3>, pad = array<i64: 1, 2, 3, 4>, stride = array<i64: 1, 1>} : (tensor<3x5x6x7xf32>) -> tensor<?x?x?x?xf32>
704+
%0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array<i64: 4, 3>, pad = array<i64: 1, 2, 3, 4>, stride = array<i64: 1, 1>} : (tensor<3x5x6x7xf32>) -> tensor<?x?x?x?xf32>
705705

706706
// CHECK: -> tensor<3x5x11x7xf32>
707707
%1 = "tosa.max_pool2d"(%arg0) {kernel = array<i64: 4, 3>, pad = array<i64: 1, 2, 3, 4>, stride = array<i64: 1, 1>} : (tensor<3x5x6x7xf32>) -> tensor<?x?x?x?xf32>
@@ -731,7 +731,7 @@ func.func @conv2d_dynamic_bias(%input: tensor<2x8x9x3xf32>, %weights: tensor<5x3
731731
// CHECK-LABEL: @test_pool_stride
732732
func.func @test_pool_stride(%arg0: tensor<3x11x12x7xf32>) {
733733
// CHECK: -> tensor<3x4x4x7xf32>
734-
%0 = "tosa.avg_pool2d"(%arg0) {kernel = array<i64: 4, 3>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 3>} : (tensor<3x11x12x7xf32>) -> tensor<?x?x?x?xf32>
734+
%0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array<i64: 4, 3>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 3>} : (tensor<3x11x12x7xf32>) -> tensor<?x?x?x?xf32>
735735

736736
// CHECK: -> tensor<3x4x4x7xf32>
737737
%1 = "tosa.max_pool2d"(%arg0) {kernel = array<i64: 4, 3>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 3>} : (tensor<3x11x12x7xf32>) -> tensor<?x?x?x?xf32>

0 commit comments

Comments
 (0)