-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][tosa] Add acc_type to Tosa-v1.0 Conv Ops #121466
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: Jack Frankland (FranklandJack) ChangesTosa v1.0 adds accumulator type attributes to the various convolution operations defined in the spec. Update the dialect and any lit tests to include these attributes. Patch is 109.05 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/121466.diff 15 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index f5536927dc251d..d3f12c34421b06 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -126,11 +126,12 @@ def Tosa_ConvOpQuantInfoBuilder : OpBuilder<
(ins "::mlir::Type":$outputType, "::mlir::Value":$input,
"::mlir::Value":$weight, "::mlir::Value":$bias,
"::mlir::DenseI64ArrayAttr":$pad, "::mlir::DenseI64ArrayAttr":$stride,
- "::mlir::DenseI64ArrayAttr":$dilation),
+ "::mlir::DenseI64ArrayAttr":$dilation,
+ "::mlir::TypeAttr":$acc_type),
[{
buildConvOpWithQuantInfo($_builder, $_state, outputType,
input, weight, bias,
- pad, stride, dilation);
+ pad, stride, dilation, acc_type);
}]>;
// Handles tosa.transpose_conv2d which has an outpad and output shape attribute.
@@ -139,12 +140,13 @@ def Tosa_TransConvOpQuantInfoBuilder : OpBuilder<
"::mlir::Value":$weight, "mlir::Value":$bias,
"::mlir::DenseI64ArrayAttr":$outpad,
"::mlir::DenseI64ArrayAttr":$stride,
- "::mlir::DenseI64ArrayAttr":$outputShape),
+ "::mlir::DenseI64ArrayAttr":$outputShape,
+ "::mlir::TypeAttr":$acc_type),
[{
buildTransConvOpWithQuantInfo($_builder, $_state, outputType,
input, weight, bias,
outpad, stride,
- outputShape);
+ outputShape, acc_type);
}]>;
// The tosa.fully_connected op has its own builder as it does not have
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 8ae5d3ab417b69..ec7fcd7749848b 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -57,7 +57,7 @@ def Tosa_ArgMaxOp : Tosa_InferShapedTypeOp<"argmax"> {
// Accumulator types.
//===----------------------------------------------------------------------===//
-def Tosa_AccType : AnyTypeOf<[I<32>, SI<32>, F16, F32]>;
+def Tosa_AccType : AnyTypeOf<[I<32>, I<48>, F16, F32]>;
//===----------------------------------------------------------------------===//
// Operator: avg_pool2d
@@ -106,6 +106,7 @@ def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> {
Tosa_IntArrayAttr4:$pad,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr2:$dilation,
+ TypeAttrOf<Tosa_AccType>:$acc_type,
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
);
@@ -135,6 +136,7 @@ def Tosa_Conv3DOp : Tosa_InferShapedTypeOp<"conv3d"> {
Tosa_IntArrayAttr6:$pad,
Tosa_IntArrayAttr3:$stride,
Tosa_IntArrayAttr3:$dilation,
+ TypeAttrOf<Tosa_AccType>:$acc_type,
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
);
@@ -165,6 +167,7 @@ def Tosa_DepthwiseConv2DOp : Tosa_InferShapedTypeOp<"depthwise_conv2d"> {
Tosa_IntArrayAttr4:$pad,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr2:$dilation,
+ TypeAttrOf<Tosa_AccType>:$acc_type,
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
);
@@ -348,6 +351,7 @@ def Tosa_TransposeConv2DOp : Tosa_InferShapedTypeOp<"transpose_conv2d"> {
Tosa_IntArrayAttr4:$out_pad,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr4:$out_shape,
+ TypeAttrOf<Tosa_AccType>:$acc_type,
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
);
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 631d3c48f2df02..cb844117508a45 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -271,6 +271,55 @@ LogicalResult tosa::ConstOp::verify() {
return success();
}
+template <typename T>
+static LogicalResult verifyConvOpModes(T op) {
+ auto inputEType =
+ llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
+
+ if (auto quantType =
+ llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType))
+ inputEType = quantType.getStorageType();
+
+ auto accType = op.getAccType();
+ if (inputEType.isInteger(8) && !accType.isInteger(32))
+ return op.emitOpError("accumulator type for i8 tensor is not i32");
+
+ if (inputEType.isInteger(16) && !accType.isInteger(48))
+ return op.emitOpError("accumulator type for i16 tensor is not i48");
+
+ if ((inputEType.isFloat8E5M2() || inputEType.isFloat8E4M3FN()) &&
+ !accType.isF16())
+ return op.emitOpError("accumulator type for f8 tensor is not f16");
+
+ if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
+ return op.emitOpError("accumulator type for f16 tensor is not f16/f32");
+
+ if (inputEType.isBF16() && !accType.isF32())
+ return op.emitOpError("accumulator type for bf16 tensor is not f32");
+
+ if (inputEType.isF32() && !accType.isF32())
+ return op.emitOpError("accumulator type for f32 tensor is not f32");
+
+ auto resultEType =
+ llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
+
+ if (auto quantType =
+ llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultEType))
+ resultEType = quantType.getStorageType();
+
+ // check allowed input/result element types combinations
+ if ((inputEType.isInteger(8) && resultEType.isInteger(32)) ||
+ (inputEType.isInteger(16) && resultEType.isInteger(48)) ||
+ (inputEType.isFloat8E5M2() && resultEType.isF16()) ||
+ (inputEType.isFloat8E4M3FN() && resultEType.isF16()) ||
+ (inputEType.isF16() && resultEType.isF16()) ||
+ (inputEType.isBF16() && resultEType.isBF16()) ||
+ (inputEType.isF32() && resultEType.isF32()))
+ return success();
+
+ return op.emitOpError("input/output element types are incompatible.");
+}
+
LogicalResult tosa::ArgMaxOp::verify() {
// Ensure output is of 32-bit integer
const auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
@@ -368,12 +417,14 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
Type outputType, Value input, Value weight,
Value bias, DenseI64ArrayAttr pad,
DenseI64ArrayAttr stride,
- DenseI64ArrayAttr dilation) {
+ DenseI64ArrayAttr dilation,
+ TypeAttr accType) {
result.addOperands({input, weight, bias});
result.addAttribute("pad", pad);
result.addAttribute("stride", stride);
result.addAttribute("dilation", dilation);
+ result.addAttribute("acc_type", accType);
auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
if (quantAttr) {
@@ -390,11 +441,12 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
static void buildTransConvOpWithQuantInfo(
OpBuilder &builder, OperationState &result, Type outputType, Value input,
Value weight, Value bias, DenseI64ArrayAttr outpad,
- DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape) {
+ DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape, TypeAttr accType) {
result.addOperands({input, weight, bias});
result.addAttribute("out_pad", outpad);
result.addAttribute("stride", stride);
result.addAttribute("out_shape", outputShape);
+ result.addAttribute("acc_type", accType);
auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight);
if (quantAttr) {
@@ -1595,7 +1647,11 @@ LogicalResult Conv2DOp::inferReturnTypeComponents(
return success();
}
-LogicalResult Conv2DOp::verify() { return verifyConvOp(*this); }
+LogicalResult Conv2DOp::verify() {
+ if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
+ return failure();
+ return success();
+}
LogicalResult Conv3DOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
@@ -1667,7 +1723,11 @@ LogicalResult Conv3DOp::inferReturnTypeComponents(
return success();
}
-LogicalResult Conv3DOp::verify() { return verifyConvOp(*this); }
+LogicalResult Conv3DOp::verify() {
+ if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
+ return failure();
+ return success();
+}
LogicalResult AvgPool2dOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
@@ -1762,7 +1822,11 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
return success();
}
-LogicalResult DepthwiseConv2DOp::verify() { return verifyConvOp(*this); }
+LogicalResult DepthwiseConv2DOp::verify() {
+ if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed())
+ return failure();
+ return success();
+}
LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index 0779cdb9667a1a..275c2f80f7f49d 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -75,13 +75,15 @@ class TransposeConvNonStridedConverter
loc, resultTy, input, reverse2, bias,
rewriter.getDenseI64ArrayAttr(convPad),
rewriter.getDenseI64ArrayAttr(stride),
- rewriter.getDenseI64ArrayAttr({1, 1}), *op.getQuantizationInfo());
+ rewriter.getDenseI64ArrayAttr({1, 1}),
+ /* acc_type = */ op.getAccType(), *op.getQuantizationInfo());
} else {
conv2d = rewriter.create<tosa::Conv2DOp>(
loc, resultTy, input, reverse2, bias,
rewriter.getDenseI64ArrayAttr(convPad),
rewriter.getDenseI64ArrayAttr(stride),
- rewriter.getDenseI64ArrayAttr({1, 1}));
+ rewriter.getDenseI64ArrayAttr({1, 1}),
+ /* acc_type = */ op.getAccTypeAttr());
}
rewriter.replaceOp(op, conv2d);
@@ -238,7 +240,7 @@ class TransposeConvStridedConverter
/*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
/*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}),
/*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}),
- *op.getQuantizationInfo())
+ /* acc_type = */ op.getAccType(), *op.getQuantizationInfo())
.getResult();
} else {
conv2d = CreateOpAndInferShape<tosa::Conv2DOp>(
@@ -246,7 +248,8 @@ class TransposeConvStridedConverter
weight, zeroBias,
/*pad=*/rewriter.getDenseI64ArrayAttr({0, 0, 0, 0}),
/*stride=*/rewriter.getDenseI64ArrayAttr({1, 1}),
- /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}))
+ /*dilation=*/rewriter.getDenseI64ArrayAttr({1, 1}),
+ /* acc_type = */ op.getAccTypeAttr())
.getResult();
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index bfdc72ee07e97f..453a8610e7169a 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -510,7 +510,7 @@ func.func @avg_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>)
func.func @conv2d_scalar_bias_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<1xf32>) -> () {
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x45x40x28xf32>
// CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<1xf32>) outs(%[[INIT]] : tensor<1x45x40x28xf32>) {
- %0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<1xf32>) -> tensor<1x45x40x28xf32>
+ %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<1xf32>) -> tensor<1x45x40x28xf32>
return
}
@@ -531,7 +531,7 @@ func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi
// CHECK: linalg.conv_2d_nhwc_fhwc_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %c0_i32, %c0_i32_0 : tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, i32, i32) outs(%[[BROADCAST]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
// HWCF: linalg.conv_2d_nhwc_hwcf_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]], %c0_i32, %c0_i32_0 : tensor<1x49x42x27xi8>, tensor<1x1x27x28xi8>, i32, i32) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
- %0 = tosa.conv2d %input, %weights, %bias {dilation = array<i64: 2, 1>, pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>, stride = array<i64: 1, 1>} : (tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, tensor<28xi8>) -> tensor<1x45x40x28xi32>
+ %0 = tosa.conv2d %input, %weights, %bias {acc_type = i32, dilation = array<i64: 2, 1>, pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>, stride = array<i64: 1, 1>} : (tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, tensor<28xi8>) -> tensor<1x45x40x28xi32>
return
}
@@ -552,7 +552,7 @@ func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27
// CHECK: linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%1 : tensor<1x45x40x28xf32>) -> tensor<1x45x40x28xf32>
// HWCF: linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]] : tensor<1x49x42x27xf32>, tensor<3x3x27x28xf32>) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xf32>
- %0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>) -> tensor<1x45x40x28xf32>
+ %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>) -> tensor<1x45x40x28xf32>
return
}
@@ -571,7 +571,7 @@ func.func @conv2d_dyn(%input: tensor<?x49x42x27xf32>, %weights: tensor<28x3x3x27
// CHECK: linalg.yield %[[IN]] : f32
// CHECK: } -> tensor<?x45x40x28xf32>
// CHECK: %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<?x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%[[BROADCAST]] : tensor<?x45x40x28xf32>) -> tensor<?x45x40x28xf32>
- %0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<?x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>) -> tensor<?x45x40x28xf32>
+ %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<?x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>) -> tensor<?x45x40x28xf32>
return
}
@@ -627,7 +627,7 @@ func.func @conv2d_dyn_w_h(%input: tensor<1x?x?x27xf32>, %weights: tensor<28x3x3x
// CHECK: } -> tensor<1x?x?x28xf32>
// CHECK: linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x?x?x27xf32>, tensor<28x3x3x27xf32>) outs(%17 : tensor<1x?x?x28xf32>) -> tensor<1x?x?x28xf32>
- %0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x?x?x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>) -> tensor<1x?x?x28xf32>
+ %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x?x?x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>) -> tensor<1x?x?x28xf32>
return
}
@@ -650,7 +650,7 @@ func.func @conv2d_dyn_output(%input: tensor<2x6x5x4xf32>, %weights: tensor<4x3x3
// linalg.yield %[[ADD]] : f32
// } -> tensor<?x4x3x4xf32>
- %0 = tosa.conv2d %input, %weights, %bias {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<2x6x5x4xf32 >, tensor<4x3x3x4xf32>, tensor<4xf32>) -> tensor<?x4x3x4xf32>
+ %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<2x6x5x4xf32 >, tensor<4x3x3x4xf32>, tensor<4xf32>) -> tensor<?x4x3x4xf32>
return
}
@@ -662,7 +662,7 @@ func.func @conv2d_padded_f32(%input: tensor<1x47x40x28xf32>, %weights: tensor<28
// CHECK: tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
// CHECK: tensor.yield %[[C0]]
// CHECK: linalg.conv_2d_nhwc_fhwc
- %0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x47x40x28xf32>, tensor<28x3x3x28xf32>, tensor<28xf32>) -> tensor<1x45x40x28xf32>
+ %0 = tosa.conv2d %input, %weights, %bias {acc_type = f32, pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x47x40x28xf32>, tensor<28x3x3x28xf32>, tensor<28xf32>) -> tensor<1x45x40x28xf32>
return
}
@@ -674,7 +674,7 @@ func.func @conv2d_quant(%arg0 : tensor<1x12x12x1xi8>, %arg1 : tensor<1024x3x3x1x
// CHECK: tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
// CHECK: tensor.yield %[[C22]]
// CHECK: linalg.conv_2d_nhwc_fhwc_q
- %0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 1, 1, 1, 1>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, stride = array<i64: 1, 1>} : (tensor<1x12x12x1xi8>, tensor<1024x3x3x1xi8>, tensor<1024xi32>) -> tensor<1x12x12x1024xi32>
+ %0 = tosa.conv2d %arg0, %arg1, %arg2 {acc_type = i32, dilation = array<i64: 1, 1>, pad = array<i64: 1, 1, 1, 1>, quantization_info = #tosa.conv_quant<input_zp = -22, weight_zp = 42>, stride = array<i64: 1, 1>} : (tensor<1x12x12x1xi8>, tensor<1024x3x3x1xi8>, tensor<1024xi32>) -> tensor<1x12x12x1024xi32>
return
}
@@ -696,7 +696,7 @@ func.func @depthwise_conv(%arg0 : tensor<1x7x5x3xf32>, %arg1 : tensor<3x1x3x11xf
// CHECK: [[ADD:%.+]] = arith.addf %[[ARG3]], %[[ARG4]] : f32
// CHECK: linalg.yield [[ADD]] : f32
// CHECK: } -> tensor<1x5x5x33xf32>
- %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 { pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>, tensor<33xf32>) -> tensor<1x5x5x33xf32>
+ %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 { acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>, tensor<33xf32>) -> tensor<1x5x5x33xf32>
return
}
@@ -712,7 +712,7 @@ func.func @depthwise_conv_scalar_bias(%arg0 : tensor<1x7x5x3xf32>, %arg1 : tenso
// CHECK: [[ADD:%.+]] = arith.addf %[[ARG3]], %[[ARG4]] : f32
// CHECK: linalg.yield [[ADD]] : f32
// CHECK: } -> tensor<1x5x5x33xf32>
- %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 { pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>, tensor<1xf32>) -> tensor<1x5x5x33xf32>
+ %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2 {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>, tensor<1xf32>) -> tensor<1x5x5x33xf32>
return
}
@@ -736,7 +736,7...
[truncated]
|
As per @lhutton1 comment, you probably need to perform some of the checks in the |
b367a8a
to
5152221
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at the existing code and this patch I see a discrepancy of how the verifiers and validations have been laid out. I know that we made the suggestion to move the check into the TosaValidationPass
but this is not how AvgPool2d
is being done at the moment.
So, could you please move the checks back to the verifier and decide what things need to move into the validation pass all-together with subsequent patches when profile intersection takes place?
5152221
to
131e59a
Compare
I've reverted to the original commit. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies I hadn't noticed AvgPool2D
, happy to pull out moving the acc_type
checks from the verifier to the validation pass into a separate change, especially since it doesn't impose any new restrictions (from v0.80 expectations).
LGTM with an additional comment to consider and a small nit - does it need some negative tests to check the verifier?
mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Outdated
resultEType = quantType.getStorageType(); | ||
|
||
// check allowed input/result element types combinations | ||
if ((inputEType.isInteger(8) && resultEType.isInteger(32)) || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We might want to be more careful about adding these checks, since they impose a new restriction on the allowed data types of the operator. Are they necessary for this PR? If not, perhaps we can remove them for now and re-introduce in the validation pass in a separate change to minimise the possible impact?
Tosa v1.0 adds accumulator type attributes to the various convolution operations defined in the spec. Update the dialect and any lit tests to include these attributes. Signed-off-by: Tai Ly <[email protected]>
131e59a
to
e817a37
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates, LGTM! Would be good for @sjarus, @eric-k256 to check as well since this is breaking existing legalizations to TOSA
Tosa v1.0 adds accumulator type attributes to the various convolution operations defined in the spec. Update the dialect and any lit tests to include these attributes.