Skip to content

[mlir][tosa] Improve lowering to tosa.fully_connected #73049

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 51 additions & 40 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,49 @@ linalgIntBroadcastExtSIAdd(PatternRewriter &rewriter, Location loc, Value bias,
.getResult(0);
}

// Broadcast the source value to all the outer dimensions of the result value.
// If required, the element type is expanded using an arith.extsi operation.
static mlir::Value linalgBroadcastAndMaybeExtSI(PatternRewriter &rewriter,
Location loc, Value source,
Value result) {
ShapedType resultTy = cast<ShapedType>(result.getType());
ShapedType sourceTy = cast<ShapedType>(source.getType());
int64_t resultRank = resultTy.getRank();
int64_t sourceRank = sourceTy.getRank();

// The source tensor is broadcast to all the outer dimensions of the
// result tensor.
SmallVector<AffineExpr> sourceDims;
for (auto dim : llvm::seq<int64_t>(0, sourceRank)) {
auto expr = rewriter.getAffineDimExpr(dim + resultRank - sourceRank);
sourceDims.push_back(expr);
}

// Creating maps for the input and output of the broacast-like generic op.
SmallVector<AffineMap, 2> indexingMaps = {
// Broadcast the last dimension of the bias to all output dimensions.
AffineMap::get(/*dimCount=*/resultRank,
/*symbolCount=*/0, sourceDims, rewriter.getContext()),

// Output indexing map.
rewriter.getMultiDimIdentityMap(resultRank)};

// Build the broadcast-like operation as a linalg.generic.
return rewriter
.create<linalg::GenericOp>(
loc, resultTy, ValueRange({source}), result, indexingMaps,
getNParallelLoopsAttrs(resultTy.getRank()),
[](OpBuilder &builder, Location loc, ValueRange args) {
Value biasVal = args[0];
Type resType = args[1].getType();
if (resType != biasVal.getType()) {
biasVal = builder.create<arith::ExtSIOp>(loc, resType, biasVal);
}
builder.create<linalg::YieldOp>(loc, biasVal);
})
.getResult(0);
}

static mlir::Value reifyConstantDim(int64_t attr,
ImplicitLocOpBuilder &builder) {
return builder.createOrFold<arith::IndexCastOp>(
Expand Down Expand Up @@ -618,28 +661,6 @@ class FullyConnectedConverter

SmallVector<Value> filteredDims = condenseValues(dynDims);

// Creating maps for the output of MatMul and the bias
SmallVector<AffineMap, 4> indexingMaps;

// Broadcast the bias.
indexingMaps.push_back(AffineMap::get(/*dimCount=*/2, /*symbolCount=*/0,
{rewriter.getAffineDimExpr(1)},
rewriter.getContext()));

indexingMaps.push_back(rewriter.getMultiDimIdentityMap(outputTy.getRank()));
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(outputTy.getRank()));

auto emptyTensor = rewriter.create<tensor::EmptyOp>(
loc, outputTy.getShape(), outputTy.getElementType(), filteredDims);

// When quantized, the input elemeny type is not the same as the output
auto resultZeroAttr = rewriter.getZeroAttr(outputETy);
Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
Value zeroTensor = rewriter
.create<linalg::FillOp>(loc, ValueRange{zero},
ValueRange{emptyTensor})
.result();

SmallVector<int64_t> permutation{1, 0};
auto permutationAttr = rewriter.getI64TensorAttr(permutation);
Value permutationValue =
Expand All @@ -655,26 +676,17 @@ class FullyConnectedConverter
Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>(
loc, outputTy.getShape(), outputETy, filteredDims);

Value broadcastBias =
linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);

if (!op.getQuantizationInfo()) {
Value matmul = rewriter
.create<linalg::MatmulOp>(
loc, TypeRange{op.getType()},
ValueRange{input, transposedWeight}, zeroTensor)
ValueRange{input, transposedWeight}, broadcastBias)
->getResult(0);

