-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][tosa] Improve lowering to tosa.fully_connected #73049
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-tosa @llvm/pr-subscribers-mlir Author: Spenser Bauman (sabauma) ChangesThe current lowering of tosa.fully_connected produces a linalg.matmul followed by a linalg.generic to add the bias. The IR looks like the following:
This has two down sides:
This extra work can be avoided by leveraging the out-param of linalg.matmul. The new IR sequence is:
In my experiments, this eliminates one loop and one allocation (post bufferization) from the generated code. Full diff: https://github.com/llvm/llvm-project/pull/73049.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 99a65f63038a43f..b9a7b778ce4017d 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -632,17 +632,6 @@ class FullyConnectedConverter
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(outputTy.getRank()));
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(outputTy.getRank()));
- auto emptyTensor = rewriter.create<tensor::EmptyOp>(
- loc, outputTy.getShape(), outputTy.getElementType(), filteredDims);
-
- // When quantized, the input elemeny type is not the same as the output
- auto resultZeroAttr = rewriter.getZeroAttr(outputETy);
- Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
- Value zeroTensor = rewriter
- .create<linalg::FillOp>(loc, ValueRange{zero},
- ValueRange{emptyTensor})
- .result();
-
SmallVector<int64_t> permutation{1, 0};
auto permutationAttr = rewriter.getI64TensorAttr(permutation);
Value permutationValue =
@@ -658,26 +647,18 @@ class FullyConnectedConverter
Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>(
loc, outputTy.getShape(), outputETy, filteredDims);
+ auto broadcastDims = DenseI64ArrayAttr::get(getContext(), {0});
+ Value biasInitTensor = rewriter.create<linalg::BroadcastOp>(
+ loc, bias, biasEmptyTensor, broadcastDims)->getResult(0);
+
if (!op.getQuantizationInfo()) {
Value matmul = rewriter
.create<linalg::MatmulOp>(
loc, TypeRange{op.getType()},
- ValueRange{input, transposedWeight}, zeroTensor)
+ ValueRange{input, transposedWeight}, biasInitTensor)
->getResult(0);
- Value result =
- rewriter
- .create<linalg::GenericOp>(
- loc, outputTy, ValueRange({bias, matmul}), biasEmptyTensor,
- indexingMaps, getNParallelLoopsAttrs(outputTy.getRank()),
- [&](OpBuilder &nestedBuilder, Location nestedLoc,
- ValueRange args) {
- Value added = nestedBuilder.create<arith::AddFOp>(
- loc, args[0], args[1]);
- nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
- })
- .getResult(0);
- rewriter.replaceOp(op, result);
+ rewriter.replaceOp(op, matmul);
return success();
}
@@ -691,11 +672,10 @@ class FullyConnectedConverter
.create<linalg::QuantizedMatmulOp>(
loc, TypeRange{op.getType()},
ValueRange{input, transposedWeight, inputZp, outputZp},
- zeroTensor)
+ biasInitTensor)
->getResult(0);
- Value result = linalgIntBroadcastExtSIAdd(rewriter, loc, bias, matmul,
- biasEmptyTensor, indexingMaps);
- rewriter.replaceOp(op, result);
+
+ rewriter.replaceOp(op, matmul);
return success();
}
};
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 1cf7c8dee606899..3b6d574b73b1ab6 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -68,22 +68,13 @@ func.func @matmul_dyn_independent_dim(%arg0: tensor<1x5x?xf32>, %arg1: tensor<1x
// -----
-// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
-// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
-
// CHECK-LABEL: @fully_connected
func.func @fully_connected(%arg0: tensor<5x3xf32>, %arg1: tensor<6x3xf32>, %arg2: tensor<6xf32>) -> (tensor<5x6xf32>) {
- // CHECK: [[INITT:%.+]] = tensor.empty()
- // CHECK: [[ZERO:%.+]] = arith.constant 0
- // CHECK: [[FILL:%.+]] = linalg.fill ins([[ZERO]]{{.*}}outs([[INITT]]
- // CHECK: [[PERM:%.+]] = arith.constant dense<[1, 0]>
- // CHECK: [[TRANSPOSE:%.+]] = tosa.transpose %arg1, [[PERM]]
- // CHECK: [[INITB:%.+]] = tensor.empty()
- // CHECK: [[MATMUL:%.+]] = linalg.matmul ins(%arg0, [[TRANSPOSE]] : tensor<5x3xf32>, tensor<3x6xf32>) outs([[FILL]] : tensor<5x6xf32>) -> tensor<5x6xf32>
- // CHECK: [[ADDED:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%arg2, [[MATMUL]] : tensor<6xf32>, tensor<5x6xf32>) outs([[INITB]] : tensor<5x6xf32>) {
- // CHECK: ^bb0(%[[ARG3:[0-9a-zA-Z_]+]]: f32, %[[ARG4:[0-9a-zA-Z_]+]]: f32, %[[ARG5:[0-9a-zA-Z_]+]]: f32):
- // CHECK: [[ADD:%.+]] = arith.addf %[[ARG3]], %[[ARG4]] : f32
- // CHECK: linalg.yield [[ADD]] : f32
+ // %[[PERM:.+]] = arith.constant dense<[1, 0]> : tensor<2xi64>
+ // %[[TRANSPOSED:.+]] = tosa.transpose %arg1, %[[PERM]] : (tensor<6x3xf32>, tensor<2xi64>) -> tensor<3x6xf32>
+ // %[[INIT:.+]] = tensor.empty() : tensor<5x6xf32>
+ // %[[BROADCASTED:.+]] = linalg.broadcast ins(%arg2 : tensor<6xf32>) outs(%[[INIT]] : tensor<5x6xf32>) dimensions = [0]
+ // linalg.matmul ins(%arg0, %[[TRANSPOSED]] : tensor<5x3xf32>, tensor<3x6xf32>) outs(%[[BROADCASTED]] : tensor<5x6xf32>) -> tensor<5x6xf32>
%0 = tosa.fully_connected %arg0, %arg1, %arg2 : (tensor<5x3xf32>, tensor<6x3xf32>, tensor<6xf32>) -> tensor<5x6xf32>
return %0 : tensor<5x6xf32>
@@ -91,48 +82,31 @@ func.func @fully_connected(%arg0: tensor<5x3xf32>, %arg1: tensor<6x3xf32>, %arg2
// -----
-// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
-// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
-
// CHECK-LABEL: @quantized_fully_connected
func.func @quantized_fully_connected(%arg0: tensor<5x3xi8>, %arg1: tensor<6x3xi8>, %arg2: tensor<6xi32>) -> (tensor<5x6xi32>) {
- // CHECK: [[INITT:%.+]] = tensor.empty()
- // CHECK: [[ZERO:%.+]] = arith.constant 0
- // CHECK: [[FILL:%.+]] = linalg.fill ins([[ZERO]]{{.*}}outs([[INITT]]
- // CHECK: [[PERM:%.+]] = arith.constant dense<[1, 0]>
- // CHECK: [[TRANSPOSE:%.+]] = tosa.transpose %arg1, [[PERM]]
- // CHECK: [[INITB:%.+]] = tensor.empty()
- // CHECK: [[ONE:%.+]] = arith.constant 1
- // CHECK: [[TWO:%.+]] = arith.constant 2
- // CHECK: [[MATMUL:%.+]] = linalg.quantized_matmul ins(%arg0, [[TRANSPOSE]], [[ONE]], [[TWO]] : tensor<5x3xi8>, tensor<3x6xi8>, i32, i32) outs([[FILL]] : tensor<5x6xi32>) -> tensor<5x6xi32>
- // CHECK: [[ADDED:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%arg2, [[MATMUL]] : tensor<6xi32>, tensor<5x6xi32>) outs([[INITB]]
- // CHECK: ^bb0([[IN1:%.+]]: i32, [[IN2:%.+]]: i32, [[UNUSED:%.+]]: i32):
- // CHECK: [[ADD:%.+]] = arith.addi
- // CHECK: linalg.yield [[ADD]] : i32
+ // %[[PERM:.+]] = arith.constant dense<[1, 0]> : tensor<2xi64>
+ // %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[PERM]] : (tensor<6x3xi8>, tensor<2xi64>) -> tensor<3x6xi8>
+ // %[[INIT:.+]] = tensor.empty() : tensor<5x6xi32>
+ // %[[BROADCASTED:.+]] = linalg.broadcast ins(%arg2 : tensor<6xi32>) outs(%[[INIT]] : tensor<5x6xi32>) dimensions = [0]
+ // %[[C1:.+]] = arith.constant 1 : i32
+ // %[[C2:.+]] = arith.constant 2 : i32
+ // linalg.quantized_matmul ins(%arg0, %[[BROADCASTED]], %[[C1]], %[[C2]] : tensor<5x3xi8>, tensor<3x6xi8>, i32, i32) outs(%[[BROADCASTED]]) : tensor<5x6xi32>) -> tensor<5x6xi32>
+
%0 = tosa.fully_connected %arg0, %arg1, %arg2 {quantization_info = #tosa.conv_quant<input_zp = 1, weight_zp = 2>} : (tensor<5x3xi8>, tensor<6x3xi8>, tensor<6xi32>) -> tensor<5x6xi32>
return %0 : tensor<5x6xi32>
}
// -----
-// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
-// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
-
// CHECK-LABEL: @fully_connected_dyn
func.func @fully_connected_dyn(%arg0: tensor<?x3xf32>, %arg1: tensor<6x3xf32>, %arg2: tensor<6xf32>) -> (tensor<?x6xf32>) {
- // CHECK: %[[C0:.+]] = arith.constant 0
- // CHECK: %[[DIM:.+]] = tensor.dim %arg0, %[[C0]]
- // CHECK: %[[INITT:.+]] = tensor.empty(%[[DIM]])
- // CHECK: %[[ZERO:.+]] = arith.constant 0
- // CHECK: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]]{{.*}}outs(%[[INITT]]
- // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 0]>
- // CHECK: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[PERM]]
- // CHECK: %[[INITB:.+]] = tensor.empty(%[[DIM]])
- // CHECK: %[[MATMUL:.+]] = linalg.matmul ins(%arg0, %[[TRANSPOSE]] : tensor<?x3xf32>, tensor<3x6xf32>) outs(%[[FILL]] : tensor<?x6xf32>) -> tensor<?x6xf32>
- // CHECK: %[[ADDED:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%arg2, %[[MATMUL]] : tensor<6xf32>, tensor<?x6xf32>) outs(%[[INITB]] : tensor<?x6xf32>) {
- // CHECK: ^bb0(%[[ARG3:[0-9a-zA-Z_]+]]: f32, %[[ARG4:[0-9a-zA-Z_]+]]: f32, %[[ARG5:[0-9a-zA-Z_]+]]: f32):
- // CHECK: %[[ADD:.+]] = arith.addf %[[ARG3]], %[[ARG4]] : f32
- // CHECK: linalg.yield %[[ADD]] : f32
+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
+ // CHECK: %[[DIM0:.+]] = tensor.dim %arg0, %c0 : tensor<?x3xf32>
+ // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 0]> : tensor<2xi64>
+ // CHECK: %[[TRANSPOSED:.+]] = tosa.transpose %arg1, %[[PERM]] : (tensor<6x3xf32>, tensor<2xi64>) -> tensor<3x6xf32>
+ // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM0]]) : tensor<?x6xf32>
+ // CHECK: %[[BROADCASTED:.+]] = linalg.broadcast ins(%arg2 : tensor<6xf32>) outs(%[[INIT]] : tensor<?x6xf32>) dimensions = [0]
+ // CHECK: linalg.matmul ins(%arg0, %[[TRANSPOSED]] : tensor<?x3xf32>, tensor<3x6xf32>) outs(%[[BROADCASTED]] : tensor<?x6xf32>) -> tensor<?x6xf32>
%0 = tosa.fully_connected %arg0, %arg1, %arg2 : (tensor<?x3xf32>, tensor<6x3xf32>, tensor<6xf32>) -> tensor<?x6xf32>
return %0 : tensor<?x6xf32>
|
9b5c5b2
to
1d97a44
Compare
c59b55a
to
654ca3e
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
24645dd
to
0d59420
Compare
0d59420
to
ec8cd7a
Compare
The current lowering of tosa.fully_connected produces a linalg.matmul followed by a linalg.generic to add the bias. The IR looks like the following: %init = tensor.empty() %zero = linalg.fill ins(0 : f32) outs(%init) %prod = linalg.matmul ins(%A, %B) outs(%zero) %initB = tensor.empty() %result = linalg.generic ins(%prod, %bias) outs(%initB) This has two down sides: 1. The tensor.empty operations typically result in additional allocations after bufferization 2. There is a redundant traversal of the data to add the bias to the matrix product. This extra work can be avoided by leveraging the out-param of linalg.matmul. The new IR sequence is: %init = tensor.empty() %broadcast = linalg.broadcast ins(%bias) outs(%init) %prod = linalg.matmul ins(%A, %B) outs(%broadcast) In my experiments, this eliminates one loop and one allocation (post bufferization) from the generated code.
ec8cd7a
to
767ab22
Compare
The existing lowering of tosa.conv2d emits a separate linalg.generic operator to add the bias after computing the computation. This change eliminates that additional step by using the generated linalg.conv_2d_* operator by using the bias value as the input to the linalg.conv_2d operation. Rather than: %init = tensor.empty() %conv = linalg.conv_2d ins(%A, %B) %outs(%init) %init = tensor.empty() %bias = linalg.generic ins(%conv, %bias) outs(%init2) { // perform add operation } The lowering now produces: %init = tensor.empty() %bias_expanded = linalg.broadcast ins(%bias) outs(%init) %conv = linalg.conv_2d ins(%A, %B) %outs(%bias) This is the same strategy as #73049 applied to convolutions.
…e10e4323c Local branch amd-gfx a22e10e Revert "Revert "[AMDGPU] Do not assume stack size for PAL code object indirect calls"" Remote branch main 0d87e25 [mlir][tosa] Improve lowering to tosa.fully_connected (llvm#73049)
The current lowering of tosa.fully_connected produces a linalg.matmul followed by a linalg.generic to add the bias. The IR looks like the following:
This has two down sides:
This extra work can be avoided by leveraging the out-param of linalg.matmul. The new IR sequence is:
In my experiments, this eliminates one loop and one allocation (post bufferization) from the generated code.