Skip to content

Commit 52b1c3a

Browse files
udaya-rangaTai78641
authored andcommitted
[mlir][tosa] Change MatMul zero-point to inputs
* Change zero-point attributes to inputs * Fix relevant mlir tests * Enhance ShardingInterface in MatMul Signed-off-by: Udaya Ranga <[email protected]> Change-Id: Ia58b15cba546a948a6a4d8e8ee26a72cd050de4e
1 parent b08769c commit 52b1c3a

File tree

15 files changed

+242
-86
lines changed

15 files changed

+242
-86
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,11 @@ profileComplianceMap = {
3535
{fp16T, fp16T, fp32T, fp32T},
3636
{fp32T, fp32T, fp32T, fp32T}}}}},
3737
{"tosa.matmul",
38-
{{{Profile::pro_int}, {{i8T, i8T, i32T}}},
38+
{{{Profile::pro_int}, {{i8T, i8T, i8T, i8T, i32T}}},
3939
{{Profile::pro_fp},
40-
{{fp16T, fp16T, fp16T}, {fp16T, fp16T, fp32T}, {fp32T, fp32T, fp32T}}}}},
40+
{{fp16T, fp16T, fp16T, fp16T, fp16T},
41+
{fp16T, fp16T, fp16T, fp16T, fp32T},
42+
{fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
4143
{"tosa.max_pool2d",
4244
{{{Profile::pro_int}, {{i8T, i8T}}},
4345
{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
@@ -273,10 +275,10 @@ extensionComplianceMap = {
273275
{{Extension::int16}, {{i16T, i8T, i48T, i48T}}},
274276
{{Extension::bf16}, {{bf16T, bf16T, fp32T, fp32T}}}}},
275277
{"tosa.matmul",
276-
{{{Extension::int16}, {{i16T, i16T, i48T}}},
277-
{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp16T}}},
278-
{{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp16T}}},
279-
{{Extension::bf16}, {{bf16T, bf16T, fp32T}}}}},
278+
{{{Extension::int16}, {{i16T, i16T, i16T, i16T, i48T}}},
279+
{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T}}},
280+
{{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T}}},
281+
{{Extension::bf16}, {{bf16T, bf16T, bf16T, bf16T, fp32T}}}}},
280282
{"tosa.max_pool2d",
281283
{{{Extension::int16}, {{i16T, i16T}}},
282284
{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,8 +309,8 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
309309
let arguments = (ins
310310
Tosa_Tensor3D:$a,
311311
Tosa_Tensor3D:$b,
312-
OptionalAttr<I32Attr>:$a_zp,
313-
OptionalAttr<I32Attr>:$b_zp
312+
Tosa_ScalarTensor:$a_zp,
313+
Tosa_ScalarTensor:$b_zp
314314
);
315315

316316
let results = (outs
@@ -322,7 +322,15 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
322322
Extension<[Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
323323
];
324324

325+
let extraClassDeclaration = [{
326+
FailureOr<int64_t> getAZeroPoint();
327+
FailureOr<int64_t> getBZeroPoint();
328+
LogicalResult verifyAZeroPoint(int64_t zp);
329+
LogicalResult verifyBZeroPoint(int64_t zp);
330+
}];
331+
325332
let builders = [Tosa_MatMulOpQuantInfoBuilder];
333+
let hasVerifier = 1;
326334
}
327335

328336
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -621,15 +621,38 @@ class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
621621
.create<linalg::FillOp>(loc, ValueRange{zero},
622622
ValueRange{emptyTensor})
623623
.result();
624-
if (!op.getAZp() && !op.getBZp()) {
624+
625+
FailureOr<int64_t> maybeAZp = op.getAZeroPoint();
626+
FailureOr<int64_t> maybeBZp = op.getBZeroPoint();
627+
if (failed(maybeAZp))
628+
return rewriter.notifyMatchFailure(
629+
op, "input a zero point cannot be statically determined");
630+
if (failed(maybeBZp))
631+
return rewriter.notifyMatchFailure(
632+
op, "input b zero point cannot be statically determined");
633+
634+
int64_t aZpVal = *maybeAZp;
635+
int64_t bZpVal = *maybeBZp;
636+
637+
if (op.verifyAZeroPoint(aZpVal).failed())
638+
return rewriter.notifyMatchFailure(
639+
op, "input a zero point must be zero for non-int8 integer types");
640+
641+
if (op.verifyBZeroPoint(bZpVal).failed())
642+
return rewriter.notifyMatchFailure(
643+
op, "input b zero point must be zero for non-int8 integer types");
644+
645+
if (aZpVal == 0 && bZpVal == 0) {
625646
rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>(
626647
op, TypeRange{op.getType()},
627648
ValueRange{adaptor.getA(), adaptor.getB()}, ValueRange{zeroTensor});
628649
return success();
629650
}
630651

631-
auto aZp = rewriter.create<arith::ConstantOp>(loc, op.getAZpAttr());
632-
auto bZp = rewriter.create<arith::ConstantOp>(loc, op.getBZpAttr());
652+
auto aZp = rewriter.create<arith::ConstantOp>(
653+
loc, rewriter.getI32IntegerAttr(aZpVal));
654+
auto bZp = rewriter.create<arith::ConstantOp>(
655+
loc, rewriter.getI32IntegerAttr(bZpVal));
633656
rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
634657
op, TypeRange{op.getType()},
635658
ValueRange{adaptor.getA(), adaptor.getB(), aZp, bZp}, zeroTensor);

mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ struct MatMulOpSharding
5555
SmallVector<AffineMap> maps;
5656
maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 1, 3}, ctx));
5757
maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 3, 2}, ctx));
58+
maps.push_back(AffineMap::get(0, 0, {}, ctx));
59+
maps.push_back(AffineMap::get(0, 0, {}, ctx));
5860
maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 1, 2}, ctx));
5961
return maps;
6062
}

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 55 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,34 @@ static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
435435
return success();
436436
}
437437