Value result =
rewriter
.create<linalg::GenericOp>(
loc, outputTy, ValueRange({bias, matmul}), biasEmptyTensor,
indexingMaps, getNParallelLoopsAttrs(outputTy.getRank()),
[&](OpBuilder &nestedBuilder, Location nestedLoc,
ValueRange args) {
Value added = nestedBuilder.create<arith::AddFOp>(
loc, args[0], args[1]);
nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
})
.getResult(0);
rewriter.replaceOp(op, result);
rewriter.replaceOp(op, matmul);
return success();
}

Expand All @@ -688,11 +700,10 @@ class FullyConnectedConverter
.create<linalg::QuantizedMatmulOp>(
loc, TypeRange{op.getType()},
ValueRange{input, transposedWeight, inputZp, outputZp},
zeroTensor)
broadcastBias)
->getResult(0);
Value result = linalgIntBroadcastExtSIAdd(rewriter, loc, bias, matmul,
biasEmptyTensor, indexingMaps);
rewriter.replaceOp(op, result);

rewriter.replaceOp(op, matmul);
return success();
}
};
Expand Down
84 changes: 41 additions & 43 deletions mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -82,71 +82,69 @@ func.func @matmul_dyn_output(%arg0: tensor<1x1x8xf32>, %arg1: tensor<1x8x1xf32>)

// -----

// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1) -> (d1)>
// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>

// CHECK-LABEL: @fully_connected
func.func @fully_connected(%arg0: tensor<5x3xf32>, %arg1: tensor<6x3xf32>, %arg2: tensor<6xf32>) -> (tensor<5x6xf32>) {
// CHECK: [[INITT:%.+]] = tensor.empty()
// CHECK: [[ZERO:%.+]] = arith.constant 0
// CHECK: [[FILL:%.+]] = linalg.fill ins([[ZERO]]{{.*}}outs([[INITT]]
// CHECK: [[PERM:%.+]] = arith.constant dense<[1, 0]>
// CHECK: [[TRANSPOSE:%.+]] = tosa.transpose %arg1, [[PERM]]
// CHECK: [[INITB:%.+]] = tensor.empty()
// CHECK: [[MATMUL:%.+]] = linalg.matmul ins(%arg0, [[TRANSPOSE]] : tensor<5x3xf32>, tensor<3x6xf32>) outs([[FILL]] : tensor<5x6xf32>) -> tensor<5x6xf32>
// CHECK: [[ADDED:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%arg2, [[MATMUL]] : tensor<6xf32>, tensor<5x6xf32>) outs([[INITB]] : tensor<5x6xf32>) {
// CHECK: ^bb0(%[[ARG3:[0-9a-zA-Z_]+]]: f32, %[[ARG4:[0-9a-zA-Z_]+]]: f32, %[[ARG5:[0-9a-zA-Z_]+]]: f32):
// CHECK: [[ADD:%.+]] = arith.addf %[[ARG3]], %[[ARG4]] : f32
// CHECK: linalg.yield [[ADD]] : f32
// CHECK: %[[PERM:.+]] = arith.constant dense<[1, 0]> : tensor<2xi64>
// CHECK: %[[TRANSPOSED:.+]] = tosa.transpose %arg1, %[[PERM]] : (tensor<6x3xf32>, tensor<2xi64>) -> tensor<3x6xf32>
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<5x6xf32>

// CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2 : tensor<6xf32>) outs(%[[INIT]] : tensor<5x6xf32>) {
// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
// CHECK: linalg.yield %[[IN]] : f32
// CHECK: } -> tensor<5x6xf32>

// CHECK: linalg.matmul ins(%arg0, %[[TRANSPOSED]] : tensor<5x3xf32>, tensor<3x6xf32>) outs(%[[BROADCAST]] : tensor<5x6xf32>) -> tensor<5x6xf32>

%0 = tosa.fully_connected %arg0, %arg1, %arg2 : (tensor<5x3xf32>, tensor<6x3xf32>, tensor<6xf32>) -> tensor<5x6xf32>
return %0 : tensor<5x6xf32>
}

// -----

// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1) -> (d1)>
// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>

// CHECK-LABEL: @quantized_fully_connected
func.func @quantized_fully_connected(%arg0: tensor<5x3xi8>, %arg1: tensor<6x3xi8>, %arg2: tensor<6xi32>) -> (tensor<5x6xi32>) {
// CHECK: [[INITT:%.+]] = tensor.empty()
// CHECK: [[ZERO:%.+]] = arith.constant 0
// CHECK: [[FILL:%.+]] = linalg.fill ins([[ZERO]]{{.*}}outs([[INITT]]
// CHECK: [[PERM:%.+]] = arith.constant dense<[1, 0]>
// CHECK: [[TRANSPOSE:%.+]] = tosa.transpose %arg1, [[PERM]]
// CHECK: [[INITB:%.+]] = tensor.empty()
// CHECK: [[ONE:%.+]] = arith.constant 1
// CHECK: [[TWO:%.+]] = arith.constant 2
// CHECK: [[MATMUL:%.+]] = linalg.quantized_matmul ins(%arg0, [[TRANSPOSE]], [[ONE]], [[TWO]] : tensor<5x3xi8>, tensor<3x6xi8>, i32, i32) outs([[FILL]] : tensor<5x6xi32>) -> tensor<5x6xi32>
// CHECK: [[ADDED:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%arg2, [[MATMUL]] : tensor<6xi32>, tensor<5x6xi32>) outs([[INITB]]
// CHECK: ^bb0([[IN1:%.+]]: i32, [[IN2:%.+]]: i32, [[UNUSED:%.+]]: i32):
// CHECK: [[ADD:%.+]] = arith.addi
// CHECK: linalg.yield [[ADD]] : i32
// CHECK: %[[PERM:.+]] = arith.constant dense<[1, 0]> : tensor<2xi64>
// CHECK: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[PERM]] : (tensor<6x3xi8>, tensor<2xi64>) -> tensor<3x6xi8>
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<5x6xi32>

// CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2 : tensor<6xi32>) outs(%[[INIT]] : tensor<5x6xi32>) {
// CHECK: ^bb0(%[[IN:.+]]: i32, %[[OUT:.+]]: i32):
// CHECK: linalg.yield %[[IN]] : i32
// CHECK: } -> tensor<5x6xi32>

// CHECK: %[[C1:.+]] = arith.constant 1 : i32
// CHECK: %[[C2:.+]] = arith.constant 2 : i32
// CHECK: linalg.quantized_matmul ins(%arg0, %[[TRANSPOSE]], %[[C1]], %[[C2]] : tensor<5x3xi8>, tensor<3x6xi8>, i32, i32) outs(%[[BROADCAST]] : tensor<5x6xi32>) -> tensor<5x6xi32>

%0 = tosa.fully_connected %arg0, %arg1, %arg2 {quantization_info = #tosa.conv_quant<input_zp = 1, weight_zp = 2>} : (tensor<5x3xi8>, tensor<6x3xi8>, tensor<6xi32>) -> tensor<5x6xi32>
return %0 : tensor<5x6xi32>
}

// -----

// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK: #[[$MAP0:.+]] = affine_map<(d0, d1) -> (d1)>
// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>

// CHECK-LABEL: @fully_connected_dyn
func.func @fully_connected_dyn(%arg0: tensor<?x3xf32>, %arg1: tensor<6x3xf32>, %arg2: tensor<6xf32>) -> (tensor<?x6xf32>) {
// CHECK: %[[C0:.+]] = arith.constant 0
// CHECK: %[[DIM:.+]] = tensor.dim %arg0, %[[C0]]
// CHECK: %[[INITT:.+]] = tensor.empty(%[[DIM]])
// CHECK: %[[ZERO:.+]] = arith.constant 0
// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]]{{.*}}outs(%[[INITT]]
// CHECK: %[[PERM:.+]] = arith.constant dense<[1, 0]>
// CHECK: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[PERM]]
// CHECK: %[[INITB:.+]] = tensor.empty(%[[DIM]])
// CHECK: %[[MATMUL:.+]] = linalg.matmul ins(%arg0, %[[TRANSPOSE]] : tensor<?x3xf32>, tensor<3x6xf32>) outs(%[[FILL]] : tensor<?x6xf32>) -> tensor<?x6xf32>
// CHECK: %[[ADDED:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel"]} ins(%arg2, %[[MATMUL]] : tensor<6xf32>, tensor<?x6xf32>) outs(%[[INITB]] : tensor<?x6xf32>) {
// CHECK: ^bb0(%[[ARG3:[0-9a-zA-Z_]+]]: f32, %[[ARG4:[0-9a-zA-Z_]+]]: f32, %[[ARG5:[0-9a-zA-Z_]+]]: f32):
// CHECK: %[[ADD:.+]] = arith.addf %[[ARG3]], %[[ARG4]] : f32
// CHECK: linalg.yield %[[ADD]] : f32
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[DIM0:.+]] = tensor.dim %arg0, %c0 : tensor<?x3xf32>
// CHECK: %[[PERM:.+]] = arith.constant dense<[1, 0]> : tensor<2xi64>
// CHECK: %[[TRANSPOSED:.+]] = tosa.transpose %arg1, %[[PERM]] : (tensor<6x3xf32>, tensor<2xi64>) -> tensor<3x6xf32>
// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM0]]) : tensor<?x6xf32>

// CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%arg2 : tensor<6xf32>) outs(%[[INIT]] : tensor<?x6xf32>) {
// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
// CHECK: linalg.yield %[[IN]] : f32
// CHECK: } -> tensor<?x6xf32>

// CHECK: linalg.matmul ins(%arg0, %[[TRANSPOSED]] : tensor<?x3xf32>, tensor<3x6xf32>) outs(%[[BROADCAST]] : tensor<?x6xf32>) -> tensor<?x6xf32>

%0 = tosa.fully_connected %arg0, %arg1, %arg2 : (tensor<?x3xf32>, tensor<6x3xf32>, tensor<6xf32>) -> tensor<?x6xf32>
return %0 : tensor<?x6xf32>
Expand Down
36 changes: 36 additions & 0 deletions mlir/test/Integration/Dialect/Tosa/CPU/test-fully-connected.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named,tosa-to-linalg,tosa-to-arith))" | \
// RUN: mlir-opt -one-shot-bufferize -func-bufferize -test-lower-to-llvm | \
// RUN: mlir-cpu-runner -O3 -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils \
// RUN: | FileCheck %s

func.func private @printMemrefF32(tensor<*xf32>)

func.func @main() {
%A = arith.constant dense<[
[8.0, 1.0, 6.0],
[3.0, 5.0, 7.0],
[4.0, 9.0, 2.0]
]> : tensor<3x3xf32>

%B = arith.constant dense<[
[1.0, 1.0, 1.0],
[1.0, 1.0, 1.0],
[1.0, 1.0, 1.0]
]> : tensor<3x3xf32>

%C = arith.constant dense<[0.0, 1.0, 2.0]> : tensor<3xf32>

%result = tosa.fully_connected %A, %B, %C : (tensor<3x3xf32>, tensor<3x3xf32>, tensor<3xf32>) -> tensor<3x3xf32>

%result_unranked = tensor.cast %result : tensor<3x3xf32> to tensor<*xf32>
call @printMemrefF32(%result_unranked) : (tensor<*xf32>) -> ()
return
}

// CHECK: Unranked Memref base@ = {{.*}} rank = 2 offset = 0 sizes = [3, 3] strides = [3, 1] data =
// CHECK-NEXT: [
// CHECK-SAME: [15, 16, 17]
// CHECK-NEXT: [15, 16, 17]
// CHECK-NEXT: [15, 16, 17]
// CHECK-SAME: ]