Skip to content

Commit eed17dc

Browse files
authored
[mlir][tosa] Fix lowering of tosa.matmul with dynamic outputs (#72724)
The existing lowering of tosa.matmul will construct illegal tensor.empty operations when the output type is more dynamic than the input types. %0 = tosa.matmul %arg0, %arg1 : (tensor<1x1x8xf32>, tensor<1x8x1xf32>) -> tensor<?x1x1xf32> When constructing the tensor.empty operation, consult the output type rather than the input types to decide whether a dimension is dynamic.
1 parent 9b20af1 commit eed17dc

File tree

2 files changed

+17
-6
lines changed

2 files changed

+17
-6
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -540,21 +540,18 @@ class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
540540
auto outputTy = cast<ShapedType>(op.getType());
541541
auto outputElementTy = outputTy.getElementType();
542542

543-
auto firstOperandTy = cast<ShapedType>(op->getOperand(0).getType());
544-
auto secondOperandTy = cast<ShapedType>(op->getOperand(1).getType());
545-
546543
SmallVector<Value> dynDims;
547544
dynDims.resize(cast<ShapedType>(op->getResult(0).getType()).getRank());
548545

549-
if (!firstOperandTy.hasRank() || firstOperandTy.isDynamicDim(0)) {
546+
if (!outputTy.hasRank() || outputTy.isDynamicDim(0)) {
550547
dynDims[0] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 0);
551548
}
552549

553-
if (!firstOperandTy.hasRank() || firstOperandTy.isDynamicDim(1)) {
550+
if (!outputTy.hasRank() || outputTy.isDynamicDim(1)) {
554551
dynDims[1] = rewriter.create<tensor::DimOp>(loc, op->getOperand(0), 1);
555552
}
556553

557-
if (!secondOperandTy.hasRank() || secondOperandTy.isDynamicDim(2)) {
554+
if (!outputTy.hasRank() || outputTy.isDynamicDim(2)) {
558555
dynDims[2] = rewriter.create<tensor::DimOp>(loc, op->getOperand(1), 2);
559556
}
560557

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,20 @@ func.func @matmul_dyn_independent_dim(%arg0: tensor<1x5x?xf32>, %arg1: tensor<1x
6868

6969
// -----
7070

71+
// CHECK-LABEL: @matmul_dyn_output
72+
func.func @matmul_dyn_output(%arg0: tensor<1x1x8xf32>, %arg1: tensor<1x8x1xf32>) -> tensor<?x1x1xf32> {
73+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
74+
// CHECK: %[[DIM0:.+]] = tensor.dim %arg0, %[[C0]] : tensor<1x1x8xf32>
75+
// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32
76+
// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM0]]) : tensor<?x1x1xf32>
77+
// CHECK: %[[FILLED:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<?x1x1xf32>) -> tensor<?x1x1xf32>
78+
// CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x1x8xf32>, tensor<1x8x1xf32>) outs(%[[FILLED]] : tensor<?x1x1xf32>) -> tensor<?x1x1xf32>
79+
%0 = tosa.matmul %arg0, %arg1 : (tensor<1x1x8xf32>, tensor<1x8x1xf32>) -> tensor<?x1x1xf32>
80+
return %0 : tensor<?x1x1xf32>
81+
}
82+
83+
// -----
84+
7185
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d1)>
7286
// CHECK: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)>
7387

0 commit comments

Comments
 (0)