438+
static LogicalResult verifyZpMatMul(MatMulOp op) {
439+
auto aEType = getStorageElementTypeOrSelf(op.getA().getType());
440+
auto aZpEType = getStorageElementTypeOrSelf(op.getAZp().getType());
441+
if (aEType != aZpEType) {
442+
return op.emitOpError("expect input a and a_zp have the same "
443+
"element type, got ")
444+
<< aEType << " and " << aZpEType;
445+
}
446+
447+
auto bEType = getStorageElementTypeOrSelf(op.getB().getType());
448+
auto bZpEType = getStorageElementTypeOrSelf(op.getBZp().getType());
449+
if (bEType != bZpEType) {
450+
return op.emitOpError("expect input b and b_zp have the same "
451+
"element type, got ")
452+
<< bEType << " and " << bZpEType;
453+
}
454+
455+
FailureOr<int64_t> maybeAZp = op.getAZeroPoint();
456+
if (succeeded(maybeAZp) && op.verifyAZeroPoint(*maybeAZp).failed())
457+
return failure();
458+
459+
FailureOr<int64_t> maybeBZp = op.getBZeroPoint();
460+
if (succeeded(maybeBZp) && op.verifyBZeroPoint(*maybeBZp).failed())
461+
return failure();
462+
463+
return success();
464+
}
465+
438466
LogicalResult tosa::ArgMaxOp::verify() {
439467
// Ensure output is of 32-bit integer
440468
const auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
@@ -601,23 +629,13 @@ buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
601629
static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
602630
OperationState &result, Type outputType,
603631
Value a, Value b) {
604-
result.addOperands({a, b});
605-
auto quantAttr = ::buildMatMulOpQuantizationAttr(builder, a, b);
632+
auto zps = createZPsAsConst(builder, a, b);
633+
result.addOperands({a, b, zps.first, zps.second});
606634

607-
if (quantAttr) {
608-
result.addAttribute("a_zp", builder.getI32IntegerAttr(
609-
static_cast<int32_t>(quantAttr.getAZp())));
610-
result.addAttribute("b_zp", builder.getI32IntegerAttr(
611-
static_cast<int32_t>(quantAttr.getBZp())));
612-
613-
auto inputType = llvm::dyn_cast<ShapedType>(a.getType());
614-
assert(inputType && "Input must be a shaped tensor type!");
615-
616-
auto inputQType = llvm::dyn_cast<mlir::quant::UniformQuantizedType>(
617-
inputType.getElementType());
618-
assert(inputQType && "Tensor must have quantized datatype!");
619-
620-
unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
635+
Type finalOutputType{outputType};
636+
if (auto quantAttr = buildMatMulOpQuantizationAttr(builder, a, b)) {
637+
auto eType = getStorageElementTypeOrSelf(a.getType());
638+
auto inputBits = eType.getIntOrFloatBitWidth();
621639

622640
auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
623641
assert(outputShapedType && "Output must be a shaped type");
@@ -627,11 +645,10 @@ static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
627645
accElementType = builder.getIntegerType(48);
628646
else
629647
accElementType = builder.getI32Type();
630-
auto accType = outputShapedType.clone(accElementType);
631-
result.addTypes(accType);
632-
} else {
633-
result.addTypes(outputType);
648+
649+
finalOutputType = outputShapedType.clone(accElementType);
634650
}
651+
result.addTypes(finalOutputType);
635652
}
636653

637654
/// Both the tosa.avg_pool2d and unary ops use the same
@@ -1025,6 +1042,22 @@ LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
10251042
return success();
10261043
}
10271044

1045+
LogicalResult MatMulOp::verify() {
1046+
auto aType = llvm::dyn_cast<ShapedType>(getA().getType());
1047+
auto bType = llvm::dyn_cast<ShapedType>(getB().getType());
1048+
1049+
// Must be shaped tensor types
1050+
if (!aType)
1051+
return emitOpError("expect a shaped tensor for input a, got ")
1052+
<< getA().getType();
1053+
1054+
if (!bType)
1055+
return emitOpError("expect a shaped tensor for input b, got ")
1056+
<< getB().getType();
1057+
1058+
return verifyZpMatMul(*this);
1059+
}
1060+
10281061
LogicalResult tosa::PadOp::inferReturnTypeComponents(
10291062
MLIRContext *context, ::std::optional<Location> location,
10301063
PadOp::Adaptor adaptor,
@@ -1560,6 +1593,8 @@ ZERO_POINT_HELPER(TransposeConv2DOp, Input)
15601593
ZERO_POINT_HELPER(TransposeConv2DOp, Weight)
15611594
ZERO_POINT_HELPER(AvgPool2dOp, Input)
15621595
ZERO_POINT_HELPER(AvgPool2dOp, Output)
1596+
ZERO_POINT_HELPER(MatMulOp, A)
1597+
ZERO_POINT_HELPER(MatMulOp, B)
15631598
#undef ZERO_POINT_HELPER
15641599

15651600
LogicalResult tosa::TransposeOp::inferReturnTypeComponents(

mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,15 @@ void ProfileInfoDepot::populateProfileInfo(tosa::RescaleOp op) {
178178
addValue(op.getOutput());
179179
}
180180

181+
template <>
182+
void ProfileInfoDepot::populateProfileInfo(tosa::MatMulOp op) {
183+
addValue(op.getA());
184+
addValue(op.getB());
185+
addValue(op.getAZp());
186+
addValue(op.getBZp());
187+
addValue(op.getOutput());
188+
}
189+
181190
LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
182191
// This helper function only populates the info for the customised operands.
183192
#define POPULATE_PROFILE_INFO_CUSTOM(tosaOp) \
@@ -218,6 +227,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
218227
POPULATE_PROFILE_INFO_CUSTOM(Resize)
219228
POPULATE_PROFILE_INFO_CUSTOM(Select)
220229
POPULATE_PROFILE_INFO_CUSTOM(Rescale)
230+
POPULATE_PROFILE_INFO_CUSTOM(MatMul)
221231

222232
// Type Invariant Extension, a capability extension that is independent
223233
// of the data type, meaning any compatible type can be used. No type
@@ -235,7 +245,6 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
235245
POPULATE_PROFILE_INFO_COMMON(Cast)
236246
POPULATE_PROFILE_INFO_COMMON(Const)
237247
POPULATE_PROFILE_INFO_COMMON(ArgMax)
238-
POPULATE_PROFILE_INFO_COMMON(MatMul)
239248
POPULATE_PROFILE_INFO_COMMON(Sub)
240249
POPULATE_PROFILE_INFO_COMMON(Maximum)
241250
POPULATE_PROFILE_INFO_COMMON(Minimum)

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ func.func @matmul(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x6xf32>) -> (tensor
88
// CHECK: [[INIT:%.+]] = tensor.empty()
99
// CHECK: [[FILLED:%.+]] = linalg.fill ins([[C0]] : f32) outs([[INIT]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32>
1010
// CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x5x3xf32>, tensor<1x3x6xf32>) outs([[FILLED]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32>
11-
%0 = tosa.matmul %arg0, %arg1 : (tensor<1x5x3xf32>, tensor<1x3x6xf32>) -> tensor<1x5x6xf32>
11+
%a_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
12+
%b_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
13+
%0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x5x3xf32>, tensor<1x3x6xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x5x6xf32>
1214
return %0 : tensor<1x5x6xf32>
1315
}
1416

@@ -23,7 +25,9 @@ func.func @matmul_quantized(%arg0: tensor<1x5x3xi8>, %arg1: tensor<1x3x6xi8>) ->
2325
// CHECK: [[ONE:%.+]] = arith.constant 1
2426
// CHECK: [[TWO:%.+]] = arith.constant 2
2527
// CHECK: linalg.quantized_batch_matmul ins(%arg0, %arg1, [[ONE]], [[TWO]] : tensor<1x5x3xi8>, tensor<1x3x6xi8>, i32, i32) outs([[FILLED]] : tensor<1x5x6xi32>) -> tensor<1x5x6xi32>
26-
%0 = tosa.matmul %arg0, %arg1 {a_zp = 1 : i32, b_zp = 2 : i32} : (tensor<1x5x3xi8>, tensor<1x3x6xi8>) -> tensor<1x5x6xi32>
28+
%a_zp = "tosa.const"() <{value = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
29+
%b_zp = "tosa.const"() <{value = dense<2> : tensor<1xi8>}> : () -> tensor<1xi8>
30+
%0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x5x3xi8>, tensor<1x3x6xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x5x6xi32>
2731
return %0 : tensor<1x5x6xi32>
2832
}
2933

@@ -37,7 +41,9 @@ func.func @matmul_dyn_batch(%arg0: tensor<?x5x3xf32>, %arg1: tensor<?x3x6xf32>)
3741
// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM]])
3842
// CHECK: %[[FILLED:.+]] = linalg.fill ins(%[[C0_0]] : f32) outs(%[[INIT]] : tensor<?x5x6xf32>) -> tensor<?x5x6xf32>
3943
// CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<?x5x3xf32>, tensor<?x3x6xf32>) outs(%[[FILLED]] : tensor<?x5x6xf32>) -> tensor<?x5x6xf32>
40-
%0 = tosa.matmul %arg0, %arg1 : (tensor<?x5x3xf32>, tensor<?x3x6xf32>) -> tensor<?x5x6xf32>
44+
%a_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
45+
%b_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
46+
%0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<?x5x3xf32>, tensor<?x3x6xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x5x6xf32>
4147
return %0 : tensor<?x5x6xf32>
4248
}
4349

@@ -51,7 +57,9 @@ func.func @matmul_dyn_independent_dim(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x
5157
// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM]])
5258
// CHECK: %[[FILLED:.+]] = linalg.fill ins(%[[C0]] : f32) outs(%[[INIT]] : tensor<1x5x?xf32>) -> tensor<1x5x?xf32>
5359
// CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x5x3xf32>, tensor<1x3x?xf32>) outs(%[[FILLED]] : tensor<1x5x?xf32>) -> tensor<1x5x?xf32>
54-
%0 = tosa.matmul %arg0, %arg1 : (tensor<1x5x3xf32>, tensor<1x3x?xf32>) -> tensor<1x5x?xf32>
60+
%a_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
61+
%b_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
62+
%0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x5x3xf32>, tensor<1x3x?xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x5x?xf32>
5563
return %0 : tensor<1x5x?xf32>
5664
}
5765

@@ -63,7 +71,9 @@ func.func @matmul_dyn_independent_dim(%arg0: tensor<1x5x?xf32>, %arg1: tensor<1x
6371
// CHECK: %[[INIT:.+]] = tensor.empty()
6472
// CHECK: %[[FILLED:.+]] = linalg.fill ins(%[[C0]] : f32) outs(%[[INIT]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32>
6573
// CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x5x?xf32>, tensor<1x?x6xf32>) outs(%[[FILLED]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32>
66-
%0 = tosa.matmul %arg0, %arg1 : (tensor<1x5x?xf32>, tensor<1x?x6xf32>) -> tensor<1x5x6xf32>
74+
%a_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
75+
%b_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
76+
%0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x5x?xf32>, tensor<1x?x6xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x5x6xf32>
6777
return %0 : tensor<1x5x6xf32>
6878
}
6979

@@ -77,7 +87,9 @@ func.func @matmul_dyn_output(%arg0: tensor<1x1x8xf32>, %arg1: tensor<1x8x1xf32>)
7787
// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM0]]) : tensor<?x1x1xf32>
7888
// CHECK: %[[FILLED:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<?x1x1xf32>) -> tensor<?x1x1xf32>
7989
// CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x1x8xf32>, tensor<1x8x1xf32>) outs(%[[FILLED]] : tensor<?x1x1xf32>) -> tensor<?x1x1xf32>
80-
%0 = tosa.matmul %arg0, %arg1 : (tensor<1x1x8xf32>, tensor<1x8x1xf32>) -> tensor<?x1x1xf32>
90+
%a_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
91+
%b_zp = "tosa.const"() <{value = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
92+
%0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x1x8xf32>, tensor<1x8x1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x1x1xf32>
8193
return %0 : tensor<?x1x1xf32>
8294
}
8395

0 commit comments

Comments
 (0)