Skip to content

[mlir][tosa] Improve lowering of tosa.conv2d #74143

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 2, 2023

Conversation

sabauma
Copy link
Contributor

@sabauma sabauma commented Dec 1, 2023

The existing lowering of tosa.conv2d emits a separate linalg.generic operator to add the bias after computing the computation.

This change eliminates that additional step by using the generated linalg.conv_2d_* operator by using the bias value as the input to the linalg.conv_2d operation.

Rather than:

%init = tensor.empty()
%conv = linalg.conv_2d ins(%A, %B) %outs(%init)

%init = tensor.empty()
%bias = linalg.generic ins(%conv, %bias) outs(%init2) {
  // perform add operation
}

The lowering now produces:

%init = tensor.empty()
%bias_expanded = linalg.broadcast ins(%bias) outs(%init)
%conv = linalg.conv_2d ins(%A, %B) %outs(%bias)

This is the same strategy as #73049 applied to convolutions.

The existing lowering of tosa.conv2d emits a separate linalg.generic
operator to add the bias after computing the computation.

This change eliminates that additional step by using the generated
linalg.conv_2d_* operator by using the bias value as the input to the
linalg.conv_2d operation.

Rather than:

    %init = tensor.empty()
    %conv = linalg.conv_2d ins(%A, %B) %outs(%init)

    %init = tensor.empty()
    %bias = linalg.generic ins(%conv, %bias) outs(%init2) {
      // perform add operation
    }

The lowering now produces:

    %init = tensor.empty()
    %bias_expanded = linalg.broadcast ins(%bias) outs(%init)
    %conv = linalg.conv_2d ins(%A, %B) %outs(%bias)
@llvmbot
Copy link
Member

llvmbot commented Dec 1, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Spenser Bauman (sabauma)

Changes

The existing lowering of tosa.conv2d emits a separate linalg.generic operator to add the bias after computing the computation.

This change eliminates that additional step by using the generated linalg.conv_2d_* operator by using the bias value as the input to the linalg.conv_2d operation.

Rather than:

%init = tensor.empty()
%conv = linalg.conv_2d ins(%A, %B) %outs(%init)

%init = tensor.empty()
%bias = linalg.generic ins(%conv, %bias) outs(%init2) {
  // perform add operation
}

The lowering now produces:

%init = tensor.empty()
%bias_expanded = linalg.broadcast ins(%bias) outs(%init)
%conv = linalg.conv_2d ins(%A, %B) %outs(%bias)

This is the same strategy as #73049 applied to convolutions.


Full diff: https://github.com/llvm/llvm-project/pull/74143.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+9-37)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+59-66)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 0accd9d1986a1ed..b3fbc7dd0b22c19 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -344,15 +344,6 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
                                                   weightPermValue);
     }
 
-    auto resultZeroAttr = rewriter.getZeroAttr(resultETy);
-    Value emptyTensor = rewriter.create<tensor::EmptyOp>(
-        loc, resultTy.getShape(), resultETy, filteredDims);
-    Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
-    Value zeroTensor = rewriter
-                           .create<linalg::FillOp>(loc, ValueRange{zero},
-                                                   ValueRange{emptyTensor})
-                           .result();
-
     // Extract the attributes for convolution.
     ArrayRef<int64_t> stride = strideTosaAttr;
     ArrayRef<int64_t> dilation = dilationTosaAttr;
@@ -361,18 +352,12 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
     auto strideAttr = rewriter.getI64TensorAttr(stride);
     auto dilationAttr = rewriter.getI64TensorAttr(dilation);
 
-    // Create maps for the bias broadcasting
-    SmallVector<AffineMap, 4> indexingMaps;
-    indexingMaps.push_back(AffineMap::get(
-        /*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
-        {rewriter.getAffineDimExpr(resultTy.getRank() - 1)},
-        rewriter.getContext()));
-    indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
-    indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
-
     Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>(
         loc, resultTy.getShape(), resultETy, filteredDims);
 
+    Value broadcastBias =
+        linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);
+
     if (isQuantized) {
       auto quantizationInfo = *op.getQuantizationInfo();
       auto iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp());
@@ -380,38 +365,25 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
 
       auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
       auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
+
       Value conv =
           rewriter
               .create<LinalgConvQOp>(
                   loc, resultTy, ValueRange{input, weight, iZpVal, kZpVal},
-                  ValueRange{zeroTensor}, strideAttr, dilationAttr)
+                  ValueRange{broadcastBias}, strideAttr, dilationAttr)
               ->getResult(0);
-      Value result = linalgIntBroadcastExtSIAdd(rewriter, loc, bias, conv,
-                                                biasEmptyTensor, indexingMaps);
-      rewriter.replaceOp(op, result);
+
+      rewriter.replaceOp(op, conv);
       return success();
     }
 
     Value conv = rewriter
                      .create<LinalgConvOp>(
                          loc, resultTy, ValueRange{input, weight},
-                         ValueRange{zeroTensor}, strideAttr, dilationAttr)
+                         ValueRange{broadcastBias}, strideAttr, dilationAttr)
                      ->getResult(0);
 
-    Value result =
-        rewriter
-            .create<linalg::GenericOp>(
-                loc, resultTy, ValueRange({bias, conv}), biasEmptyTensor,
-                indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()),
-                [&](OpBuilder &nestedBuilder, Location nestedLoc,
-                    ValueRange args) {
-                  Value added = nestedBuilder.create<arith::AddFOp>(
-                      loc, args[0], args[1]);
-                  nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
-                })
-            .getResult(0);
-
-    rewriter.replaceOp(op, result);
+    rewriter.replaceOp(op, conv);
     return success();
   }
 };
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 230001f7633b570..aa010e759a0f201 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -378,16 +378,14 @@ func.func @avg_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>)
 func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi8>, %bias: tensor<28xi8>) -> () {
   // HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64>
   // HWCF: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[TRANSPOSE_DIMS]] : (tensor<28x1x1x27xi8>, tensor<4xi64>) -> tensor<1x1x27x28xi8>
-  // CHECK: %[[M_IN:.+]] = tensor.empty()
-  // CHECK: %[[CST:.+]] = arith.constant 0
-  // CHECK: %[[FILL:.+]] = linalg.fill
-  // CHECK: %[[B_IN:.+]] = tensor.empty()
-  // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, i32, i32) outs(%[[FILL]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
-  // HWCF: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]], %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<1x1x27x28xi8>, i32, i32) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
-  // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xi8>, tensor<1x45x40x28xi32>) outs(%[[B_IN]] : tensor<1x45x40x28xi32>)
+  // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x45x40x28xi32>
+  // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xi8>) outs(%[[INIT]] : tensor<1x45x40x28xi32>) {
   // CHECK:   arith.extsi
-  // CHECK:   arith.addi
   // CHECK:   linalg.yield
+  // CHECK: } -> tensor<1x45x40x28xi32>
+  // CHECK: linalg.conv_2d_nhwc_fhwc_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %c0_i32, %c0_i32_0 : tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, i32, i32) outs(%[[BROADCAST]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
+  // HWCF: linalg.conv_2d_nhwc_hwcf_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]], %c0_i32, %c0_i32_0 : tensor<1x49x42x27xi8>, tensor<1x1x27x28xi8>, i32, i32) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
+
   %0 = tosa.conv2d %input, %weights, %bias {dilation = array<i64: 2, 1>, pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>, stride = array<i64: 1, 1>} : (tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, tensor<28xi8>) -> tensor<1x45x40x28xi32>
   return
 }
@@ -401,15 +399,14 @@ func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi
 func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () {
   // HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64>
   // HWCF: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[TRANSPOSE_DIMS]] : (tensor<28x3x3x27xf32>, tensor<4xi64>) -> tensor<3x3x27x28xf32>
-  // CHECK: %[[M_IN:.+]] = tensor.empty()
-  // CHECK: %[[CST:.+]] = arith.constant 0
-  // CHECK: %[[FILL:.+]] = linalg.fill
-  // CHECK: %[[B_IN:.+]] = tensor.empty()
-  // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%[[FILL]] : tensor<1x45x40x28xf32>)
-  // HWCF: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]] : tensor<1x49x42x27xf32>, tensor<3x3x27x28xf32>) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xf32>
-  // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<1x45x40x28xf32>) outs(%[[B_IN]] : tensor<1x45x40x28xf32>)
-  // CHECK:   arith.addf
+
+  // CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x45x40x28xf32>
+  // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xf32>) outs(%[[INIT]] : tensor<1x45x40x28xf32>) {
   // CHECK:   linalg.yield
+  // CHECK: } -> tensor<1x45x40x28xf32>
+  // CHECK: linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%1 : tensor<1x45x40x28xf32>) -> tensor<1x45x40x28xf32>
+
+  // HWCF: linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]] : tensor<1x49x42x27xf32>, tensor<3x3x27x28xf32>) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xf32>
   %0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>) -> tensor<1x45x40x28xf32>
   return
 }
@@ -421,16 +418,14 @@ func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27
 
 // CHECK-LABEL: @conv2d_dyn
 func.func @conv2d_dyn(%input: tensor<?x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () {
-  // CHECK: %[[C0:.+]] = arith.constant 0
-  // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]]
-  // CHECK: %[[M_IN:.+]] = tensor.empty(%[[BATCH]])
-  // CHECK: %[[CST:.+]] = arith.constant 0
-  // CHECK: %[[FILL:.+]] = linalg.fill
-  // CHECK: %[[B_IN:.+]] = tensor.empty(%[[BATCH]])
-  // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<?x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%[[FILL]] : tensor<?x45x40x28xf32>)
-  // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<?x45x40x28xf32>) outs(%[[B_IN]] : tensor<?x45x40x28xf32>)
-  // CHECK:   %[[ADD:.+]] = arith.addf
-  // CHECK:   linalg.yield %[[ADD]] : f32
+  // CHECK: %[[C0:.+]] = arith.constant 0 : index
+  // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]] : tensor<?x49x42x27xf32>
+  // CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x45x40x28xf32>
+  // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xf32>) outs(%[[INIT]] : tensor<?x45x40x28xf32>) {
+  // CHECK: ^bb0(%[[IN:.+]]: f32, %{{.+}}: f32):
+  // CHECK:   linalg.yield %[[IN]] : f32
+  // CHECK: } -> tensor<?x45x40x28xf32>
+  // CHECK: %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<?x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%[[BROADCAST]] : tensor<?x45x40x28xf32>) -> tensor<?x45x40x28xf32>
   %0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<?x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>) -> tensor<?x45x40x28xf32>
   return
 }
@@ -481,14 +476,12 @@ func.func @conv2d_dyn_w_h(%input: tensor<1x?x?x27xf32>, %weights: tensor<28x3x3x
   // CHECK: %[[W_OUT:.+]] = arith.addi %[[DIVIDED_0]], %[[ONE_0]] : index
 
   // Running convolution
-  // CHECK: %[[M_IN:.+]] = tensor.empty(%[[H_OUT]], %[[W_OUT]])
-  // CHECK: %[[CST:.+]] = arith.constant 0
-  // CHECK: %[[FILL:.+]] = linalg.fill
-  // CHECK: %[[B_IN:.+]] = tensor.empty(%[[H_OUT]], %[[W_OUT]])
-  // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x?x?x27xf32>, tensor<28x3x3x27xf32>) outs(%[[FILL]] : tensor<1x?x?x28xf32>)
-  // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<1x?x?x28xf32>) outs(%[[B_IN]] : tensor<1x?x?x28xf32>)
-  // CHECK:   %[[ADD:.+]] = arith.addf
-  // CHECK:   linalg.yield %[[ADD]] : f32
+  // CHECK: %[[INIT:.+]] = tensor.empty(%[[H_OUT]], %[[W_OUT]]) : tensor<1x?x?x28xf32>
+  // CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xf32>) outs(%[[INIT]] : tensor<1x?x?x28xf32>) {
+  // CHECK:   linalg.yield
+  // CHECK: } -> tensor<1x?x?x28xf32>
+  // CHECK: linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x?x?x27xf32>, tensor<28x3x3x27xf32>) outs(%17 : tensor<1x?x?x28xf32>) -> tensor<1x?x?x28xf32>
+
   %0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x?x?x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>) -> tensor<1x?x?x28xf32>
   return
 }
@@ -678,52 +671,52 @@ func.func @depthwise_conv2d_dyn_w_h(%arg0: tensor<2x?x?x3xf32>, %arg1: tensor<3x
 
 // -----
 
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d4)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+
 // CHECK-LABEL: @conv3d_f32
 func.func @conv3d_f32(%input: tensor<1x49x48x47x27xf32>, %weights: tensor<28x3x4x5x27xf32>, %bias: tensor<28xf32>) -> () {
-  // CHECK-DAG: %[[PERMS:.+]] = arith.constant dense<[1, 2, 3, 4, 0]>
-  // CHECK-DAG: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[PERMS]]
-  // CHECK-DAG: %[[EMPTY:.+]] = tensor.empty()
-  // CHECK-DAG: %[[ZERO:.+]] = arith.constant 0
-  // CHECK-DAG: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]] : f32) outs(%[[EMPTY]] : tensor<1x47x45x43x28xf32>)
-  // CHECK-DAG: %[[EMPTY:.+]] = tensor.empty()
-  // CHECK-DAG: %[[CONV3D:.+]] = linalg.conv_3d_ndhwc_dhwcf
+  // CHECK-DAG:  %[[PERMS:.+]] = arith.constant dense<[1, 2, 3, 4, 0]>
+  // CHECK-DAG:  %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[PERMS]]
+  // CHECK-DAG:  %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x28xf32>
+  // CHECK:      %[[BROADCAST:.+]] = linalg.generic
+  // CHECK-SAME: {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
+  // CHECK-SAME: ins(%arg2 : tensor<28xf32>) outs(%1 : tensor<1x47x45x43x28xf32>) {
+  // CHECK:      ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
+  // CHECK:        linalg.yield %[[IN]] : f32
+  // CHECK:      } -> tensor<1x47x45x43x28xf32>
+  // CHECK:      linalg.conv_3d_ndhwc_dhwcf
   // CHECK-SAME: {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
   // CHECK-SAME: ins(%arg0, %[[TRANSPOSE]] : tensor<1x49x48x47x27xf32>, tensor<3x4x5x27x28xf32>)
-  // CHECK-SAME: outs(%[[FILL]] : tensor<1x47x45x43x28xf32>) -> tensor<1x47x45x43x28xf32>
-  // CHECK: %[[GENERIC:.+]] = linalg.generic
-  // CHECK-SAME: {indexing_maps = [#map, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
-  // CHECK-SAME: ins(%arg2, %[[CONV3D]] : tensor<28xf32>, tensor<1x47x45x43x28xf32>)
-  // CHECK-SAME: outs(%[[EMPTY]] : tensor<1x47x45x43x28xf32>) {
-  // CHECK: ^bb0(%[[A1:.+]]: f32, %[[A2:.+]]: f32, %{{.+}}: f32):
-  // CHECK: %[[ADD:.+]] = arith.addf %[[A1]], %[[A2]] : f32
-  // CHECK: linalg.yield %[[ADD]]
+  // CHECK-SAME: outs(%[[BROADCAST]] : tensor<1x47x45x43x28xf32>) -> tensor<1x47x45x43x28xf32>
   %0 = tosa.conv3d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf32>, tensor<28x3x4x5x27xf32>, tensor<28xf32>)  -> tensor<1x47x45x43x28xf32>
   return
 }
 
 // -----
 
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d4)>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
+
 // CHECK-LABEL: @conv3d_i8
 func.func @conv3d_i8(%input: tensor<1x49x48x47x27xi8>, %weights: tensor<28x3x4x5x27xi8>, %bias: tensor<28xi32>) -> () {
-    // CHECK-DAG: %[[PERMS:.+]] = arith.constant dense<[1, 2, 3, 4, 0]>
-  // CHECK-DAG: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[PERMS]]
-  // CHECK-DAG: %[[EMPTY:.+]] = tensor.empty()
-  // CHECK-DAG: %[[ZERO:.+]] = arith.constant 0
-  // CHECK-DAG: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]] : i32) outs(%[[EMPTY]] : tensor<1x47x45x43x28xi32>)
-  // CHECK-DAG: %[[EMPTY:.+]] = tensor.empty()
-  // CHECK-DAG: %[[IZP:.+]] = arith.constant -128 : i32
-  // CHECK-DAG: %[[FZP:.+]] = arith.constant 42 : i32
-  // CHECK-DAG: %[[CONV3D:.+]] = linalg.conv_3d_ndhwc_dhwcf_q
+  // CHECK-DAG:  %[[PERMS:.+]] = arith.constant dense<[1, 2, 3, 4, 0]>
+  // CHECK-DAG:  %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[PERMS]]
+  // CHECK-DAG:  %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x28xi32>
+  // CHECK:      %[[BROADCAST:.+]] = linalg.generic
+  // CHECK-SAME: {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
+  // CHECK-SAME: ins(%arg2 : tensor<28xi32>)
+  // CHECK-SAME: outs(%[[INIT]] : tensor<1x47x45x43x28xi32>) {
+  // CHECK:      ^bb0(%[[IN:.+]]: i32, %[[OUT:.+]]: i32):
+  // CHECK:        linalg.yield %[[IN]] : i32
+  // CHECK:      } -> tensor<1x47x45x43x28xi32>
+  // CHECK:      %[[IZP:.+]] = arith.constant -128 : i32
+  // CHECK:      %[[FZP:.+]] = arith.constant 42 : i32
+  // CHECK:      linalg.conv_3d_ndhwc_dhwcf_q
   // CHECK-SAME: {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
   // CHECK-SAME: ins(%arg0, %[[TRANSPOSE]], %[[IZP]], %[[FZP]] : tensor<1x49x48x47x27xi8>, tensor<3x4x5x27x28xi8>, i32, i32)
-  // CHECK-SAME: outs(%[[FILL]] : tensor<1x47x45x43x28xi32>) -> tensor<1x47x45x43x28xi32>
-  // CHECK: %[[GENERIC:.+]] = linalg.generic
-  // CHECK-SAME: {indexing_maps = [#map, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
-  // CHECK-SAME: ins(%arg2, %[[CONV3D]] : tensor<28xi32>, tensor<1x47x45x43x28xi32>)
-  // CHECK-SAME: outs(%[[EMPTY]] : tensor<1x47x45x43x28xi32>) {
-  // CHECK: ^bb0(%[[A1:.+]]: i32, %[[A2:.+]]: i32, %{{.+}}: i32):
-  // CHECK: %[[ADD:.+]] = arith.addi %[[A1]], %[[A2]] : i32
-  // CHECK: linalg.yield %[[ADD]]
+  // CHECK-SAME: outs(%[[BROADCAST]] : tensor<1x47x45x43x28xi32>) -> tensor<1x47x45x43x28xi32>
+
   %0 = tosa.conv3d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -128, weight_zp = 42>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xi8>, tensor<28x3x4x5x27xi8>, tensor<28xi32>)  -> tensor<1x47x45x43x28xi32>
   return
 }

Copy link
Contributor

@eric-k256 eric-k256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like a nice simplification!

@sabauma sabauma merged commit 293c21d into llvm:main Dec 2, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants