Skip to content

Commit 608e8f7

Browse files
committed
feat(TosaToLinalg): use linalg.matmul instead of linalg.batch_matmul when converting tosa.matmul with 1x batch.
1 parent 20684a4 commit 608e8f7

File tree

5 files changed

+170
-13
lines changed

5 files changed

+170
-13
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -974,7 +974,10 @@ def TosaToLinalgNamed
974974
Pass that converts TOSA operations to the equivalent operations using the
975975
Linalg named operations.
976976
}];
977-
977+
let options = [
978+
Option<"useMatmulForSingleBatch", "use-matmul-for-single-batch", "bool", /*default=*/"false",
979+
"Use linalg.matmul for 1x batch size instead of linalg.batch_matmul.">
980+
];
978981
let constructor = "tosa::createTosaToLinalgNamed()";
979982
}
980983

mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ void addTosaToLinalgPasses(OpPassManager &pm,
3838
void populateTosaToLinalgConversionPatterns(RewritePatternSet *patterns);
3939

4040
/// Populates conversion passes from TOSA dialect to Linalg named operations.
41-
void populateTosaToLinalgNamedConversionPatterns(RewritePatternSet *patterns);
41+
void populateTosaToLinalgNamedConversionPatterns(
42+
RewritePatternSet *patterns, bool useMatmulForSingleBatch = false);
4243

4344
} // namespace tosa
4445
} // namespace mlir

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp

Lines changed: 88 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,35 @@ static void createDepthwiseConvCollapseMap(
173173
rewriter.getAffineDimExpr(outputRank));
174174
}
175175

176+
static FailureOr<Value> collapseValue(OpBuilder &rewriter, Location loc,
177+
Value value, ShapedType type) {
178+
auto reassociationMap = getReassociationIndicesForReshape(
179+
cast<ShapedType>(value.getType()), type);
180+
if (!reassociationMap.has_value())
181+
return failure();
182+
183+
return Value(rewriter.create<tensor::CollapseShapeOp>(
184+
loc, type, value, reassociationMap.value()));
185+
}
186+
187+
static FailureOr<SmallVector<Value>>
188+
collapseValues(OpBuilder &rewriter, Location loc, SmallVector<Value> values,
189+
SmallVector<ShapedType> newTys, bool useMatmulForBatchOne) {
190+
if (!useMatmulForBatchOne)
191+
return values;
192+
193+
SmallVector<Value> newValues;
194+
for (auto [idx, value] : llvm::enumerate(values)) {
195+
196+
auto newValue = collapseValue(rewriter, loc, value, newTys[idx]);
197+
if (failed(newValue))
198+
return failure();
199+
200+
newValues.push_back(*newValue);
201+
}
202+
return newValues;
203+
}
204+
176205
namespace {
177206

178207
template <typename TosaConvOp, typename LinalgConvOp, typename LinalgConvQOp>
@@ -498,6 +527,9 @@ class DepthwiseConvConverter
498527

499528
class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
500529
public:
530+
MatMulConverter(MLIRContext *ctx, bool useMatmulForSingleBatch)
531+
: OpConversionPattern<tosa::MatMulOp>(ctx),
532+
useMatmulForSingleBatch(useMatmulForSingleBatch) {}
501533
using OpConversionPattern<tosa::MatMulOp>::OpConversionPattern;
502534
LogicalResult
503535
matchAndRewrite(tosa::MatMulOp op, OpAdaptor adaptor,
@@ -525,20 +557,55 @@ class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
525557
dynDims[2] = rewriter.create<tensor::DimOp>(loc, op->getOperand(1), 2);
526558
}
527559

560+
auto getTypeWithoutBatch = [&](ShapedType ty) {
561+
auto shape2D = {ty.getDimSize(1), ty.getDimSize(2)};
562+
return RankedTensorType::get(shape2D, ty.getElementType());
563+
};
564+
528565
SmallVector<Value> filteredDims = condenseValues(dynDims);
529566

567+
bool useMatmulForBatchOne =
568+
outputTy.getDimSize(0) == 1 && this->useMatmulForSingleBatch;
569+
570+
auto newInput1Type = getTypeWithoutBatch(firstOperandTy);
571+
auto newInput2Type = getTypeWithoutBatch(secondOperandTy);
572+
auto newOutputType = getTypeWithoutBatch(outputTy);
573+
574+
SmallVector<Value> inputs = {adaptor.getA(), adaptor.getB()};
575+
auto inputsOrFailure =
576+
collapseValues(rewriter, loc, inputs, {newInput1Type, newInput2Type},
577+
useMatmulForBatchOne);
578+
auto matmulMap = getReassociationIndicesForReshape(newOutputType, outputTy);
579+
580+
// If any of the reassociations of indices failed, don't use matmul.
581+
if (failed(inputsOrFailure) || !matmulMap.has_value()) {
582+
useMatmulForBatchOne = false;
583+
} else {
584+
inputs = *inputsOrFailure;
585+
}
586+
530587
auto zeroAttr = rewriter.getZeroAttr(outputElementTy);
531588
Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
532-
auto emptyTensor = rewriter.create<tensor::EmptyOp>(
533-
loc, outputTy.getShape(), outputTy.getElementType(), filteredDims);
589+
590+
Value emptyTensor = rewriter.create<tensor::EmptyOp>(
591+
loc,
592+
useMatmulForBatchOne ? newOutputType.getShape() : outputTy.getShape(),
593+
outputElementTy, filteredDims);
594+
534595
Value zeroTensor = rewriter
535596
.create<linalg::FillOp>(loc, ValueRange{zero},
536597
ValueRange{emptyTensor})
537598
.result();
599+
538600
if (!op.getQuantizationInfo()) {
539-
rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>(
540-
op, TypeRange{op.getType()},
541-
ValueRange{adaptor.getA(), adaptor.getB()}, ValueRange{zeroTensor});
601+
if (useMatmulForBatchOne) {
602+
auto matmul = rewriter.create<linalg::MatmulOp>(
603+
loc, TypeRange{newOutputType}, inputs, ValueRange{zeroTensor});
604+
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
605+
op, outputTy, matmul->getResult(0), matmulMap.value());
606+
} else
607+
rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>(
608+
op, TypeRange{op.getType()}, inputs, ValueRange{zeroTensor});
542609
return success();
543610
}
544611

@@ -547,12 +614,22 @@ class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
547614
loc, rewriter.getI32IntegerAttr(quantizationInfo.getAZp()));
548615
auto bZp = rewriter.create<arith::ConstantOp>(
549616
loc, rewriter.getI32IntegerAttr(quantizationInfo.getBZp()));
550-
rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
551-
op, TypeRange{op.getType()},
552-
ValueRange{adaptor.getA(), adaptor.getB(), aZp, bZp}, zeroTensor);
617+
if (useMatmulForBatchOne) {
618+
auto matmul = rewriter.create<linalg::QuantizedMatmulOp>(
619+
loc, TypeRange{newOutputType},
620+
ValueRange{inputs[0], inputs[1], aZp, bZp}, zeroTensor);
621+
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
622+
op, outputTy, matmul->getResult(0), matmulMap.value());
623+
} else
624+
rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
625+
op, TypeRange{op.getType()},
626+
ValueRange{inputs[0], inputs[1], aZp, bZp}, zeroTensor);
553627

554628
return success();
555629
}
630+
631+
private:
632+
bool useMatmulForSingleBatch;
556633
};
557634

558635
class FullyConnectedConverter
@@ -974,15 +1051,16 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
9741051
} // namespace
9751052

9761053
void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
977-
RewritePatternSet *patterns) {
1054+
RewritePatternSet *patterns, bool useMatmulForSingleBatch) {
9781055
patterns->add<
9791056
// clang-format off
9801057
ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcHwcfOp, linalg::Conv2DNhwcHwcfQOp>,
9811058
ConvConverter<tosa::Conv3DOp, linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNdhwcDhwcfQOp>,
9821059
DepthwiseConvConverter,
983-
MatMulConverter,
9841060
MaxPool2dConverter,
9851061
AvgPool2dConverter,
9861062
FullyConnectedConverter>(patterns->getContext());
1063+
patterns->add<
1064+
MatMulConverter>(patterns->getContext(), useMatmulForSingleBatch);
9871065
// clang-format on
9881066
}

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ struct TosaToLinalgNamed
6161
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
6262

6363
FunctionOpInterface func = getOperation();
64-
mlir::tosa::populateTosaToLinalgNamedConversionPatterns(&patterns);
64+
mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
65+
&patterns, this->useMatmulForSingleBatch);
6566
if (failed(applyFullConversion(func, target, std::move(patterns))))
6667
signalPassFailure();
6768
}
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named{use-matmul-for-single-batch=true},cse))" %s -verify-diagnostics -o -| FileCheck %s
2+
3+
// CHECK-LABEL: @matmul
4+
func.func @matmul(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x6xf32>) -> (tensor<1x5x6xf32>) {
5+
// CHECK: %[[COLLAPSE1:.*]] = tensor.collapse_shape %arg0 {{\[\[}}0, 1], [2]] : tensor<1x5x3xf32> into tensor<5x3xf32>
6+
// CHECK: %[[COLLAPSE2:.*]] = tensor.collapse_shape %arg1 {{\[\[}}0, 1], [2]] : tensor<1x3x6xf32> into tensor<3x6xf32>
7+
// CHECK: %[[CONST:.*]] = arith.constant 0.000000e+00
8+
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<5x6xf32>
9+
// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[CONST]] : f32) outs(%[[EMPTY]] : tensor<5x6xf32>) -> tensor<5x6xf32>
10+
// CHECK: %[[MATMUL:.*]] = linalg.matmul ins(%[[COLLAPSE1]], %[[COLLAPSE2]] : tensor<5x3xf32>, tensor<3x6xf32>) outs(%[[FILL]] : tensor<5x6xf32>) -> tensor<5x6xf32>
11+
// CHECK: tensor.expand_shape %[[MATMUL]] {{\[\[}}0, 1], [2]] : tensor<5x6xf32> into tensor<1x5x6xf32>
12+
%0 = "tosa.matmul"(%arg0, %arg1) : (tensor<1x5x3xf32>, tensor<1x3x6xf32>) -> (tensor<1x5x6xf32>)
13+
return %0 : tensor<1x5x6xf32>
14+
}
15+
16+
// -----
17+
18+
// CHECK-LABEL: @matmul_quantized
19+
func.func @matmul_quantized(%arg0: tensor<1x5x3xi8>, %arg1: tensor<1x3x6xi8>) -> (tensor<1x5x6xi32>) {
20+
// CHECK: %[[COLLAPSE1:.*]] = tensor.collapse_shape %arg0 {{\[\[}}0, 1], [2]] : tensor<1x5x3xi8> into tensor<5x3xi8>
21+
// CHECK: %[[COLLAPSE2:.*]] = tensor.collapse_shape %arg1 {{\[\[}}0, 1], [2]] : tensor<1x3x6xi8> into tensor<3x6xi8>
22+
// CHECK: %[[VAL_4:.*]] = arith.constant 0
23+
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<5x6xi32>
24+
// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[VAL_4]] : i32) outs(%[[EMPTY]] : tensor<5x6xi32>) -> tensor<5x6xi32>
25+
// CHECK: %[[CONST1:.*]] = arith.constant 1
26+
// CHECK: %[[CONST2:.*]] = arith.constant 2
27+
// CHECK: %[[VAL_9:.*]] = linalg.quantized_matmul ins(%[[COLLAPSE1]], %[[COLLAPSE2]], %[[CONST1]], %[[CONST2]] : tensor<5x3xi8>, tensor<3x6xi8>, i32, i32) outs(%[[FILL]] : tensor<5x6xi32>) -> tensor<5x6xi32>
28+
// CHECK: tensor.expand_shape %[[VAL_9]] {{\[\[}}0, 1], [2]] : tensor<5x6xi32> into tensor<1x5x6xi32>
29+
%0 = "tosa.matmul"(%arg0, %arg1) {quantization_info = #tosa.matmul_quant<a_zp = 1, b_zp = 2>} : (tensor<1x5x3xi8>, tensor<1x3x6xi8>) -> (tensor<1x5x6xi32>)
30+
return %0 : tensor<1x5x6xi32>
31+
}
32+
33+
// -----
34+
35+
// CHECK-LABEL: @matmul_dyn_batch_no_matmul
36+
func.func @matmul_dyn_batch_no_matmul(%arg0: tensor<?x5x3xf32>, %arg1: tensor<?x3x6xf32>) -> (tensor<?x5x6xf32>) {
37+
// CHECK: %[[C0:.+]] = arith.constant 0
38+
// CHECK: %[[DIM:.+]] = tensor.dim %arg0, %[[C0]]
39+
// CHECK: %[[C0_0:.+]] = arith.constant 0
40+
// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM]])
41+
// CHECK: %[[FILLED:.+]] = linalg.fill ins(%[[C0_0]] : f32) outs(%[[INIT]] : tensor<?x5x6xf32>) -> tensor<?x5x6xf32>
42+
// CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<?x5x3xf32>, tensor<?x3x6xf32>) outs(%[[FILLED]] : tensor<?x5x6xf32>) -> tensor<?x5x6xf32>
43+
%0 = "tosa.matmul"(%arg0, %arg1) : (tensor<?x5x3xf32>, tensor<?x3x6xf32>) -> (tensor<?x5x6xf32>)
44+
return %0 : tensor<?x5x6xf32>
45+
}
46+
47+
// -----
48+
49+
// CHECK-LABEL: @matmul_dyn_independent_dim
50+
func.func @matmul_dyn_independent_dim(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x?xf32>) -> (tensor<1x5x?xf32>) {
51+
// CHECK: %[[CONST1:.*]] = arith.constant 2
52+
// CHECK: %[[DIM:.*]] = tensor.dim %arg1, %[[CONST1]] : tensor<1x3x?xf32>
53+
// CHECK: %[[COLLAPSE1:.*]] = tensor.collapse_shape %arg0 {{\[\[}}0, 1], [2]] : tensor<1x5x3xf32> into tensor<5x3xf32>
54+
// CHECK: %[[COLLAPSE2:.*]] = tensor.collapse_shape %arg1 {{\[\[}}0, 1], [2]] : tensor<1x3x?xf32> into tensor<3x?xf32>
55+
// CHECK: %[[CONST2:.*]] = arith.constant 0.000000e+00
56+
// CHECK: %[[EMPTY:.*]] = tensor.empty(%[[DIM]]) : tensor<5x?xf32>
57+
// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[CONST2]] : f32) outs(%[[EMPTY]] : tensor<5x?xf32>) -> tensor<5x?xf32>
58+
// CHECK: %[[MATMUL:.*]] = linalg.matmul ins(%[[COLLAPSE1]], %[[COLLAPSE2]] : tensor<5x3xf32>, tensor<3x?xf32>) outs(%[[FILL]] : tensor<5x?xf32>) -> tensor<5x?xf32>
59+
// CHECK: %[[VAL_10:.*]] = tensor.expand_shape %[[MATMUL]] {{\[\[}}0, 1], [2]] : tensor<5x?xf32> into tensor<1x5x?xf32>
60+
%0 = "tosa.matmul"(%arg0, %arg1) : (tensor<1x5x3xf32>, tensor<1x3x?xf32>) -> (tensor<1x5x?xf32>)
61+
return %0 : tensor<1x5x?xf32>
62+
}
63+
64+
// -----
65+
66+
// CHECK-LABEL: @matmul_dyn_independent_dim_no_matmul
67+
func.func @matmul_dyn_independent_dim_no_matmul(%arg0: tensor<1x5x?xf32>, %arg1: tensor<1x?x6xf32>) -> (tensor<1x5x6xf32>) {
68+
// CHECK: %[[C0:.+]] = arith.constant 0
69+
// CHECK: %[[INIT:.+]] = tensor.empty()
70+
// CHECK: %[[FILLED:.+]] = linalg.fill ins(%[[C0]] : f32) outs(%[[INIT]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32>
71+
// CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x5x?xf32>, tensor<1x?x6xf32>) outs(%[[FILLED]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32>
72+
%0 = "tosa.matmul"(%arg0, %arg1) : (tensor<1x5x?xf32>, tensor<1x?x6xf32>) -> (tensor<1x5x6xf32>)
73+
return %0 : tensor<1x5x6xf32>
74+
}

0 commit comments

Comments
 (0)