@@ -173,6 +173,35 @@ static void createDepthwiseConvCollapseMap(
173
173
rewriter.getAffineDimExpr (outputRank));
174
174
}
175
175
176
+ static FailureOr<Value> collapseValue (OpBuilder &rewriter, Location loc,
177
+ Value value, ShapedType type) {
178
+ auto reassociationMap = getReassociationIndicesForReshape (
179
+ cast<ShapedType>(value.getType ()), type);
180
+ if (!reassociationMap.has_value ())
181
+ return failure ();
182
+
183
+ return Value (rewriter.create <tensor::CollapseShapeOp>(
184
+ loc, type, value, reassociationMap.value ()));
185
+ }
186
+
187
+ static FailureOr<SmallVector<Value>>
188
+ collapseValues (OpBuilder &rewriter, Location loc, SmallVector<Value> values,
189
+ SmallVector<ShapedType> newTys, bool useMatmulForBatchOne) {
190
+ if (!useMatmulForBatchOne)
191
+ return values;
192
+
193
+ SmallVector<Value> newValues;
194
+ for (auto [idx, value] : llvm::enumerate (values)) {
195
+
196
+ auto newValue = collapseValue (rewriter, loc, value, newTys[idx]);
197
+ if (failed (newValue))
198
+ return failure ();
199
+
200
+ newValues.push_back (*newValue);
201
+ }
202
+ return newValues;
203
+ }
204
+
176
205
namespace {
177
206
178
207
template <typename TosaConvOp, typename LinalgConvOp, typename LinalgConvQOp>
@@ -498,6 +527,9 @@ class DepthwiseConvConverter
498
527
499
528
class MatMulConverter : public OpConversionPattern <tosa::MatMulOp> {
500
529
public:
530
+ MatMulConverter (MLIRContext *ctx, bool useMatmulForSingleBatch)
531
+ : OpConversionPattern<tosa::MatMulOp>(ctx),
532
+ useMatmulForSingleBatch (useMatmulForSingleBatch) {}
501
533
using OpConversionPattern<tosa::MatMulOp>::OpConversionPattern;
502
534
LogicalResult
503
535
matchAndRewrite (tosa::MatMulOp op, OpAdaptor adaptor,
@@ -525,20 +557,55 @@ class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
525
557
dynDims[2 ] = rewriter.create <tensor::DimOp>(loc, op->getOperand (1 ), 2 );
526
558
}
527
559
560
+ auto getTypeWithoutBatch = [&](ShapedType ty) {
561
+ auto shape2D = {ty.getDimSize (1 ), ty.getDimSize (2 )};
562
+ return RankedTensorType::get (shape2D, ty.getElementType ());
563
+ };
564
+
528
565
SmallVector<Value> filteredDims = condenseValues (dynDims);
529
566
567
+ bool useMatmulForBatchOne =
568
+ outputTy.getDimSize (0 ) == 1 && this ->useMatmulForSingleBatch ;
569
+
570
+ auto newInput1Type = getTypeWithoutBatch (firstOperandTy);
571
+ auto newInput2Type = getTypeWithoutBatch (secondOperandTy);
572
+ auto newOutputType = getTypeWithoutBatch (outputTy);
573
+
574
+ SmallVector<Value> inputs = {adaptor.getA (), adaptor.getB ()};
575
+ auto inputsOrFailure =
576
+ collapseValues (rewriter, loc, inputs, {newInput1Type, newInput2Type},
577
+ useMatmulForBatchOne);
578
+ auto matmulMap = getReassociationIndicesForReshape (newOutputType, outputTy);
579
+
580
+ // If any of the reassociations of indices failed, don't use matmul.
581
+ if (failed (inputsOrFailure) || !matmulMap.has_value ()) {
582
+ useMatmulForBatchOne = false ;
583
+ } else {
584
+ inputs = *inputsOrFailure;
585
+ }
586
+
530
587
auto zeroAttr = rewriter.getZeroAttr (outputElementTy);
531
588
Value zero = rewriter.create <arith::ConstantOp>(loc, zeroAttr);
532
- auto emptyTensor = rewriter.create <tensor::EmptyOp>(
533
- loc, outputTy.getShape (), outputTy.getElementType (), filteredDims);
589
+
590
+ Value emptyTensor = rewriter.create <tensor::EmptyOp>(
591
+ loc,
592
+ useMatmulForBatchOne ? newOutputType.getShape () : outputTy.getShape (),
593
+ outputElementTy, filteredDims);
594
+
534
595
Value zeroTensor = rewriter
535
596
.create <linalg::FillOp>(loc, ValueRange{zero},
536
597
ValueRange{emptyTensor})
537
598
.result ();
599
+
538
600
if (!op.getQuantizationInfo ()) {
539
- rewriter.replaceOpWithNewOp <linalg::BatchMatmulOp>(
540
- op, TypeRange{op.getType ()},
541
- ValueRange{adaptor.getA (), adaptor.getB ()}, ValueRange{zeroTensor});
601
+ if (useMatmulForBatchOne) {
602
+ auto matmul = rewriter.create <linalg::MatmulOp>(
603
+ loc, TypeRange{newOutputType}, inputs, ValueRange{zeroTensor});
604
+ rewriter.replaceOpWithNewOp <tensor::ExpandShapeOp>(
605
+ op, outputTy, matmul->getResult (0 ), matmulMap.value ());
606
+ } else
607
+ rewriter.replaceOpWithNewOp <linalg::BatchMatmulOp>(
608
+ op, TypeRange{op.getType ()}, inputs, ValueRange{zeroTensor});
542
609
return success ();
543
610
}
544
611
@@ -547,12 +614,22 @@ class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
547
614
loc, rewriter.getI32IntegerAttr (quantizationInfo.getAZp ()));
548
615
auto bZp = rewriter.create <arith::ConstantOp>(
549
616
loc, rewriter.getI32IntegerAttr (quantizationInfo.getBZp ()));
550
- rewriter.replaceOpWithNewOp <linalg::QuantizedBatchMatmulOp>(
551
- op, TypeRange{op.getType ()},
552
- ValueRange{adaptor.getA (), adaptor.getB (), aZp, bZp}, zeroTensor);
617
+ if (useMatmulForBatchOne) {
618
+ auto matmul = rewriter.create <linalg::QuantizedMatmulOp>(
619
+ loc, TypeRange{newOutputType},
620
+ ValueRange{inputs[0 ], inputs[1 ], aZp, bZp}, zeroTensor);
621
+ rewriter.replaceOpWithNewOp <tensor::ExpandShapeOp>(
622
+ op, outputTy, matmul->getResult (0 ), matmulMap.value ());
623
+ } else
624
+ rewriter.replaceOpWithNewOp <linalg::QuantizedBatchMatmulOp>(
625
+ op, TypeRange{op.getType ()},
626
+ ValueRange{inputs[0 ], inputs[1 ], aZp, bZp}, zeroTensor);
553
627
554
628
return success ();
555
629
}
630
+
631
+ private:
632
+ bool useMatmulForSingleBatch;
556
633
};
557
634
558
635
class FullyConnectedConverter
@@ -974,15 +1051,16 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
974
1051
} // namespace
975
1052
976
1053
void mlir::tosa::populateTosaToLinalgNamedConversionPatterns (
977
- RewritePatternSet *patterns) {
1054
+ RewritePatternSet *patterns, bool useMatmulForSingleBatch ) {
978
1055
patterns->add <
979
1056
// clang-format off
980
1057
ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwcHwcfOp, linalg::Conv2DNhwcHwcfQOp>,
981
1058
ConvConverter<tosa::Conv3DOp, linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNdhwcDhwcfQOp>,
982
1059
DepthwiseConvConverter,
983
- MatMulConverter,
984
1060
MaxPool2dConverter,
985
1061
AvgPool2dConverter,
986
1062
FullyConnectedConverter>(patterns->getContext ());
1063
+ patterns->add <
1064
+ MatMulConverter>(patterns->getContext (), useMatmulForSingleBatch);
987
1065
// clang-format on
988
1066
}
0 commit comments