Skip to content

[mlir][tosa] Add acc_type to Tosa-v1.0 Conv Ops #121466

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 8, 2025

Conversation

FranklandJack
Copy link
Contributor

Tosa v1.0 adds accumulator type attributes to the various convolution operations defined in the spec. Update the dialect and any lit tests to include these attributes.

@llvmbot
Copy link
Member

llvmbot commented Jan 2, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Jack Frankland (FranklandJack)

Changes

Tosa v1.0 adds accumulator type attributes to the various convolution operations defined in the spec. Update the dialect and any lit tests to include these attributes.


Patch is 109.05 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/121466.diff

15 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td (+6-4)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+5-1)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+69-5)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp (+7-4)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+18-18)
  • (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+4-4)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+5-5)
  • (modified) mlir/test/Dialect/Tosa/level_check.mlir (+36-36)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+8-8)
  • (modified) mlir/test/Dialect/Tosa/quant-test.mlir (+3-3)
  • (modified) mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir (+4-4)
  • (modified) mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir (+3-3)
  • (modified) mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir (+9-7)
  • (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+30-30)
  • (modified) mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp (+1-1)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index f5536927dc251d..d3f12c34421b06 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -126,11 +126,12 @@ def Tosa_ConvOpQuantInfoBuilder : OpBuilder<
   (ins "::mlir::Type":$outputType, "::mlir::Value":$input,
        "::mlir::Value":$weight, "::mlir::Value":$bias,
        "::mlir::DenseI64ArrayAttr":$pad, "::mlir::DenseI64ArrayAttr":$stride,
-       "::mlir::DenseI64ArrayAttr":$dilation),
+       "::mlir::DenseI64ArrayAttr":$dilation,
+       "::mlir::TypeAttr":$acc_type),
   [{
     buildConvOpWithQuantInfo($_builder, $_state, outputType,
                              input, weight, bias,
-                             pad, stride, dilation);
+                             pad, stride, dilation, acc_type);
   }]>;
 
 // Handles tosa.transpose_conv2d which has an outpad and output shape attribute.
@@ -139,12 +140,13 @@ def Tosa_TransConvOpQuantInfoBuilder : OpBuilder<
        "::mlir::Value":$weight, "mlir::Value":$bias,
        "::mlir::DenseI64ArrayAttr":$outpad,
        "::mlir::DenseI64ArrayAttr":$stride,
-       "::mlir::DenseI64ArrayAttr":$outputShape),
+       "::mlir::DenseI64ArrayAttr":$outputShape,
+       "::mlir::TypeAttr":$acc_type),
   [{
     buildTransConvOpWithQuantInfo($_builder, $_state, outputType,
                                   input, weight, bias,
                                   outpad, stride,
-                                  outputShape);
+                                  outputShape, acc_type);
   }]>;
 
 // The tosa.fully_connected op has its own builder as it does not have
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 8ae5d3ab417b69..ec7fcd7749848b 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -57,7 +57,7 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> {
 // Accumulator types.
 //===----------------------------------------------------------------------===//
 
-def Tosa_AccType : AnyTypeOf<[I<32>, SI<32>, F16, F32]>;
+def Tosa_AccType : AnyTypeOf<[I<32>, I<48>, F16, F32]>;
 
 //===----------------------------------------------------------------------===//
 // Operator: avg_pool2d
@@ -106,6 +106,7 @@ def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> {
     Tosa_IntArrayAttr4:$pad,
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr2:$dilation,
+    TypeAttrOf<Tosa_AccType>:$acc_type,
     OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
     DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
   );
@@ -135,6 +136,7 @@ def Tosa_Conv3DOp : Tosa_InferShapedTypeOp<"conv3d"> {
     Tosa_IntArrayAttr6:$pad,
     Tosa_IntArrayAttr3:$stride,
     Tosa_IntArrayAttr3:$dilation,
+    TypeAttrOf<Tosa_AccType>:$acc_type,
     OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
     DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
   );
@@ -165,6 +167,7 @@ def Tosa_DepthwiseConv2DOp : Tosa_InferShapedTypeOp<"depthwise_conv2d"> {
     Tosa_IntArrayAttr4:$pad,
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr2:$dilation,
+    TypeAttrOf<Tosa_AccType>:$acc_type,
     OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
     DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
   );
@@ -348,6 +351,7 @@ def Tosa_TransposeConv2DOp : Tosa_InferShapedTypeOp<"transpose_conv2d"> {
     Tosa_IntArrayAttr4:$out_pad,
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr4:$out_shape,
+    TypeAttrOf<Tosa_AccType>:$acc_type,
     OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
     DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
   );
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 631d3c48f2df02..cb844117508a45 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -271,6 +271,55 @@ LogicalResult tosa::ConstOp::verify() {
   return success();
 }
 
+template <typename T>
+static LogicalResult verifyConvOpModes(T op) {
+  auto inputEType =
+      llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
+
+  if (auto quantType =
+          llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType))
+    inputEType = quantType.getStorageType();
+
+  auto accType = op.getAccType();
+  if (inputEType.isInteger(8) && !accType.isInteger(32))
+    return op.emitOpError("accumulator type for i8 tensor is not i32");
+
+  if (inputEType.isInteger(16) && !accType.isInteger(48))
+    return op.emitOpError("accumulator type for i16 tensor is not i48");
+
+  if ((inputEType.isFloat8E5M2() || inputEType.isFloat8E4M3FN()) &&
+      !accType.isF16())
+    return op.emitOpError("accumulator type for f8 tensor is not f16");
+
+  if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
+    return op.emitOpError("accumulator type for f16 tensor is not f16/f32");
+
+  if (inputEType.isBF16() && !accType.isF32())
+    return op.emitOpError("accumulator type for bf16 tensor is not f32");
+
+  if (inputEType.isF32() && !accType.isF32())
+    return op.emitOpError("accumulator type for f32 tensor is not f32");
+
+  auto resultEType =
+      llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
+
+  if (auto quantType =
+          llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultEType))
+    resultEType = quantType.getStorageType();
+
+  // check allowed input/result element types combinations
+  if ((inputEType.isInteger(8) && resultEType.isInteger(32)) ||
+      (inputEType.isInteger(16) && resultEType.isInteger(48)) ||
+      (inputEType.isFloat8E5M2() && resultEType.isF16()) ||
+      (inputEType.isFloat8E4M3FN() && resultEType.isF16()) ||
+      (inputEType.isF16() && resultEType.isF16()) ||
+      (inputEType.isBF16() && resultEType.isBF16()) ||
+      (inputEType.isF32() && resultEType.isF32()))
+    return success();
+
+  return op.emitOpError("input/output element types are incompatible.");
+}
+
 LogicalResult tosa::ArgMaxOp::verify() {
   // Ensure output is of 32-bit integer
   const auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
@@ -368,12 +417,14 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
                                      Type outputType, Value input, Value weight,
                                      Value bias, DenseI64ArrayAttr pad,
                                      DenseI64ArrayAttr stride,
-                                     DenseI64ArrayAttr dilation) {
+                                     DenseI64ArrayAttr dilation,
+                                     TypeAttr accType) {
 
   result.addOperands({input, weight, bias});
   result.addAttribute("pad", pad);
   result.addAttribute("stride", stride);
   result.addAttribute("dilation", dilation);
+  result.addAttribute("acc_type", accType);
 
   auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
   if (quantAttr) {
@@ -390,11 +441,12 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
 static void buildTransConvOpWithQuantInfo(
     OpBuilder &builder, OperationState &result, Type outputType, Value input,
     Value weight, Value bias, DenseI64ArrayAttr outpad,
-    DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape) {
+    DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape, TypeAttr accType) {
   result.addOperands({input, weight, bias});
   result.addAttribute("out_pad", outpad);
   result.addAttribute("stride", stride);
   result.addAttribute("out_shape", outputShape);
+  result.addAttribute("acc_type", accType);
   auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
 
   if (quantAttr) {
@@ -1595,7 +1647,11 @@ LogicalResult Conv2DOp::inferReturnTypeComponents(
   return success();
 }
 
-LogicalResult Conv2DOp::verify() { return verifyConvOp(*this); }
+LogicalResult Conv2DOp::verify() {
+  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
+    return failure();
+  return success();
+}
 
 LogicalResult Conv3DOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
@@ -1667,7 +1723,11 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
   return success();
 }
 
-LogicalResult Conv3DOp::verify() { return verifyConvOp(*this); }
+LogicalResult Conv3DOp::verify() {
+  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
+    return failure();
+  return success();
+}
 
 LogicalResult AvgPool2dOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
@@ -1762,7 +1822,11 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
   return success();
 }
 
-LogicalResult DepthwiseConv2DOp::verify() { return verifyConvOp(*this); }
+LogicalResult DepthwiseConv2DOp::verify() {
+  if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
+    return failure();
+  return success();
+}
 
 LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index 0779cdb9667a1a..275c2f80f7f49d 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -75,13 +75,15 @@ class TransposeConvNonStridedConverter
           loc, resultTy, input, reverse2, bias,
           rewriter.getDenseI64ArrayAttr(convPad),
           rewriter.getDenseI64ArrayAttr(stride),
-          rewriter.getDenseI64ArrayAttr({1, 1}), *op.getQuantizationInfo());
+          rewriter.getDenseI64ArrayAttr({1, 1}),
+          /* acc_type = */ op.getAccType(), *op.getQuantizationInfo());
     } else {
       conv2d = rewriter.create<tosa::Conv2DOp>(
           loc, resultTy, input, reverse2, bias,
           rewriter.getDenseI64ArrayAttr(convPad),
           rewriter.getDenseI64ArrayAttr(stride),
-          rewriter.getDenseI64ArrayAttr({1, 1}));
+          rewriter.getDenseI64ArrayAttr({1, 1}),
+          /* acc_type = */ op.getAccTypeAttr());
     }
 
     rewriter.replaceOp(op, conv2d);
@@ -238,7 +240,7 @@ class TransposeConvStridedConverter
                    /*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
                    /*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}),
                    /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}),
-                   *op.getQuantizationInfo())
+                   /* acc_type = */ op.getAccType(), *op.getQuantizationInfo())
                    .getResult();
     } else {
       conv2d = CreateOpAndInferShape<tosa::Conv2DOp>(
@@ -246,7 +248,8 @@ class TransposeConvStridedConverter
                    weight, zeroBias,
                    /*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
                    /*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}),
-                   /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}))
+                   /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}),
+                   /* acc_type = */ op.getAccTypeAttr())
                    .getResult();
     }
 
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index bfdc72ee07e97f..453a8610e7169a 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -510,7 +510,7 @@ func.func @avg_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>)
 func.func @conv2d_scalar_bias_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<1xf32>) -> () {
   // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x45x40x28xf32>
   // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%[[INIT]] : tensor<1x45x40x28xf32>) {
-  %0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<1xf32>) -> tensor<1x45x40x28xf32>
+  %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<1xf32>) -> tensor<1x45x40x28xf32>
   return
 }
 
@@ -531,7 +531,7 @@ func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi
   // CHECK: linalg.conv_2d_nhwc_fhwc_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %c0_i32, %c0_i32_0 : tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, i32, i32) outs(%[[BROADCAST]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
   // HWCF: linalg.conv_2d_nhwc_hwcf_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]], %c0_i32, %c0_i32_0 : tensor<1x49x42x27xi8>, tensor<1x1x27x28xi8>, i32, i32) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
 
-  %0 = tosa.conv2d %input, %weights, %bias {dilation = array<i64: 2, 1>, pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>, stride = array<i64: 1, 1>} : (tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, tensor<28xi8>) -> tensor<1x45x40x28xi32>
+  %0 = tosa.conv2d %input, %weights, %bias {acc_type = i32, dilation = array<i64: 2, 1>, pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>, stride = array<i64: 1, 1>} : (tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, tensor<28xi8>) -> tensor<1x45x40x28xi32>
   return
 }
 
@@ -552,7 +552,7 @@ func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27
   // CHECK: linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%1 : tensor<1x45x40x28xf32>) -> tensor<1x45x40x28xf32>
 
   // HWCF: linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]] : tensor<1x49x42x27xf32>, tensor<3x3x27x28xf32>) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xf32>
-  %0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>) -> tensor<1x45x40x28xf32>
+  %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>) -> tensor<1x45x40x28xf32>
   return
 }
 
@@ -571,7 +571,7 @@ func.func @conv2d_dyn(%input: tensor<?x49x42x27xf32>, %weights: tensor<28x3x3x27
   // CHECK:   linalg.yield %[[IN]] : f32
   // CHECK: } -> tensor<?x45x40x28xf32>
   // CHECK: %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<?x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%[[BROADCAST]] : tensor<?x45x40x28xf32>) -> tensor<?x45x40x28xf32>
-  %0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<?x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>) -> tensor<?x45x40x28xf32>
+  %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<?x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>) -> tensor<?x45x40x28xf32>
   return
 }
 
@@ -627,7 +627,7 @@ func.func @conv2d_dyn_w_h(%input: tensor<1x?x?x27xf32>, %weights: tensor<28x3x3x
   // CHECK: } -> tensor<1x?x?x28xf32>
   // CHECK: linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x?x?x27xf32>, tensor<28x3x3x27xf32>) outs(%17 : tensor<1x?x?x28xf32>) -> tensor<1x?x?x28xf32>
 
-  %0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x?x?x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>) -> tensor<1x?x?x28xf32>
+  %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x?x?x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>) -> tensor<1x?x?x28xf32>
   return
 }
 
@@ -650,7 +650,7 @@ func.func @conv2d_dyn_output(%input: tensor<2x6x5x4xf32>, %weights: tensor<4x3x3
   //   linalg.yield %[[ADD]] : f32
   // } -> tensor<?x4x3x4xf32>
 
-  %0 = tosa.conv2d %input, %weights, %bias {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<2x6x5x4xf32    >, tensor<4x3x3x4xf32>, tensor<4xf32>) -> tensor<?x4x3x4xf32>
+  %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<2x6x5x4xf32    >, tensor<4x3x3x4xf32>, tensor<4xf32>) -> tensor<?x4x3x4xf32>
   return
 }
 
@@ -662,7 +662,7 @@ func.func @conv2d_padded_f32(%input: tensor<1x47x40x28xf32>, %weights: tensor<28
   // CHECK: tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
   // CHECK:   tensor.yield %[[C0]]
   // CHECK: linalg.conv_2d_nhwc_fhwc
-  %0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x47x40x28xf32>, tensor<28x3x3x28xf32>, tensor<28xf32>) -> tensor<1x45x40x28xf32>
+  %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x47x40x28xf32>, tensor<28x3x3x28xf32>, tensor<28xf32>) -> tensor<1x45x40x28xf32>
   return
 }
 
@@ -674,7 +674,7 @@ func.func @conv2d_quant(%arg0 : tensor<1x12x12x1xi8>, %arg1 : tensor<1024x3x3x1x
   // CHECK: tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
   // CHECK:   tensor.yield %[[C22]]
   // CHECK: linalg.conv_2d_nhwc_fhwc_q
-  %0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 1, 1, 1, 1>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, stride = array<i64: 1, 1>} : (tensor<1x12x12x1xi8>, tensor<1024x3x3x1xi8>, tensor<1024xi32>) -> tensor<1x12x12x1024xi32>
+  %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 1, 1, 1, 1>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, stride = array<i64: 1, 1>} : (tensor<1x12x12x1xi8>, tensor<1024x3x3x1xi8>, tensor<1024xi32>) -> tensor<1x12x12x1024xi32>
   return
 }
 
@@ -696,7 +696,7 @@ func.func @depthwise_conv(%arg0 : tensor<1x7x5x3xf32>, %arg1 : tensor<3x1x3x11xf
   // CHECK:   [[ADD:%.+]] = arith.addf %[[ARG3]], %[[ARG4]] : f32
   // CHECK:   linalg.yield [[ADD]] : f32
   // CHECK: } -> tensor<1x5x5x33xf32>
-  %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 { pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>, tensor<33xf32>)  -> tensor<1x5x5x33xf32>
+  %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 { acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>, tensor<33xf32>)  -> tensor<1x5x5x33xf32>
   return
 }
 
@@ -712,7 +712,7 @@ func.func @depthwise_conv_scalar_bias(%arg0 : tensor<1x7x5x3xf32>, %arg1 : tenso
   // CHECK:   [[ADD:%.+]] = arith.addf %[[ARG3]], %[[ARG4]] : f32
   // CHECK:   linalg.yield [[ADD]] : f32
   // CHECK: } -> tensor<1x5x5x33xf32>
-  %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 { pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>, tensor<1xf32>)  -> tensor<1x5x5x33xf32>
+  %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>, tensor<1xf32>)  -> tensor<1x5x5x33xf32>
   return
 }
 
@@ -736,7 +736,7...
[truncated]

@GeorgeARM
Copy link
Contributor

GeorgeARM commented Jan 2, 2025

As per @lhutton1 comment, you probably need to perform some of the checks in the TosaValidation pass as per specification requirements.

Copy link
Contributor

@GeorgeARM GeorgeARM left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the existing code and this patch I see a discrepancy of how the verifiers and validations have been laid out. I know that we made the suggestion to move the check into the TosaValidationPass but this is not how AvgPool2d is being done at the moment.
So, could you please move the checks back to the verifier and decide what things need to move into the validation pass all-together with subsequent patches when profile intersection takes place?

@FranklandJack
Copy link
Contributor Author

Looking at the existing code and this patch I see a discrepancy of how the verifiers and validations have been laid out. I know that I made the suggestion to move the check into the TosaValidationPass but this is not how AvgPool2d is being done at the moment. So, could you please move the checks back to the verifier and decide what things need to move into the validation pass all-together with subsequent patches when profile intersection takes place?

I've reverted to the original commit.

Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies I hadn't noticed AvgPool2D, happy to pull out moving the acc_type checks from the verifier to the validation pass into a separate change, especially since it doesn't impose any new restrictions (from v0.80 expectations).

LGTM with an additional comment to consider and a small nit - does it need some negative tests to check the verifier?

resultEType = quantType.getStorageType();

// check allowed input/result element types combinations
if ((inputEType.isInteger(8) && resultEType.isInteger(32)) ||
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to be more careful about adding these checks, since they impose a new restriction on the allowed data types of the operator. Are they necessary for this PR? If not, perhaps we can remove them for now and re-introduce in the validation pass in a separate change to minimise the possible impact?

Tosa v1.0 adds accumulator type attributes to the various convolution
operations defined in the spec. Update the dialect and any lit tests to
include these attributes.

Signed-off-by: Tai Ly <[email protected]>
Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates, LGTM! Would be good for @sjarus, @eric-k256 to check as well since this is breaking existing legalizations to TOSA

@GeorgeARM GeorgeARM merged commit 360a03c into llvm:main Jan 8, 2025
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants