Skip to content

[mlir][tosa] Improve lowering to tosa.fully_connected #73049

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 1, 2023

Conversation

sabauma
Copy link
Contributor

@sabauma sabauma commented Nov 21, 2023

The current lowering of tosa.fully_connected produces a linalg.matmul followed by a linalg.generic to add the bias. The IR looks like the following:

%init = tensor.empty()
%zero = linalg.fill ins(0 : f32) outs(%init)
%prod = linalg.matmul ins(%A, %B) outs(%zero)

// Add the bias
%initB = tensor.empty()
%result = linalg.generic ins(%prod, %bias) outs(%initB) {
   // add bias and product
}

This has two down sides:

  1. The tensor.empty operations typically result in additional allocations after bufferization
  2. There is a redundant traversal of the data to add the bias to the matrix product.

This extra work can be avoided by leveraging the out-param of linalg.matmul. The new IR sequence is:

%init = tensor.empty()
%broadcast = linalg.broadcast ins(%bias) outs(%init)
%prod = linalg.matmul ins(%A, %B) outs(%broadcast)

In my experiments, this eliminates one loop and one allocation (post bufferization) from the generated code.

@llvmbot
Copy link
Member

llvmbot commented Nov 21, 2023

@llvm/pr-subscribers-mlir-tosa
@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Spenser Bauman (sabauma)

Changes

The current lowering of tosa.fully_connected produces a linalg.matmul followed by a linalg.generic to add the bias. The IR looks like the following:

%init = tensor.empty()
%zero = linalg.fill ins(0 : f32) outs(%init)
%prod = linalg.matmul ins(%A, %B) outs(%zero)

// Add the bias
%initB = tensor.empty()
%result = linalg.generic ins(%prod, %bias) outs(%initB) {
   // add bias and product
}

This has two down sides:

  1. The tensor.empty operations typically result in additional allocations after bufferization
  2. There is a redundant traversal of the data to add the bias to the matrix product.

This extra work can be avoided by leveraging the out-param of linalg.matmul. The new IR sequence is:

%init = tensor.empty()
%broadcast = linalg.broadcast ins(%bias) outs(%init)
%prod = linalg.matmul ins(%A, %B) outs(%broadcast)

In my experiments, this eliminates one loop and one allocation (post bufferization) from the generated code.


Full diff: https://github.com/llvm/llvm-project/pull/73049.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+9-29)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+20-46)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 99a65f63038a43f..b9a7b778ce4017d 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -632,17 +632,6 @@ class FullyConnectedConverter
     indexingMaps.push_back(rewriter.getMultiDimIdentityMap(outputTy.getRank()));
     indexingMaps.push_back(rewriter.getMultiDimIdentityMap(outputTy.getRank()));
 
-    auto emptyTensor = rewriter.create<tensor::EmptyOp>(
-        loc, outputTy.getShape(), outputTy.getElementType(), filteredDims);
-
-    // When quantized, the input elemeny type is not the same as the output
-    auto resultZeroAttr = rewriter.getZeroAttr(outputETy);
-    Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
-    Value zeroTensor = rewriter
-                           .create<linalg::FillOp>(loc, ValueRange{zero},
-                                                   ValueRange{emptyTensor})
-                           .result();
-
     SmallVector<int64_t> permutation{1, 0};
     auto permutationAttr = rewriter.getI64TensorAttr(permutation);
     Value permutationValue =
@@ -658,26 +647,18 @@ class FullyConnectedConverter
     Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>(
         loc, outputTy.getShape(), outputETy, filteredDims);
 
+    auto broadcastDims = DenseI64ArrayAttr::get(getContext(), {0});
+    Value biasInitTensor = rewriter.create<linalg::BroadcastOp>(
+        loc, bias, biasEmptyTensor, broadcastDims)->getResult(0);
+
     if (!op.getQuantizationInfo()) {
       Value matmul = rewriter
                          .create<linalg::MatmulOp>(
                              loc, TypeRange{op.getType()},
-                             ValueRange{input, transposedWeight}, zeroTensor)
+                             ValueRange{input, transposedWeight}, biasInitTensor)
                          ->getResult(0);
 
-      Value result =
-          rewriter
-              .create<linalg::GenericOp>(
-                  loc, outputTy, ValueRange({bias, matmul}), biasEmptyTensor,
-                  indexingMaps, getNParallelLoopsAttrs(outputTy.getRank()),
-                  [&](OpBuilder &nestedBuilder, Location nestedLoc,
-                      ValueRange args) {
-                    Value added = nestedBuilder.create<arith::AddFOp>(
-                        loc, args[0], args[1]);
-                    nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
-                  })
-              .getResult(0);
-      rewriter.replaceOp(op, result);
+      rewriter.replaceOp(op, matmul);
       return success();
     }
 
@@ -691,11 +672,10 @@ class FullyConnectedConverter
             .create<linalg::QuantizedMatmulOp>(
                 loc, TypeRange{op.getType()},
                 ValueRange{input, transposedWeight, inputZp, outputZp},
-                zeroTensor)
+                biasInitTensor)
             ->getResult(0);
-    Value result = linalgIntBroadcastExtSIAdd(rewriter, loc, bias, matmul,
-                                              biasEmptyTensor, indexingMaps);
-    rewriter.replaceOp(op, result);
+
+    rewriter.replaceOp(op, matmul);
     return success();
   }
 };
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 1cf7c8dee606899..3b6d574b73b1ab6 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -68,22 +68,13 @@ func.func @matmul_dyn_independent_dim(%arg0: tensor<1x5x?xf32>, %arg1: tensor<1x
 
 // -----
 
-// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
-// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
-
 // CHECK-LABEL: @fully_connected
 func.func @fully_connected(%arg0: tensor<5x3xf32>, %arg1: tensor<6x3xf32>, %arg2: tensor<6xf32>) -> (tensor<5x6xf32>) {
-  // CHECK: [[INITT:%.+]] = tensor.empty()
-  // CHECK: [[ZERO:%.+]] = arith.constant 0
-  // CHECK: [[FILL:%.+]] = linalg.fill ins([[ZERO]]{{.*}}outs([[INITT]]
-  // CHECK: [[PERM:%.+]] = arith.constant dense<[1, 0]>
-  // CHECK: [[TRANSPOSE:%.+]] = tosa.transpose %arg1, [[PERM]]
-  // CHECK: [[INITB:%.+]] = tensor.empty()
-  // CHECK: [[MATMUL:%.+]] = linalg.matmul ins(%arg0, [[TRANSPOSE]] : tensor<5x3xf32>, tensor<3x6xf32>) outs([[FILL]] : tensor<5x6xf32>) -> tensor<5x6xf32>
-  // CHECK: [[ADDED:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%arg2, [[MATMUL]] : tensor<6xf32>, tensor<5x6xf32>) outs([[INITB]] : tensor<5x6xf32>) {
-  // CHECK: ^bb0(%[[ARG3:[0-9a-zA-Z_]+]]: f32, %[[ARG4:[0-9a-zA-Z_]+]]: f32, %[[ARG5:[0-9a-zA-Z_]+]]: f32):
-  // CHECK:   [[ADD:%.+]] = arith.addf %[[ARG3]], %[[ARG4]] : f32
-  // CHECK:   linalg.yield [[ADD]] : f32
+  // %[[PERM:.+]] = arith.constant dense<[1, 0]> : tensor<2xi64>
+  // %[[TRANSPOSED:.+]] = tosa.transpose %arg1, %[[PERM]] : (tensor<6x3xf32>, tensor<2xi64>) -> tensor<3x6xf32>
+  // %[[INIT:.+]] = tensor.empty() : tensor<5x6xf32>
+  // %[[BROADCASTED:.+]] = linalg.broadcast ins(%arg2 : tensor<6xf32>) outs(%[[INIT]] : tensor<5x6xf32>) dimensions = [0] 
+  // linalg.matmul ins(%arg0, %[[TRANSPOSED]] : tensor<5x3xf32>, tensor<3x6xf32>) outs(%[[BROADCASTED]] : tensor<5x6xf32>) -> tensor<5x6xf32>
 
   %0 = tosa.fully_connected %arg0, %arg1, %arg2 : (tensor<5x3xf32>, tensor<6x3xf32>, tensor<6xf32>) -> tensor<5x6xf32>
   return %0 : tensor<5x6xf32>
@@ -91,48 +82,31 @@ func.func @fully_connected(%arg0: tensor<5x3xf32>, %arg1: tensor<6x3xf32>, %arg2
 
 // -----
 
-// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
-// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
-
 // CHECK-LABEL: @quantized_fully_connected
 func.func @quantized_fully_connected(%arg0: tensor<5x3xi8>, %arg1: tensor<6x3xi8>, %arg2: tensor<6xi32>) -> (tensor<5x6xi32>) {
-  // CHECK: [[INITT:%.+]] = tensor.empty()
-  // CHECK: [[ZERO:%.+]] = arith.constant 0
-  // CHECK: [[FILL:%.+]] = linalg.fill ins([[ZERO]]{{.*}}outs([[INITT]]
-  // CHECK: [[PERM:%.+]] = arith.constant dense<[1, 0]>
-  // CHECK: [[TRANSPOSE:%.+]] = tosa.transpose %arg1, [[PERM]]
-  // CHECK: [[INITB:%.+]] = tensor.empty()
-  // CHECK: [[ONE:%.+]] = arith.constant 1
-  // CHECK: [[TWO:%.+]] = arith.constant 2
-  // CHECK: [[MATMUL:%.+]] = linalg.quantized_matmul ins(%arg0, [[TRANSPOSE]], [[ONE]], [[TWO]] : tensor<5x3xi8>, tensor<3x6xi8>, i32, i32) outs([[FILL]] : tensor<5x6xi32>) -> tensor<5x6xi32>
-  // CHECK: [[ADDED:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%arg2, [[MATMUL]] : tensor<6xi32>, tensor<5x6xi32>) outs([[INITB]]
-  // CHECK: ^bb0([[IN1:%.+]]: i32, [[IN2:%.+]]: i32, [[UNUSED:%.+]]: i32):
-  // CHECK:   [[ADD:%.+]] = arith.addi
-  // CHECK:   linalg.yield [[ADD]] : i32
+  // %[[PERM:.+]] = arith.constant dense<[1, 0]> : tensor<2xi64>
+  // %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[PERM]] : (tensor<6x3xi8>, tensor<2xi64>) -> tensor<3x6xi8>
+  // %[[INIT:.+]] = tensor.empty() : tensor<5x6xi32>
+  // %[[BROADCASTED:.+]] = linalg.broadcast ins(%arg2 : tensor<6xi32>) outs(%[[INIT]] : tensor<5x6xi32>) dimensions = [0]
+  // %[[C1:.+]] = arith.constant 1 : i32
+  // %[[C2:.+]] = arith.constant 2 : i32
+  // linalg.quantized_matmul ins(%arg0, %[[BROADCASTED]], %[[C1]], %[[C2]] : tensor<5x3xi8>, tensor<3x6xi8>, i32, i32) outs(%[[BROADCASTED]]) : tensor<5x6xi32>) -> tensor<5x6xi32>
+
   %0 = tosa.fully_connected %arg0, %arg1, %arg2 {quantization_info = #tosa.conv_quant<input_zp = 1, weight_zp = 2>} : (tensor<5x3xi8>, tensor<6x3xi8>, tensor<6xi32>) -> tensor<5x6xi32>
   return %0 : tensor<5x6xi32>
 }
 
 // -----
 
-// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
-// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
-
 // CHECK-LABEL: @fully_connected_dyn
 func.func @fully_connected_dyn(%arg0: tensor<?x3xf32>, %arg1: tensor<6x3xf32>, %arg2: tensor<6xf32>) -> (tensor<?x6xf32>) {
-  // CHECK: %[[C0:.+]] = arith.constant 0
-  // CHECK: %[[DIM:.+]] = tensor.dim %arg0, %[[C0]]
-  // CHECK: %[[INITT:.+]] = tensor.empty(%[[DIM]])
-  // CHECK: %[[ZERO:.+]] = arith.constant 0
-  // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]]{{.*}}outs(%[[INITT]]
-  // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 0]>
-  // CHECK: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[PERM]]
-  // CHECK: %[[INITB:.+]] = tensor.empty(%[[DIM]])
-  // CHECK: %[[MATMUL:.+]] = linalg.matmul ins(%arg0, %[[TRANSPOSE]] : tensor<?x3xf32>, tensor<3x6xf32>) outs(%[[FILL]] : tensor<?x6xf32>) -> tensor<?x6xf32>
-  // CHECK: %[[ADDED:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%arg2, %[[MATMUL]] : tensor<6xf32>, tensor<?x6xf32>) outs(%[[INITB]] : tensor<?x6xf32>) {
-  // CHECK: ^bb0(%[[ARG3:[0-9a-zA-Z_]+]]: f32, %[[ARG4:[0-9a-zA-Z_]+]]: f32, %[[ARG5:[0-9a-zA-Z_]+]]: f32):
-  // CHECK:   %[[ADD:.+]] = arith.addf %[[ARG3]], %[[ARG4]] : f32
-  // CHECK:   linalg.yield %[[ADD]] : f32
+  // CHECK: %[[C0:.+]] = arith.constant 0 : index
+  // CHECK: %[[DIM0:.+]] = tensor.dim %arg0, %c0 : tensor<?x3xf32>
+  // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 0]> : tensor<2xi64>
+  // CHECK: %[[TRANSPOSED:.+]] = tosa.transpose %arg1, %[[PERM]] : (tensor<6x3xf32>, tensor<2xi64>) -> tensor<3x6xf32>
+  // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM0]]) : tensor<?x6xf32>
+  // CHECK: %[[BROADCASTED:.+]] = linalg.broadcast ins(%arg2 : tensor<6xf32>) outs(%[[INIT]] : tensor<?x6xf32>) dimensions = [0]
+  // CHECK: linalg.matmul ins(%arg0, %[[TRANSPOSED]] : tensor<?x3xf32>, tensor<3x6xf32>) outs(%[[BROADCASTED]] : tensor<?x6xf32>) -> tensor<?x6xf32>
 
   %0 = tosa.fully_connected %arg0, %arg1, %arg2 : (tensor<?x3xf32>, tensor<6x3xf32>, tensor<6xf32>) -> tensor<?x6xf32>
   return %0 : tensor<?x6xf32>

@sabauma sabauma force-pushed the fully-connected-fix branch 2 times, most recently from 9b5c5b2 to 1d97a44 Compare November 21, 2023 23:31
@sabauma sabauma force-pushed the fully-connected-fix branch 2 times, most recently from c59b55a to 654ca3e Compare November 23, 2023 20:38
Copy link

github-actions bot commented Nov 23, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

@sabauma sabauma force-pushed the fully-connected-fix branch 2 times, most recently from 24645dd to 0d59420 Compare November 23, 2023 22:00
@sabauma sabauma force-pushed the fully-connected-fix branch from 0d59420 to ec8cd7a Compare December 1, 2023 14:17
The current lowering of tosa.fully_connected produces
a linalg.matmul followed by a linalg.generic to add the bias.
The IR looks like the following:

    %init = tensor.empty()
    %zero = linalg.fill ins(0 : f32) outs(%init)
    %prod = linalg.matmul ins(%A, %B) outs(%zero)

    %initB = tensor.empty()
    %result = linalg.generic ins(%prod, %bias) outs(%initB)

This has two down sides:

1. The tensor.empty operations typically result in additional
   allocations after bufferization
2. There is a redundant traversal of the data to add the bias to the
   matrix product.

This extra work can be avoided by leveraging the out-param of
linalg.matmul. The new IR sequence is:

    %init = tensor.empty()
    %broadcast = linalg.broadcast ins(%bias) outs(%init)
    %prod = linalg.matmul ins(%A, %B) outs(%broadcast)

In my experiments, this eliminates one loop and one allocation (post
bufferization) from the generated code.
@sabauma sabauma force-pushed the fully-connected-fix branch from ec8cd7a to 767ab22 Compare December 1, 2023 14:53
@sabauma sabauma merged commit 0d87e25 into llvm:main Dec 1, 2023
sabauma added a commit that referenced this pull request Dec 2, 2023
The existing lowering of tosa.conv2d emits a separate linalg.generic
operator to add the bias after computing the computation.

This change eliminates that additional step by using the generated
linalg.conv_2d_* operator by using the bias value as the input to the
linalg.conv_2d operation.

Rather than:

    %init = tensor.empty()
    %conv = linalg.conv_2d ins(%A, %B) %outs(%init)

    %init = tensor.empty()
    %bias = linalg.generic ins(%conv, %bias) outs(%init2) {
      // perform add operation
    }

The lowering now produces:

    %init = tensor.empty()
    %bias_expanded = linalg.broadcast ins(%bias) outs(%init)
    %conv = linalg.conv_2d ins(%A, %B) %outs(%bias)

This is the same strategy as
#73049 applied to convolutions.
Guzhu-AMD pushed a commit to GPUOpen-Drivers/llvm-project that referenced this pull request Dec 7, 2023
…e10e4323c

Local branch amd-gfx a22e10e Revert "Revert "[AMDGPU] Do not assume stack size for PAL code object indirect calls""
Remote branch main 0d87e25 [mlir][tosa] Improve lowering to tosa.fully_connected (llvm#73049)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants