Skip to content

Commit a5c6469

Browse files
udaya-rangaTai78641
authored andcommitted
[mlir][tosa] Change MatMul zero-point to inputs
* Change zero-point attributes to inputs * Fix relevant mlir tests * Enhance ShardingInterface in MatMul Signed-off-by: Udaya Ranga <[email protected]> Change-Id: Ia58b15cba546a948a6a4d8e8ee26a72cd050de4e
1 parent 0ee8f69 commit a5c6469

File tree

15 files changed

+232
-98
lines changed

15 files changed

+232
-98
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,11 @@ profileComplianceMap = {
3535
{fp16T, fp16T, fp32T, fp32T},
3636
{fp32T, fp32T, fp32T, fp32T}}}}},
3737
{"tosa.matmul",
38-
{{{Profile::pro_int}, {{i8T, i8T, i32T}}},
38+
{{{Profile::pro_int}, {{i8T, i8T, i8T, i8T, i32T}}},
3939
{{Profile::pro_fp},
40-
{{fp16T, fp16T, fp16T}, {fp16T, fp16T, fp32T}, {fp32T, fp32T, fp32T}}}}},
40+
{{fp16T, fp16T, fp16T, fp16T, fp16T},
41+
{fp16T, fp16T, fp16T, fp16T, fp32T},
42+
{fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
4143
{"tosa.max_pool2d",
4244
{{{Profile::pro_int}, {{i8T, i8T}}},
4345
{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
@@ -273,10 +275,10 @@ extensionComplianceMap = {
273275
{{Extension::int16}, {{i16T, i8T, i48T, i48T}}},
274276
{{Extension::bf16}, {{bf16T, bf16T, fp32T, fp32T}}}}},
275277
{"tosa.matmul",
276-
{{{Extension::int16}, {{i16T, i16T, i48T}}},
277-
{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp16T}}},
278-
{{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp16T}}},
279-
{{Extension::bf16}, {{bf16T, bf16T, fp32T}}}}},
278+
{{{Extension::int16}, {{i16T, i16T, i16T, i16T, i48T}}},
279+
{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T}}},
280+
{{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T}}},
281+
{{Extension::bf16}, {{bf16T, bf16T, bf16T, bf16T, fp32T}}}}},
280282
{"tosa.max_pool2d",
281283
{{{Extension::int16}, {{i16T, i16T}}},
282284
{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -311,8 +311,8 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
311311
let arguments = (ins
312312
Tosa_Tensor3D:$a,
313313
Tosa_Tensor3D:$b,
314-
OptionalAttr<I32Attr>:$a_zp,
315-
OptionalAttr<I32Attr>:$b_zp
314+
Tosa_ScalarIntOrFloatTensor:$a_zp,
315+
Tosa_ScalarIntOrFloatTensor:$b_zp
316316
);
317317

318318
let results = (outs
@@ -324,6 +324,13 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
324324
Extension<[Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
325325
];
326326

327+
let extraClassDeclaration = [{
328+
FailureOr<int64_t> getAZeroPoint();
329+
FailureOr<int64_t> getBZeroPoint();
330+
LogicalResult verifyAZeroPoint(int64_t zp);
331+
LogicalResult verifyBZeroPoint(int64_t zp);
332+
}];
333+
327334
let builders = [Tosa_MatMulOpQuantInfoBuilder];
328335
let hasVerifier = 1;
329336
}

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -270,8 +270,8 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
270270
return rewriter.notifyMatchFailure(
271271
op, "weight zero point cannot be statically determined");
272272

273-
int64_t inputZpVal = *maybeIZp;
274-
int64_t weightZpVal = *maybeWZp;
273+
const int64_t inputZpVal = *maybeIZp;
274+
const int64_t weightZpVal = *maybeWZp;
275275

276276
if (op.verifyInputZeroPoint(inputZpVal).failed())
277277
return rewriter.notifyMatchFailure(
@@ -466,8 +466,8 @@ class DepthwiseConvConverter
466466
return rewriter.notifyMatchFailure(
467467
op, "weight zero point cannot be statically determined");
468468

469-
int64_t inputZpVal = *maybeIZp;
470-
int64_t weightZpVal = *maybeWZp;
469+
const int64_t inputZpVal = *maybeIZp;
470+
const int64_t weightZpVal = *maybeWZp;
471471

472472
if (op.verifyInputZeroPoint(inputZpVal).failed())
473473
return rewriter.notifyMatchFailure(
@@ -621,15 +621,38 @@ class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
621621
.create<linalg::FillOp>(loc, ValueRange{zero},
622622
ValueRange{emptyTensor})
623623
.result();
624-
if (!op.getAZp() && !op.getBZp()) {
624+
625+
FailureOr<int64_t> maybeAZp = op.getAZeroPoint();
626+
FailureOr<int64_t> maybeBZp = op.getBZeroPoint();
627+
if (failed(maybeAZp))
628+
return rewriter.notifyMatchFailure(
629+
op, "input a zero point cannot be statically determined");
630+
if (failed(maybeBZp))
631+
return rewriter.notifyMatchFailure(
632+
op, "input b zero point cannot be statically determined");
633+
634+
const int64_t aZpVal = *maybeAZp;
635+
const int64_t bZpVal = *maybeBZp;
636+
637+
if (op.verifyAZeroPoint(aZpVal).failed())
638+
return rewriter.notifyMatchFailure(
639+
op, "input a zero point must be zero for non-int8 integer types");
640+
641+
if (op.verifyBZeroPoint(bZpVal).failed())
642+
return rewriter.notifyMatchFailure(
643+
op, "input b zero point must be zero for non-int8 integer types");
644+
645+
if (aZpVal == 0 && bZpVal == 0) {
625646
rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>(
626647
op, TypeRange{op.getType()},
627648
ValueRange{adaptor.getA(), adaptor.getB()}, ValueRange{zeroTensor});
628649
return success();
629650
}
630651

631-
auto aZp = rewriter.create<arith::ConstantOp>(loc, op.getAZpAttr());
632-
auto bZp = rewriter.create<arith::ConstantOp>(loc, op.getBZpAttr());
652+
auto aZp = rewriter.create<arith::ConstantOp>(
653+
loc, rewriter.getI32IntegerAttr(aZpVal));
654+
auto bZp = rewriter.create<arith::ConstantOp>(
655+
loc, rewriter.getI32IntegerAttr(bZpVal));
633656
rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
634657
op, TypeRange{op.getType()},
635658
ValueRange{adaptor.getA(), adaptor.getB(), aZp, bZp}, zeroTensor);
@@ -834,8 +857,8 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
834857
return rewriter.notifyMatchFailure(
835858
op, "output zero point could not be statically determined");
836859

837-
int64_t inputZpVal = *maybeIZp;
838-
int64_t outputZpVal = *maybeOZp;
860+
const int64_t inputZpVal = *maybeIZp;
861+
const int64_t outputZpVal = *maybeOZp;
839862

840863
// Apply padding as necessary.
841864
llvm::SmallVector<int64_t> pad;

mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ struct MatMulOpSharding
5555
SmallVector<AffineMap> maps;
5656
maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 1, 3}, ctx));
5757
maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 3, 2}, ctx));
58+
maps.push_back(AffineMap::get(0, 0, {}, ctx));
59+
maps.push_back(AffineMap::get(0, 0, {}, ctx));
5860
maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 1, 2}, ctx));
5961
return maps;
6062
}

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 39 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -636,23 +636,13 @@ buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
636636
static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
637637
OperationState &result, Type outputType,
638638
Value a, Value b) {
639-
result.addOperands({a, b});
640-
auto quantAttr = ::buildMatMulOpQuantizationAttr(builder, a, b);
639+
auto zps = createZPsAsConst(builder, a, b);
640+
result.addOperands({a, b, zps.first, zps.second});
641641

642-
if (quantAttr) {
643-
result.addAttribute("a_zp", builder.getI32IntegerAttr(
644-
static_cast<int32_t>(quantAttr.getAZp())));
645-
result.addAttribute("b_zp", builder.getI32IntegerAttr(
646-
static_cast<int32_t>(quantAttr.getBZp())));
647-
648-
auto inputType = llvm::dyn_cast<ShapedType>(a.getType());
649-
assert(inputType && "Input must be a shaped tensor type!");
650-
651-
auto inputQType = llvm::dyn_cast<mlir::quant::UniformQuantizedType>(
652-
inputType.getElementType());
653-
assert(inputQType && "Tensor must have quantized datatype!");
654-
655-
unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
642+
Type finalOutputType{outputType};
643+
if (auto quantAttr = buildMatMulOpQuantizationAttr(builder, a, b)) {
644+
auto eType = getStorageElementTypeOrSelf(a.getType());
645+
auto inputBits = eType.getIntOrFloatBitWidth();
656646

657647
auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
658648
assert(outputShapedType && "Output must be a shaped type");
@@ -662,11 +652,10 @@ static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
662652
accElementType = builder.getIntegerType(48);
663653
else
664654
accElementType = builder.getI32Type();
665-
auto accType = outputShapedType.clone(accElementType);
666-
result.addTypes(accType);
667-
} else {
668-
result.addTypes(outputType);
655+
656+
finalOutputType = outputShapedType.clone(accElementType);
669657
}
658+
result.addTypes(finalOutputType);
670659
}
671660

672661
/// Both the tosa.avg_pool2d and unary ops use the same
@@ -1147,16 +1136,39 @@ LogicalResult MatMulOp::verify() {
11471136
return emitOpError("expect quantized operands to have same widths, got ")
11481137
<< aQuantWidth << " and " << bQuantWidth;
11491138
}
1139+
} else {
1140+
// non-quantized element types
1141+
if (aElementType != bElementType) {
1142+
return emitOpError("expect same element type for inputs a and b, got ")
1143+
<< aElementType << " and " << bElementType;
1144+
}
1145+
}
11501146

1151-
return success();
1147+
// check a_zp and b_zp
1148+
auto aEType = getStorageElementTypeOrSelf(aType);
1149+
auto aZpEType = getStorageElementTypeOrSelf(getAZp().getType());
1150+
if (aEType != aZpEType) {
1151+
return emitOpError("expect input a and a_zp have the same "
1152+
"element type, got ")
1153+
<< aEType << " and " << aZpEType;
11521154
}
11531155

1154-
// non-quantized element types
1155-
if (aElementType != bElementType) {
1156-
return emitOpError("expect same element type for inputs a and b, got ")
1157-
<< aElementType << " and " << bElementType;
1156+
auto bEType = getStorageElementTypeOrSelf(bType);
1157+
auto bZpEType = getStorageElementTypeOrSelf(getBZp().getType());
1158+
if (bEType != bZpEType) {
1159+
return emitOpError("expect input b and b_zp have the same "
1160+
"element type, got ")
1161+
<< bEType << " and " << bZpEType;
11581162
}
11591163

1164+
FailureOr<int64_t> maybeAZp = getAZeroPoint();
1165+
if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed())
1166+
return failure();
1167+
1168+
FailureOr<int64_t> maybeBZp = getBZeroPoint();
1169+
if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed())
1170+
return failure();
1171+
11601172
return success();
11611173
}
11621174

@@ -1721,6 +1733,8 @@ ZERO_POINT_HELPER(TransposeConv2DOp, Input)
17211733
ZERO_POINT_HELPER(TransposeConv2DOp, Weight)
17221734
ZERO_POINT_HELPER(AvgPool2dOp, Input)
17231735
ZERO_POINT_HELPER(AvgPool2dOp, Output)
1736+
ZERO_POINT_HELPER(MatMulOp, A)
1737+
ZERO_POINT_HELPER(MatMulOp, B)
17241738
#undef ZERO_POINT_HELPER
17251739

17261740
LogicalResult tosa::TransposeOp::inferReturnTypeComponents(

mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,15 @@ void ProfileInfoDepot::populateProfileInfo(tosa::RescaleOp op) {
178178
addValue(op.getOutput());
179179
}
180180

181+
template <>
182+
void ProfileInfoDepot::populateProfileInfo(tosa::MatMulOp op) {
183+
addValue(op.getA());
184+
addValue(op.getB());
185+
addValue(op.getAZp());
186+
addValue(op.getBZp());
187+
addValue(op.getOutput());
188+
}
189+
181190
LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
182191
// This helper function only populates the info for the customised operands.
183192
#define POPULATE_PROFILE_INFO_CUSTOM(tosaOp) \
@@ -218,6 +227,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
218227
POPULATE_PROFILE_INFO_CUSTOM(Resize)
219228
POPULATE_PROFILE_INFO_CUSTOM(Select)
220229
POPULATE_PROFILE_INFO_CUSTOM(Rescale)
230+
POPULATE_PROFILE_INFO_CUSTOM(MatMul)
221231

222232
// Type Invariant Extension, a capability extension that is independent
223233
// of the data type, meaning any compatible type can be used. No type
@@ -235,7 +245,6 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
235245
POPULATE_PROFILE_INFO_COMMON(Cast)
236246
POPULATE_PROFILE_INFO_COMMON(Const)
237247
POPULATE_PROFILE_INFO_COMMON(ArgMax)
238-
POPULATE_PROFILE_INFO_COMMON(MatMul)
239248
POPULATE_PROFILE_INFO_COMMON(Sub)
240249
POPULATE_PROFILE_INFO_COMMON(Maximum)
241250
POPULATE_PROFILE_INFO_COMMON(Minimum)

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ func.func @matmul(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x6xf32>) -> (tensor
88
// CHECK: [[INIT:%.+]] = tensor.empty()
99
// CHECK: [[FILLED:%.+]] = linalg.fill ins([[C0]] : f32) outs([[INIT]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32>
1010
// CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x5x3xf32>, tensor<1x3x6xf32>) outs([[FILLED]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32>
11-
%0 = tosa.matmul %arg0, %arg1 : (tensor<1x5x3xf32>, tensor<1x3x6xf32>) -> tensor<1x5x6xf32>
11+
%a_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
12+
%b_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
13+
%0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x5x3xf32>, tensor<1x3x6xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x5x6xf32>
1214
return %0 : tensor<1x5x6xf32>
1315
}
1416

@@ -23,7 +25,9 @@ func.func @matmul_quantized(%arg0: tensor<1x5x3xi8>, %arg1: tensor<1x3x6xi8>) ->
2325
// CHECK: [[ONE:%.+]] = arith.constant 1
2426
// CHECK: [[TWO:%.+]] = arith.constant 2
2527
// CHECK: linalg.quantized_batch_matmul ins(%arg0, %arg1, [[ONE]], [[TWO]] : tensor<1x5x3xi8>, tensor<1x3x6xi8>, i32, i32) outs([[FILLED]] : tensor<1x5x6xi32>) -> tensor<1x5x6xi32>
26-
%0 = tosa.matmul %arg0, %arg1 {a_zp = 1 : i32, b_zp = 2 : i32} : (tensor<1x5x3xi8>, tensor<1x3x6xi8>) -> tensor<1x5x6xi32>
28+
%a_zp = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
29+
%b_zp = "tosa.const"() <{values = dense<2> : tensor<1xi8>}> : () -> tensor<1xi8>
30+
%0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x5x3xi8>, tensor<1x3x6xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x5x6xi32>
2731
return %0 : tensor<1x5x6xi32>
2832
}
2933

@@ -37,7 +41,9 @@ func.func @matmul_dyn_batch(%arg0: tensor<?x5x3xf32>, %arg1: tensor<?x3x6xf32>)
3741
// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM]])
3842
// CHECK: %[[FILLED:.+]] = linalg.fill ins(%[[C0_0]] : f32) outs(%[[INIT]] : tensor<?x5x6xf32>) -> tensor<?x5x6xf32>
3943
// CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<?x5x3xf32>, tensor<?x3x6xf32>) outs(%[[FILLED]] : tensor<?x5x6xf32>) -> tensor<?x5x6xf32>
40-
%0 = tosa.matmul %arg0, %arg1 : (tensor<?x5x3xf32>, tensor<?x3x6xf32>) -> tensor<?x5x6xf32>
44+
%a_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
45+
%b_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
46+
%0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<?x5x3xf32>, tensor<?x3x6xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x5x6xf32>
4147
return %0 : tensor<?x5x6xf32>
4248
}
4349

@@ -51,7 +57,9 @@ func.func @matmul_dyn_independent_dim(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x
5157
// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM]])
5258
// CHECK: %[[FILLED:.+]] = linalg.fill ins(%[[C0]] : f32) outs(%[[INIT]] : tensor<1x5x?xf32>) -> tensor<1x5x?xf32>
5359
// CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x5x3xf32>, tensor<1x3x?xf32>) outs(%[[FILLED]] : tensor<1x5x?xf32>) -> tensor<1x5x?xf32>
54-
%0 = tosa.matmul %arg0, %arg1 : (tensor<1x5x3xf32>, tensor<1x3x?xf32>) -> tensor<1x5x?xf32>
60+
%a_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
61+
%b_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
62+
%0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x5x3xf32>, tensor<1x3x?xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x5x?xf32>
5563
return %0 : tensor<1x5x?xf32>
5664
}
5765

@@ -63,7 +71,9 @@ func.func @matmul_dyn_independent_dim(%arg0: tensor<1x5x?xf32>, %arg1: tensor<1x
6371
// CHECK: %[[INIT:.+]] = tensor.empty()
6472
// CHECK: %[[FILLED:.+]] = linalg.fill ins(%[[C0]] : f32) outs(%[[INIT]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32>
6573
// CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x5x?xf32>, tensor<1x?x6xf32>) outs(%[[FILLED]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32>
66-
%0 = tosa.matmul %arg0, %arg1 : (tensor<1x5x?xf32>, tensor<1x?x6xf32>) -> tensor<1x5x6xf32>
74+
%a_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
75+
%b_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
76+
%0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x5x?xf32>, tensor<1x?x6xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x5x6xf32>
6777
return %0 : tensor<1x5x6xf32>
6878
}
6979

@@ -77,7 +87,9 @@ func.func @matmul_dyn_output(%arg0: tensor<1x1x8xf32>, %arg1: tensor<1x8x1xf32>)
7787
// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM0]]) : tensor<?x1x1xf32>
7888
// CHECK: %[[FILLED:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<?x1x1xf32>) -> tensor<?x1x1xf32>
7989
// CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x1x8xf32>, tensor<1x8x1xf32>) outs(%[[FILLED]] : tensor<?x1x1xf32>) -> tensor<?x1x1xf32>
80-
%0 = tosa.matmul %arg0, %arg1 : (tensor<1x1x8xf32>, tensor<1x8x1xf32>) -> tensor<?x1x1xf32>
90+
%a_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
91+
%b_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
92+
%0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x1x8xf32>, tensor<1x8x1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x1x1xf32>
8193
return %0 : tensor<?x1x1xf32>
8294
}
8395

0 commit comments

Comments
 (0)