Skip to content

[mlir][tosa-to-linalg] Add acc_type lowering Support #134267

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 7, 2025

Conversation

FranklandJack
Copy link
Contributor

Add support for lowering of convolution operations where the acc_type attribute differs from the result type of the operation. The only case of this in for convolutions in the TOSA-v1.0 specification is an fp16 convolution which internally uses an fp32 accumulator; all other operations have accumulator types that match their output/result types.

Add lit tests for the fp16 convolution with fp32 accumulator operators described above.

Add support for lowering of convolution operations where the `acc_type`
attribute differs from the result type of the operation. The only case
of this in for convolutions in the TOSA-v1.0 specification is an fp16
convolution which internally uses an fp32 accumulator; all other
operations have accumulator types that match their output/result types.

Add lit tests for the fp16 convolution with fp32 accumulator operators
described above.

Signed-off-by: Jack Frankland <[email protected]>
@llvmbot
Copy link
Member

llvmbot commented Apr 3, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Jack Frankland (FranklandJack)

Changes

Add support for lowering of convolution operations where the acc_type attribute differs from the result type of the operation. The only case of this in for convolutions in the TOSA-v1.0 specification is an fp16 convolution which internally uses an fp32 accumulator; all other operations have accumulator types that match their output/result types.

Add lit tests for the fp16 convolution with fp32 accumulator operators described above.


Full diff: https://github.com/llvm/llvm-project/pull/134267.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+37-13)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+40)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index fc1cad2423450..86f5e9baf4a94 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -119,10 +119,11 @@ static AffineMap getBroadcastingMap(PatternRewriter &rewriter, Value source,
 }
 
 // Broadcast the source value to all the outer dimensions of the result value.
-// If required, the element type is expanded using an arith.extsi operation.
-static mlir::Value linalgBroadcastAndMaybeExtSI(PatternRewriter &rewriter,
-                                                Location loc, Value source,
-                                                Value result) {
+// If required, the element type is expanded using an arith.extsi or arith.extf
+// operation as appropriate.
+static mlir::Value linalgBroadcastAndMaybeExt(PatternRewriter &rewriter,
+                                              Location loc, Value source,
+                                              Value result) {
   ShapedType resultTy = cast<ShapedType>(result.getType());
   const int64_t resultRank = resultTy.getRank();
   // Creating maps for the input and output of the broacast-like generic op.
@@ -135,11 +136,16 @@ static mlir::Value linalgBroadcastAndMaybeExtSI(PatternRewriter &rewriter,
       .create<linalg::GenericOp>(
           loc, resultTy, ValueRange({source}), result, indexingMaps,
           getNParallelLoopsAttrs(resultTy.getRank()),
-          [](OpBuilder &builder, Location loc, ValueRange args) {
+          [&resultTy](OpBuilder &builder, Location loc, ValueRange args) {
             Value biasVal = args[0];
             Type resType = args[1].getType();
             if (resType != biasVal.getType()) {
-              biasVal = builder.create<arith::ExtSIOp>(loc, resType, biasVal);
+              biasVal =
+                  resultTy.getElementType().isFloat()
+                      ? builder.create<arith::ExtFOp>(loc, resType, biasVal)
+                            .getResult()
+                      : builder.create<arith::ExtSIOp>(loc, resType, biasVal)
+                            .getResult();
             }
             builder.create<linalg::YieldOp>(loc, biasVal);
           })
@@ -253,12 +259,14 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
     ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
 
     Type inputETy = inputTy.getElementType();
-    Type resultETy = resultTy.getElementType();
 
     DenseI64ArrayAttr padAttr = op.getPadAttr();
     DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr();
     DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr();
 
+    Type accETy = op.getAccType();
+    Type accTy = RankedTensorType::get(resultTy.getShape(), accETy);
+
     // Get and verify zero points.
     FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
     if (failed(maybeIZp))
@@ -385,10 +393,10 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
     auto dilationAttr = rewriter.getI64TensorAttr(dilation);
 
     Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>(
-        loc, resultTy.getShape(), resultETy, filteredDims);
+        loc, resultTy.getShape(), accETy, filteredDims);
 
     Value broadcastBias =
-        linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);
+        linalgBroadcastAndMaybeExt(rewriter, loc, bias, biasEmptyTensor);
 
     if (hasZp) {
       auto iZp = rewriter.getI32IntegerAttr(inputZpVal);
@@ -410,10 +418,15 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
 
     Value conv = rewriter
                      .create<LinalgConvOp>(
-                         loc, resultTy, ValueRange{input, weight},
+                         loc, accTy, ValueRange{input, weight},
                          ValueRange{broadcastBias}, strideAttr, dilationAttr)
                      ->getResult(0);
 
+    // We may need to truncate back to the result type if the accumulator was
+    // wider than the result.
+    if (resultTy != accTy)
+      conv = rewriter.create<tosa::CastOp>(loc, resultTy, conv);
+
     rewriter.replaceOp(op, conv);
     return success();
   }
@@ -444,6 +457,8 @@ class DepthwiseConvConverter
     auto strideTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr("stride"));
     auto dilationTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr("dilation"));
 
+    Type accETy = op.getAccType();
+
     if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
       return rewriter.notifyMatchFailure(
           op, "tosa.depthwise_conv ops require static shapes");
@@ -516,11 +531,11 @@ class DepthwiseConvConverter
     ShapedType linalgConvTy =
         RankedTensorType::get({resultShape[0], resultShape[1], resultShape[2],
                                weightShape[2], weightShape[3]},
-                              resultETy);
+                              accETy);
 
-    auto resultZeroAttr = rewriter.getZeroAttr(resultETy);
+    auto resultZeroAttr = rewriter.getZeroAttr(accETy);
     Value emptyTensor = rewriter.create<tensor::EmptyOp>(
-        loc, linalgConvTy.getShape(), resultETy, filteredDims);
+        loc, linalgConvTy.getShape(), accETy, filteredDims);
     Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
     Value zeroTensor = rewriter
                            .create<linalg::FillOp>(loc, ValueRange{zero},
@@ -543,6 +558,15 @@ class DepthwiseConvConverter
                            ValueRange{zeroTensor}, strideAttr, dilationAttr)
                        .getResult(0);
 
+      // We may need to truncate back to the result type if the accumulator was
+      // wider than the result.
+      if (accETy != resultETy)
+        conv = rewriter.create<tosa::CastOp>(
+            loc,
+            RankedTensorType::get(cast<ShapedType>(conv.getType()).getShape(),
+                                  resultETy),
+            conv);
+
       SmallVector<ReassociationExprs, 4> reassociationMap;
       createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter);
       Value convReshape = rewriter.create<tensor::CollapseShapeOp>(
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 19c12ba3edbd4..242772fe5cdcf 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -658,6 +658,20 @@ func.func @conv2d_quant(%arg0 : tensor<1x12x12x1xi8>, %arg1 : tensor<1024x3x3x1x
 
 // -----
 
+// CHECK-LABEL: @conv2d_f16_f32_acc
+func.func @conv2d_f16_f32_acc(%input: tensor<1x49x42x27xf16>, %weights: tensor<28x3x3x27xf16>, %bias: tensor<28xf16>) -> () {
+  %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
+  %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
+  // CHECK: linalg.generic {{{.*}}} ins(%{{.*}} : tensor<28xf16>) outs(%{{.*}} : tensor<1x45x40x28xf32>)
+  // CHECK: arith.extf %{{.*}} : f16 to f32
+  // CHECK: %[[CONV:.*]] = linalg.conv_2d_nhwc_fhwc {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<1x49x42x27xf16>, tensor<28x3x3x27xf16>) outs(%{{.*}} : tensor<1x45x40x28xf32>) -> tensor<1x45x40x28xf32>
+  // CHECK: tosa.cast %[[CONV]] : (tensor<1x45x40x28xf32>) -> tensor<1x45x40x28xf16>
+  %0 = tosa.conv2d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x49x42x27xf16>, tensor<28x3x3x27xf16>, tensor<28xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x45x40x28xf16>
+  return
+}
+
+// -----
+
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)>
 // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 
@@ -848,6 +862,18 @@ func.func @depthwise_int_conv_zero_zp(%arg0 : tensor<1x7x5x3xi8>, %arg1 : tensor
 
 // -----
 
+// CHECK-LABEL: @depthwise_conv2d_f16_f32_acc
+func.func @depthwise_conv2d_f16_f32_acc(%arg0 : tensor<1x7x5x3xf16>, %arg1 : tensor<3x1x3x11xf16>, %arg2 : tensor<33xf16>) -> () {
+  %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
+  %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
+  // CHECK: %[[CONV:.*]] = linalg.depthwise_conv_2d_nhwc_hwcm {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<1x7x5x3xf16>, tensor<3x1x3x11xf16>) outs(%{{.*}} : tensor<1x5x5x3x11xf32>) -> tensor<1x5x5x3x11xf32>
+  // CHECK: tosa.cast %[[CONV]] : (tensor<1x5x5x3x11xf32>) -> tensor<1x5x5x3x11xf16>
+  %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<1x7x5x3xf16>, tensor<3x1x3x11xf16>, tensor<33xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x5x5x33xf16>
+  return
+}
+
+// -----
+
 // CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d4)>
 // CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
 
@@ -918,6 +944,20 @@ func.func @conv3d_i8(%input: tensor<1x49x48x47x27xi8>, %weights: tensor<28x3x4x5
 
 // -----
 
+// CHECK-LABEL: @conv3d_f16_f32_acc
+func.func @conv3d_f16_f32_acc(%input: tensor<1x49x48x47x27xf16>, %weights: tensor<28x3x4x5x27xf16>, %bias: tensor<28xf16>) -> () {
+  %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
+  %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
+  // CHECK: linalg.generic {{{.*}}} ins(%{{.*}} : tensor<28xf16>) outs(%{{.*}} : tensor<1x47x45x43x28xf32>)
+  // CHECK: arith.extf %{{.*}} : f16 to f32
+  // CHECK: %[[CONV:.*]] = linalg.conv_3d_ndhwc_dhwcf {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<1x49x48x47x27xf16>, tensor<3x4x5x27x28xf16>) outs(%{{.*}} : tensor<1x47x45x43x28xf32>) -> tensor<1x47x45x43x28xf32>
+  // CHECK: tosa.cast %[[CONV]] : (tensor<1x47x45x43x28xf32>) -> tensor<1x47x45x43x28xf16>
+  %0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf16>, tensor<28x3x4x5x27xf16>, tensor<28xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x47x45x43x28xf16>
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @test_transpose
 // CHECK-SAME: (%[[ARG0:.+]]: tensor<1x2x3xi32>)
 func.func @test_transpose(%arg0: tensor<1x2x3xi32>) -> () {

@llvmbot
Copy link
Member

llvmbot commented Apr 3, 2025

@llvm/pr-subscribers-mlir-tosa

Author: Jack Frankland (FranklandJack)

Changes

Add support for lowering of convolution operations where the acc_type attribute differs from the result type of the operation. The only case of this in for convolutions in the TOSA-v1.0 specification is an fp16 convolution which internally uses an fp32 accumulator; all other operations have accumulator types that match their output/result types.

Add lit tests for the fp16 convolution with fp32 accumulator operators described above.


Full diff: https://github.com/llvm/llvm-project/pull/134267.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+37-13)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+40)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index fc1cad2423450..86f5e9baf4a94 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -119,10 +119,11 @@ static AffineMap getBroadcastingMap(PatternRewriter &rewriter, Value source,
 }
 
 // Broadcast the source value to all the outer dimensions of the result value.
-// If required, the element type is expanded using an arith.extsi operation.
-static mlir::Value linalgBroadcastAndMaybeExtSI(PatternRewriter &rewriter,
-                                                Location loc, Value source,
-                                                Value result) {
+// If required, the element type is expanded using an arith.extsi or arith.extf
+// operation as appropriate.
+static mlir::Value linalgBroadcastAndMaybeExt(PatternRewriter &rewriter,
+                                              Location loc, Value source,
+                                              Value result) {
   ShapedType resultTy = cast<ShapedType>(result.getType());
   const int64_t resultRank = resultTy.getRank();
   // Creating maps for the input and output of the broacast-like generic op.
@@ -135,11 +136,16 @@ static mlir::Value linalgBroadcastAndMaybeExtSI(PatternRewriter &rewriter,
       .create<linalg::GenericOp>(
           loc, resultTy, ValueRange({source}), result, indexingMaps,
           getNParallelLoopsAttrs(resultTy.getRank()),
-          [](OpBuilder &builder, Location loc, ValueRange args) {
+          [&resultTy](OpBuilder &builder, Location loc, ValueRange args) {
             Value biasVal = args[0];
             Type resType = args[1].getType();
             if (resType != biasVal.getType()) {
-              biasVal = builder.create<arith::ExtSIOp>(loc, resType, biasVal);
+              biasVal =
+                  resultTy.getElementType().isFloat()
+                      ? builder.create<arith::ExtFOp>(loc, resType, biasVal)
+                            .getResult()
+                      : builder.create<arith::ExtSIOp>(loc, resType, biasVal)
+                            .getResult();
             }
             builder.create<linalg::YieldOp>(loc, biasVal);
           })
@@ -253,12 +259,14 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
     ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
 
     Type inputETy = inputTy.getElementType();
-    Type resultETy = resultTy.getElementType();
 
     DenseI64ArrayAttr padAttr = op.getPadAttr();
     DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr();
     DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr();
 
+    Type accETy = op.getAccType();
+    Type accTy = RankedTensorType::get(resultTy.getShape(), accETy);
+
     // Get and verify zero points.
     FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
     if (failed(maybeIZp))
@@ -385,10 +393,10 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
     auto dilationAttr = rewriter.getI64TensorAttr(dilation);
 
     Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>(
-        loc, resultTy.getShape(), resultETy, filteredDims);
+        loc, resultTy.getShape(), accETy, filteredDims);
 
     Value broadcastBias =
-        linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);
+        linalgBroadcastAndMaybeExt(rewriter, loc, bias, biasEmptyTensor);
 
     if (hasZp) {
       auto iZp = rewriter.getI32IntegerAttr(inputZpVal);
@@ -410,10 +418,15 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
 
     Value conv = rewriter
                      .create<LinalgConvOp>(
-                         loc, resultTy, ValueRange{input, weight},
+                         loc, accTy, ValueRange{input, weight},
                          ValueRange{broadcastBias}, strideAttr, dilationAttr)
                      ->getResult(0);
 
+    // We may need to truncate back to the result type if the accumulator was
+    // wider than the result.
+    if (resultTy != accTy)
+      conv = rewriter.create<tosa::CastOp>(loc, resultTy, conv);
+
     rewriter.replaceOp(op, conv);
     return success();
   }
@@ -444,6 +457,8 @@ class DepthwiseConvConverter
     auto strideTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr("stride"));
     auto dilationTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr("dilation"));
 
+    Type accETy = op.getAccType();
+
     if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
       return rewriter.notifyMatchFailure(
           op, "tosa.depthwise_conv ops require static shapes");
@@ -516,11 +531,11 @@ class DepthwiseConvConverter
     ShapedType linalgConvTy =
         RankedTensorType::get({resultShape[0], resultShape[1], resultShape[2],
                                weightShape[2], weightShape[3]},
-                              resultETy);
+                              accETy);
 
-    auto resultZeroAttr = rewriter.getZeroAttr(resultETy);
+    auto resultZeroAttr = rewriter.getZeroAttr(accETy);
     Value emptyTensor = rewriter.create<tensor::EmptyOp>(
-        loc, linalgConvTy.getShape(), resultETy, filteredDims);
+        loc, linalgConvTy.getShape(), accETy, filteredDims);
     Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
     Value zeroTensor = rewriter
                            .create<linalg::FillOp>(loc, ValueRange{zero},
@@ -543,6 +558,15 @@ class DepthwiseConvConverter
                            ValueRange{zeroTensor}, strideAttr, dilationAttr)
                        .getResult(0);
 
+      // We may need to truncate back to the result type if the accumulator was
+      // wider than the result.
+      if (accETy != resultETy)
+        conv = rewriter.create<tosa::CastOp>(
+            loc,
+            RankedTensorType::get(cast<ShapedType>(conv.getType()).getShape(),
+                                  resultETy),
+            conv);
+
       SmallVector<ReassociationExprs, 4> reassociationMap;
       createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter);
       Value convReshape = rewriter.create<tensor::CollapseShapeOp>(
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 19c12ba3edbd4..242772fe5cdcf 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -658,6 +658,20 @@ func.func @conv2d_quant(%arg0 : tensor<1x12x12x1xi8>, %arg1 : tensor<1024x3x3x1x
 
 // -----
 
+// CHECK-LABEL: @conv2d_f16_f32_acc
+func.func @conv2d_f16_f32_acc(%input: tensor<1x49x42x27xf16>, %weights: tensor<28x3x3x27xf16>, %bias: tensor<28xf16>) -> () {
+  %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
+  %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
+  // CHECK: linalg.generic {{{.*}}} ins(%{{.*}} : tensor<28xf16>) outs(%{{.*}} : tensor<1x45x40x28xf32>)
+  // CHECK: arith.extf %{{.*}} : f16 to f32
+  // CHECK: %[[CONV:.*]] = linalg.conv_2d_nhwc_fhwc {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<1x49x42x27xf16>, tensor<28x3x3x27xf16>) outs(%{{.*}} : tensor<1x45x40x28xf32>) -> tensor<1x45x40x28xf32>
+  // CHECK: tosa.cast %[[CONV]] : (tensor<1x45x40x28xf32>) -> tensor<1x45x40x28xf16>
+  %0 = tosa.conv2d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x49x42x27xf16>, tensor<28x3x3x27xf16>, tensor<28xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x45x40x28xf16>
+  return
+}
+
+// -----
+
 // CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)>
 // CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
 
@@ -848,6 +862,18 @@ func.func @depthwise_int_conv_zero_zp(%arg0 : tensor<1x7x5x3xi8>, %arg1 : tensor
 
 // -----
 
+// CHECK-LABEL: @depthwise_conv2d_f16_f32_acc
+func.func @depthwise_conv2d_f16_f32_acc(%arg0 : tensor<1x7x5x3xf16>, %arg1 : tensor<3x1x3x11xf16>, %arg2 : tensor<33xf16>) -> () {
+  %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
+  %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
+  // CHECK: %[[CONV:.*]] = linalg.depthwise_conv_2d_nhwc_hwcm {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<1x7x5x3xf16>, tensor<3x1x3x11xf16>) outs(%{{.*}} : tensor<1x5x5x3x11xf32>) -> tensor<1x5x5x3x11xf32>
+  // CHECK: tosa.cast %[[CONV]] : (tensor<1x5x5x3x11xf32>) -> tensor<1x5x5x3x11xf16>
+  %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<1x7x5x3xf16>, tensor<3x1x3x11xf16>, tensor<33xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x5x5x33xf16>
+  return
+}
+
+// -----
+
 // CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d4)>
 // CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
 
@@ -918,6 +944,20 @@ func.func @conv3d_i8(%input: tensor<1x49x48x47x27xi8>, %weights: tensor<28x3x4x5
 
 // -----
 
+// CHECK-LABEL: @conv3d_f16_f32_acc
+func.func @conv3d_f16_f32_acc(%input: tensor<1x49x48x47x27xf16>, %weights: tensor<28x3x4x5x27xf16>, %bias: tensor<28xf16>) -> () {
+  %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
+  %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
+  // CHECK: linalg.generic {{{.*}}} ins(%{{.*}} : tensor<28xf16>) outs(%{{.*}} : tensor<1x47x45x43x28xf32>)
+  // CHECK: arith.extf %{{.*}} : f16 to f32
+  // CHECK: %[[CONV:.*]] = linalg.conv_3d_ndhwc_dhwcf {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<1x49x48x47x27xf16>, tensor<3x4x5x27x28xf16>) outs(%{{.*}} : tensor<1x47x45x43x28xf32>) -> tensor<1x47x45x43x28xf32>
+  // CHECK: tosa.cast %[[CONV]] : (tensor<1x47x45x43x28xf32>) -> tensor<1x47x45x43x28xf16>
+  %0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf16>, tensor<28x3x4x5x27xf16>, tensor<28xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x47x45x43x28xf16>
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @test_transpose
 // CHECK-SAME: (%[[ARG0:.+]]: tensor<1x2x3xi32>)
 func.func @test_transpose(%arg0: tensor<1x2x3xi32>) -> () {

Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - thanks! Probably needs someone more familiar with linalg to have a look as well

@FranklandJack FranklandJack requested a review from sjarus April 6, 2025 20:42
@FranklandJack FranklandJack merged commit 7b007c0 into llvm:main Apr 7, 2025
15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants