Skip to content

Commit 0d87e25

Browse files
authored
[mlir][tosa] Improve lowering to tosa.fully_connected (#73049)
The current lowering of tosa.fully_connected produces a linalg.matmul followed by a linalg.generic to add the bias. The IR looks like the following: %init = tensor.empty() %zero = linalg.fill ins(0 : f32) outs(%init) %prod = linalg.matmul ins(%A, %B) outs(%zero) // Add the bias %initB = tensor.empty() %result = linalg.generic ins(%prod, %bias) outs(%initB) { // add bias and product } This has two down sides: 1. The tensor.empty operations typically result in additional allocations after bufferization 2. There is a redundant traversal of the data to add the bias to the matrix product. This extra work can be avoided by leveraging the out-param of linalg.matmul. The new IR sequence is: %init = tensor.empty() %broadcast = linalg.broadcast ins(%bias) outs(%init) %prod = linalg.matmul ins(%A, %B) outs(%broadcast) In my experiments, this eliminates one loop and one allocation (post bufferization) from the generated code.
1 parent faebb1b commit 0d87e25

File tree

3 files changed

+128
-83
lines changed

3 files changed

+128
-83
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp

Lines changed: 51 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,49 @@ linalgIntBroadcastExtSIAdd(PatternRewriter &rewriter, Location loc, Value bias,
8585
.getResult(0);
8686
}
8787

88+
// Broadcast the source value to all the outer dimensions of the result value.
89+
// If required, the element type is expanded using an arith.extsi operation.
90+
static mlir::Value linalgBroadcastAndMaybeExtSI(PatternRewriter &rewriter,
91+
Location loc, Value source,
92+
Value result) {
93+
ShapedType resultTy = cast<ShapedType>(result.getType());
94+
ShapedType sourceTy = cast<ShapedType>(source.getType());
95+
int64_t resultRank = resultTy.getRank();
96+
int64_t sourceRank = sourceTy.getRank();
97+
98+
// The source tensor is broadcast to all the outer dimensions of the
99+
// result tensor.
100+
SmallVector<AffineExpr> sourceDims;
101+
for (auto dim : llvm::seq<int64_t>(0, sourceRank)) {
102+
auto expr = rewriter.getAffineDimExpr(dim + resultRank - sourceRank);
103+
sourceDims.push_back(expr);
104+
}
105+
106+
// Creating maps for the input and output of the broacast-like generic op.
107+
SmallVector<AffineMap, 2> indexingMaps = {
108+
// Broadcast the last dimension of the bias to all output dimensions.
109+
AffineMap::get(/*dimCount=*/resultRank,
110+
/*symbolCount=*/0, sourceDims, rewriter.getContext()),
111+
112+
// Output indexing map.
113+
rewriter.getMultiDimIdentityMap(resultRank)};
114+
115+
// Build the broadcast-like operation as a linalg.generic.
116+
return rewriter
117+
.create<linalg::GenericOp>(
118+
loc, resultTy, ValueRange({source}), result, indexingMaps,
119+
getNParallelLoopsAttrs(resultTy.getRank()),
120+
[](OpBuilder &builder, Location loc, ValueRange args) {
121+
Value biasVal = args[0];
122+
Type resType = args[1].getType();
123+
if (resType != biasVal.getType()) {
124+
biasVal = builder.create<arith::ExtSIOp>(loc, resType, biasVal);
125+
}
126+
builder.create<linalg::YieldOp>(loc, biasVal);
127+
})
128+
.getResult(0);
129+
}
130+
88131
static mlir::Value reifyConstantDim(int64_t attr,
89132
ImplicitLocOpBuilder &builder) {
90133
return builder.createOrFold<arith::IndexCastOp>(
@@ -618,28 +661,6 @@ class FullyConnectedConverter
618661

619662
SmallVector<Value> filteredDims = condenseValues(dynDims);
620663

621-
// Creating maps for the output of MatMul and the bias
622-
SmallVector<AffineMap, 4> indexingMaps;
623-
624-
// Broadcast the bias.
625-
indexingMaps.push_back(AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0,
626-
{rewriter.getAffineDimExpr(1)},
627-
rewriter.getContext()));
628-
629-
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(outputTy.getRank()));
630-
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(outputTy.getRank()));
631-
632-
auto emptyTensor = rewriter.create<tensor::EmptyOp>(
633-
loc, outputTy.getShape(), outputTy.getElementType(), filteredDims);
634-
635-
// When quantized, the input elemeny type is not the same as the output
636-
auto resultZeroAttr = rewriter.getZeroAttr(outputETy);
637-
Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
638-
Value zeroTensor = rewriter
639-
.create<linalg::FillOp>(loc, ValueRange{zero},
640-
ValueRange{emptyTensor})
641-
.result();
642-
643664
SmallVector<int64_t> permutation{1, 0};
644665
auto permutationAttr = rewriter.getI64TensorAttr(permutation);
645666
Value permutationValue =
@@ -655,26 +676,17 @@ class FullyConnectedConverter
655676
Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>(
656677
loc, outputTy.getShape(), outputETy, filteredDims);
657678

679+
Value broadcastBias =
680+
linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);
681+
658682
if (!op.getQuantizationInfo()) {
659683
Value matmul = rewriter
660684
.create<linalg::MatmulOp>(
661685
loc, TypeRange{op.getType()},
662-
ValueRange{input, transposedWeight}, zeroTensor)
686+
ValueRange{input, transposedWeight}, broadcastBias)
663687
->getResult(0);
664688

665-
Value result =
666-
rewriter
667-
.create<linalg::GenericOp>(
668-
loc, outputTy, ValueRange({bias, matmul}), biasEmptyTensor,
669-
indexingMaps, getNParallelLoopsAttrs(outputTy.getRank()),
670-
[&](OpBuilder &nestedBuilder, Location nestedLoc,
671-
ValueRange args) {
672-
Value added = nestedBuilder.create<arith::AddFOp>(
673-
loc, args[0], args[1]);
674-
nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
675-
})
676-
.getResult(0);
677-
rewriter.replaceOp(op, result);
689+
rewriter.replaceOp(op, matmul);
678690
return success();
679691
}
680692

@@ -688,11 +700,10 @@ class FullyConnectedConverter
688700
.create<linalg::QuantizedMatmulOp>(
689701
loc, TypeRange{op.getType()},
690702
ValueRange{input, transposedWeight, inputZp, outputZp},
691-
zeroTensor)
703+
broadcastBias)
692704
->getResult(0);
693-
Value result = linalgIntBroadcastExtSIAdd(rewriter, loc, bias, matmul,
694-
biasEmptyTensor, indexingMaps);
695-
rewriter.replaceOp(op, result);
705+
706+
rewriter.replaceOp(op, matmul);
696707
return success();
697708
}
698709
};

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir

Lines changed: 41 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -82,71 +82,69 @@ func.func @matmul_dyn_output(%arg0: tensor<1x1x8xf32>, %arg1: tensor<1x8x1xf32>)
8282

8383
// -----
8484

85-
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
86-
// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
85+
// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1) -> (d1)>
86+
// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
8787

8888
// CHECK-LABEL: @fully_connected
8989
func.func @fully_connected(%arg0: tensor<5x3xf32>, %arg1: tensor<6x3xf32>, %arg2: tensor<6xf32>) -> (tensor<5x6xf32>) {
90-
// CHECK: [[INITT:%.+]] = tensor.empty()
91-
// CHECK: [[ZERO:%.+]] = arith.constant 0
92-
// CHECK: [[FILL:%.+]] = linalg.fill ins([[ZERO]]{{.*}}outs([[INITT]]
93-
// CHECK: [[PERM:%.+]] = arith.constant dense<[1, 0]>
94-
// CHECK: [[TRANSPOSE:%.+]] = tosa.transpose %arg1, [[PERM]]
95-
// CHECK: [[INITB:%.+]] = tensor.empty()
96-
// CHECK: [[MATMUL:%.+]] = linalg.matmul ins(%arg0, [[TRANSPOSE]] : tensor<5x3xf32>, tensor<3x6xf32>) outs([[FILL]] : tensor<5x6xf32>) -> tensor<5x6xf32>
97-
// CHECK: [[ADDED:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%arg2, [[MATMUL]] : tensor<6xf32>, tensor<5x6xf32>) outs([[INITB]] : tensor<5x6xf32>) {
98-
// CHECK: ^bb0(%[[ARG3:[0-9a-zA-Z_]+]]: f32, %[[ARG4:[0-9a-zA-Z_]+]]: f32, %[[ARG5:[0-9a-zA-Z_]+]]: f32):
99-
// CHECK: [[ADD:%.+]] = arith.addf %[[ARG3]], %[[ARG4]] : f32
100-
// CHECK: linalg.yield [[ADD]] : f32
90+
// CHECK: %[[PERM:.+]] = arith.constant dense<[1, 0]> : tensor<2xi64>
91+
// CHECK: %[[TRANSPOSED:.+]] = tosa.transpose %arg1, %[[PERM]] : (tensor<6x3xf32>, tensor<2xi64>) -> tensor<3x6xf32>
92+
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<5x6xf32>
93+
94+
// CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2 : tensor<6xf32>) outs(%[[INIT]] : tensor<5x6xf32>) {
95+
// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
96+
// CHECK: linalg.yield %[[IN]] : f32
97+
// CHECK: } -> tensor<5x6xf32>
98+
99+
// CHECK: linalg.matmul ins(%arg0, %[[TRANSPOSED]] : tensor<5x3xf32>, tensor<3x6xf32>) outs(%[[BROADCAST]] : tensor<5x6xf32>) -> tensor<5x6xf32>
101100

102101
%0 = tosa.fully_connected %arg0, %arg1, %arg2 : (tensor<5x3xf32>, tensor<6x3xf32>, tensor<6xf32>) -> tensor<5x6xf32>
103102
return %0 : tensor<5x6xf32>
104103
}
105104

106105
// -----
107106

108-
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
109-
// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
107+
// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1) -> (d1)>
108+
// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
110109

111110
// CHECK-LABEL: @quantized_fully_connected
112111
func.func @quantized_fully_connected(%arg0: tensor<5x3xi8>, %arg1: tensor<6x3xi8>, %arg2: tensor<6xi32>) -> (tensor<5x6xi32>) {
113-
// CHECK: [[INITT:%.+]] = tensor.empty()
114-
// CHECK: [[ZERO:%.+]] = arith.constant 0
115-
// CHECK: [[FILL:%.+]] = linalg.fill ins([[ZERO]]{{.*}}outs([[INITT]]
116-
// CHECK: [[PERM:%.+]] = arith.constant dense<[1, 0]>
117-
// CHECK: [[TRANSPOSE:%.+]] = tosa.transpose %arg1, [[PERM]]
118-
// CHECK: [[INITB:%.+]] = tensor.empty()
119-
// CHECK: [[ONE:%.+]] = arith.constant 1
120-
// CHECK: [[TWO:%.+]] = arith.constant 2
121-
// CHECK: [[MATMUL:%.+]] = linalg.quantized_matmul ins(%arg0, [[TRANSPOSE]], [[ONE]], [[TWO]] : tensor<5x3xi8>, tensor<3x6xi8>, i32, i32) outs([[FILL]] : tensor<5x6xi32>) -> tensor<5x6xi32>
122-
// CHECK: [[ADDED:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%arg2, [[MATMUL]] : tensor<6xi32>, tensor<5x6xi32>) outs([[INITB]]
123-
// CHECK: ^bb0([[IN1:%.+]]: i32, [[IN2:%.+]]: i32, [[UNUSED:%.+]]: i32):
124-
// CHECK: [[ADD:%.+]] = arith.addi
125-
// CHECK: linalg.yield [[ADD]] : i32
112+
// CHECK: %[[PERM:.+]] = arith.constant dense<[1, 0]> : tensor<2xi64>
113+
// CHECK: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[PERM]] : (tensor<6x3xi8>, tensor<2xi64>) -> tensor<3x6xi8>
114+
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<5x6xi32>
115+
116+
// CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2 : tensor<6xi32>) outs(%[[INIT]] : tensor<5x6xi32>) {
117+
// CHECK: ^bb0(%[[IN:.+]]: i32, %[[OUT:.+]]: i32):
118+
// CHECK: linalg.yield %[[IN]] : i32
119+
// CHECK: } -> tensor<5x6xi32>
120+
121+
// CHECK: %[[C1:.+]] = arith.constant 1 : i32
122+
// CHECK: %[[C2:.+]] = arith.constant 2 : i32
123+
// CHECK: linalg.quantized_matmul ins(%arg0, %[[TRANSPOSE]], %[[C1]], %[[C2]] : tensor<5x3xi8>, tensor<3x6xi8>, i32, i32) outs(%[[BROADCAST]] : tensor<5x6xi32>) -> tensor<5x6xi32>
124+
126125
%0 = tosa.fully_connected %arg0, %arg1, %arg2 {quantization_info = #tosa.conv_quant<input_zp = 1, weight_zp = 2>} : (tensor<5x3xi8>, tensor<6x3xi8>, tensor<6xi32>) -> tensor<5x6xi32>
127126
return %0 : tensor<5x6xi32>
128127
}
129128

130129
// -----
131130

132-
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
133-
// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
131+
// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1) -> (d1)>
132+
// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
134133

135134
// CHECK-LABEL: @fully_connected_dyn
136135
func.func @fully_connected_dyn(%arg0: tensor<?x3xf32>, %arg1: tensor<6x3xf32>, %arg2: tensor<6xf32>) -> (tensor<?x6xf32>) {
137-
// CHECK: %[[C0:.+]] = arith.constant 0
138-
// CHECK: %[[DIM:.+]] = tensor.dim %arg0, %[[C0]]
139-
// CHECK: %[[INITT:.+]] = tensor.empty(%[[DIM]])
140-
// CHECK: %[[ZERO:.+]] = arith.constant 0
141-
// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]]{{.*}}outs(%[[INITT]]
142-
// CHECK: %[[PERM:.+]] = arith.constant dense<[1, 0]>
143-
// CHECK: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[PERM]]
144-
// CHECK: %[[INITB:.+]] = tensor.empty(%[[DIM]])
145-
// CHECK: %[[MATMUL:.+]] = linalg.matmul ins(%arg0, %[[TRANSPOSE]] : tensor<?x3xf32>, tensor<3x6xf32>) outs(%[[FILL]] : tensor<?x6xf32>) -> tensor<?x6xf32>
146-
// CHECK: %[[ADDED:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%arg2, %[[MATMUL]] : tensor<6xf32>, tensor<?x6xf32>) outs(%[[INITB]] : tensor<?x6xf32>) {
147-
// CHECK: ^bb0(%[[ARG3:[0-9a-zA-Z_]+]]: f32, %[[ARG4:[0-9a-zA-Z_]+]]: f32, %[[ARG5:[0-9a-zA-Z_]+]]: f32):
148-
// CHECK: %[[ADD:.+]] = arith.addf %[[ARG3]], %[[ARG4]] : f32
149-
// CHECK: linalg.yield %[[ADD]] : f32
136+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
137+
// CHECK: %[[DIM0:.+]] = tensor.dim %arg0, %c0 : tensor<?x3xf32>
138+
// CHECK: %[[PERM:.+]] = arith.constant dense<[1, 0]> : tensor<2xi64>
139+
// CHECK: %[[TRANSPOSED:.+]] = tosa.transpose %arg1, %[[PERM]] : (tensor<6x3xf32>, tensor<2xi64>) -> tensor<3x6xf32>
140+
// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM0]]) : tensor<?x6xf32>
141+
142+
// CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2 : tensor<6xf32>) outs(%[[INIT]] : tensor<?x6xf32>) {
143+
// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
144+
// CHECK: linalg.yield %[[IN]] : f32
145+
// CHECK: } -> tensor<?x6xf32>
146+
147+
// CHECK: linalg.matmul ins(%arg0, %[[TRANSPOSED]] : tensor<?x3xf32>, tensor<3x6xf32>) outs(%[[BROADCAST]] : tensor<?x6xf32>) -> tensor<?x6xf32>
150148

151149
%0 = tosa.fully_connected %arg0, %arg1, %arg2 : (tensor<?x3xf32>, tensor<6x3xf32>, tensor<6xf32>) -> tensor<?x6xf32>
152150
return %0 : tensor<?x6xf32>
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named,tosa-to-linalg,tosa-to-arith))" | \
2+
// RUN: mlir-opt -one-shot-bufferize -func-bufferize -test-lower-to-llvm | \
3+
// RUN: mlir-cpu-runner -O3 -e main -entry-point-result=void \
4+
// RUN: -shared-libs=%mlir_runner_utils \
5+
// RUN: | FileCheck %s
6+
7+
func.func private @printMemrefF32(tensor<*xf32>)
8+
9+
func.func @main() {
10+
%A = arith.constant dense<[
11+
[8.0, 1.0, 6.0],
12+
[3.0, 5.0, 7.0],
13+
[4.0, 9.0, 2.0]
14+
]> : tensor<3x3xf32>
15+
16+
%B = arith.constant dense<[
17+
[1.0, 1.0, 1.0],
18+
[1.0, 1.0, 1.0],
19+
[1.0, 1.0, 1.0]
20+
]> : tensor<3x3xf32>
21+
22+
%C = arith.constant dense<[0.0, 1.0, 2.0]> : tensor<3xf32>
23+
24+
%result = tosa.fully_connected %A, %B, %C : (tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>) -> tensor<3x3xf32>
25+
26+
%result_unranked = tensor.cast %result : tensor<3x3xf32> to tensor<*xf32>
27+
call @printMemrefF32(%result_unranked) : (tensor<*xf32>) -> ()
28+
return
29+
}
30+
31+
// CHECK: Unranked Memref base@ = {{.*}} rank = 2 offset = 0 sizes = [3, 3] strides = [3, 1] data =
32+
// CHECK-NEXT: [
33+
// CHECK-SAME: [15, 16, 17]
34+
// CHECK-NEXT: [15, 16, 17]
35+
// CHECK-NEXT: [15, 16, 17]
36+
// CHECK-SAME: ]

0 commit comments

Comments
 (0)