-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][tosa] Change MatMul zero-point to inputs #130332
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Tai78641
commented
Mar 7, 2025
- Change zero-point attributes to inputs
- Fix relevant mlir tests
- Enhance ShardingInterface in MatMul
* Change zero-point attributes to inputs * Fix relevant mlir tests * Enhance ShardingInterface in MatMul Signed-off-by: Udaya Ranga <[email protected]> Change-Id: Ia58b15cba546a948a6a4d8e8ee26a72cd050de4e
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: Tai Ly (Tai78641) Changes
Patch is 47.21 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/130332.diff 15 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index a9b458acd87f2..d3fd4c3d1d3e1 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -35,9 +35,11 @@ profileComplianceMap = {
{fp16T, fp16T, fp32T, fp32T},
{fp32T, fp32T, fp32T, fp32T}}}}},
{"tosa.matmul",
- {{{Profile::pro_int}, {{i8T, i8T, i32T}}},
+ {{{Profile::pro_int}, {{i8T, i8T, i8T, i8T, i32T}}},
{{Profile::pro_fp},
- {{fp16T, fp16T, fp16T}, {fp16T, fp16T, fp32T}, {fp32T, fp32T, fp32T}}}}},
+ {{fp16T, fp16T, fp16T, fp16T, fp16T},
+ {fp16T, fp16T, fp16T, fp16T, fp32T},
+ {fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
{"tosa.max_pool2d",
{{{Profile::pro_int}, {{i8T, i8T}}},
{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
@@ -273,10 +275,10 @@ extensionComplianceMap = {
{{Extension::int16}, {{i16T, i8T, i48T, i48T}}},
{{Extension::bf16}, {{bf16T, bf16T, fp32T, fp32T}}}}},
{"tosa.matmul",
- {{{Extension::int16}, {{i16T, i16T, i48T}}},
- {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp16T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp16T}}},
- {{Extension::bf16}, {{bf16T, bf16T, fp32T}}}}},
+ {{{Extension::int16}, {{i16T, i16T, i16T, i16T, i48T}}},
+ {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T}}},
+ {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T}}},
+ {{Extension::bf16}, {{bf16T, bf16T, bf16T, bf16T, fp32T}}}}},
{"tosa.max_pool2d",
{{{Extension::int16}, {{i16T, i16T}}},
{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 097f78cd487ea..ecddc9fe9a13d 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -311,8 +311,8 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
let arguments = (ins
Tosa_Tensor3D:$a,
Tosa_Tensor3D:$b,
- OptionalAttr<I32Attr>:$a_zp,
- OptionalAttr<I32Attr>:$b_zp
+ Tosa_ScalarIntOrFloatTensor:$a_zp,
+ Tosa_ScalarIntOrFloatTensor:$b_zp
);
let results = (outs
@@ -324,6 +324,13 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
Extension<[Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
];
+ let extraClassDeclaration = [{
+ FailureOr<int64_t> getAZeroPoint();
+ FailureOr<int64_t> getBZeroPoint();
+ LogicalResult verifyAZeroPoint(int64_t zp);
+ LogicalResult verifyBZeroPoint(int64_t zp);
+ }];
+
let builders = [Tosa_MatMulOpQuantInfoBuilder];
let hasVerifier = 1;
}
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 2a2589e19d0ac..13c62b2d3e91c 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -270,8 +270,8 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
return rewriter.notifyMatchFailure(
op, "weight zero point cannot be statically determined");
- int64_t inputZpVal = *maybeIZp;
- int64_t weightZpVal = *maybeWZp;
+ const int64_t inputZpVal = *maybeIZp;
+ const int64_t weightZpVal = *maybeWZp;
if (op.verifyInputZeroPoint(inputZpVal).failed())
return rewriter.notifyMatchFailure(
@@ -466,8 +466,8 @@ class DepthwiseConvConverter
return rewriter.notifyMatchFailure(
op, "weight zero point cannot be statically determined");
- int64_t inputZpVal = *maybeIZp;
- int64_t weightZpVal = *maybeWZp;
+ const int64_t inputZpVal = *maybeIZp;
+ const int64_t weightZpVal = *maybeWZp;
if (op.verifyInputZeroPoint(inputZpVal).failed())
return rewriter.notifyMatchFailure(
@@ -621,15 +621,38 @@ class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
.create<linalg::FillOp>(loc, ValueRange{zero},
ValueRange{emptyTensor})
.result();
- if (!op.getAZp() && !op.getBZp()) {
+
+ FailureOr<int64_t> maybeAZp = op.getAZeroPoint();
+ FailureOr<int64_t> maybeBZp = op.getBZeroPoint();
+ if (failed(maybeAZp))
+ return rewriter.notifyMatchFailure(
+ op, "input a zero point cannot be statically determined");
+ if (failed(maybeBZp))
+ return rewriter.notifyMatchFailure(
+ op, "input b zero point cannot be statically determined");
+
+ const int64_t aZpVal = *maybeAZp;
+ const int64_t bZpVal = *maybeBZp;
+
+ if (op.verifyAZeroPoint(aZpVal).failed())
+ return rewriter.notifyMatchFailure(
+ op, "input a zero point must be zero for non-int8 integer types");
+
+ if (op.verifyBZeroPoint(bZpVal).failed())
+ return rewriter.notifyMatchFailure(
+ op, "input b zero point must be zero for non-int8 integer types");
+
+ if (aZpVal == 0 && bZpVal == 0) {
rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>(
op, TypeRange{op.getType()},
ValueRange{adaptor.getA(), adaptor.getB()}, ValueRange{zeroTensor});
return success();
}
- auto aZp = rewriter.create<arith::ConstantOp>(loc, op.getAZpAttr());
- auto bZp = rewriter.create<arith::ConstantOp>(loc, op.getBZpAttr());
+ auto aZp = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getI32IntegerAttr(aZpVal));
+ auto bZp = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getI32IntegerAttr(bZpVal));
rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
op, TypeRange{op.getType()},
ValueRange{adaptor.getA(), adaptor.getB(), aZp, bZp}, zeroTensor);
@@ -834,8 +857,8 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
return rewriter.notifyMatchFailure(
op, "output zero point could not be statically determined");
- int64_t inputZpVal = *maybeIZp;
- int64_t outputZpVal = *maybeOZp;
+ const int64_t inputZpVal = *maybeIZp;
+ const int64_t outputZpVal = *maybeOZp;
// Apply padding as necessary.
llvm::SmallVector<int64_t> pad;
diff --git a/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
index ffbb707344b8c..6dcb7c845b21f 100644
--- a/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
@@ -55,6 +55,8 @@ struct MatMulOpSharding
SmallVector<AffineMap> maps;
maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 1, 3}, ctx));
maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 3, 2}, ctx));
+ maps.push_back(AffineMap::get(0, 0, {}, ctx));
+ maps.push_back(AffineMap::get(0, 0, {}, ctx));
maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 1, 2}, ctx));
return maps;
}
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 4711122dc76e2..7a991b3876f69 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -629,23 +629,13 @@ buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
OperationState &result, Type outputType,
Value a, Value b) {
- result.addOperands({a, b});
- auto quantAttr = ::buildMatMulOpQuantizationAttr(builder, a, b);
+ auto zps = createZPsAsConst(builder, a, b);
+ result.addOperands({a, b, zps.first, zps.second});
- if (quantAttr) {
- result.addAttribute("a_zp", builder.getI32IntegerAttr(
- static_cast<int32_t>(quantAttr.getAZp())));
- result.addAttribute("b_zp", builder.getI32IntegerAttr(
- static_cast<int32_t>(quantAttr.getBZp())));
-
- auto inputType = llvm::dyn_cast<ShapedType>(a.getType());
- assert(inputType && "Input must be a shaped tensor type!");
-
- auto inputQType = llvm::dyn_cast<mlir::quant::UniformQuantizedType>(
- inputType.getElementType());
- assert(inputQType && "Tensor must have quantized datatype!");
-
- unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
+ Type finalOutputType{outputType};
+ if (auto quantAttr = buildMatMulOpQuantizationAttr(builder, a, b)) {
+ auto eType = getStorageElementTypeOrSelf(a.getType());
+ auto inputBits = eType.getIntOrFloatBitWidth();
auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
assert(outputShapedType && "Output must be a shaped type");
@@ -655,11 +645,10 @@ static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
accElementType = builder.getIntegerType(48);
else
accElementType = builder.getI32Type();
- auto accType = outputShapedType.clone(accElementType);
- result.addTypes(accType);
- } else {
- result.addTypes(outputType);
+
+ finalOutputType = outputShapedType.clone(accElementType);
}
+ result.addTypes(finalOutputType);
}
/// Both the tosa.avg_pool2d and unary ops use the same
@@ -1140,16 +1129,39 @@ LogicalResult MatMulOp::verify() {
return emitOpError("expect quantized operands to have same widths, got ")
<< aQuantWidth << " and " << bQuantWidth;
}
+ } else {
+ // non-quantized element types
+ if (aElementType != bElementType) {
+ return emitOpError("expect same element type for inputs a and b, got ")
+ << aElementType << " and " << bElementType;
+ }
+ }
- return success();
+ // check a_zp and b_zp
+ auto aEType = getStorageElementTypeOrSelf(aType);
+ auto aZpEType = getStorageElementTypeOrSelf(getAZp().getType());
+ if (aEType != aZpEType) {
+ return emitOpError("expect input a and a_zp have the same "
+ "element type, got ")
+ << aEType << " and " << aZpEType;
}
- // non-quantized element types
- if (aElementType != bElementType) {
- return emitOpError("expect same element type for inputs a and b, got ")
- << aElementType << " and " << bElementType;
+ auto bEType = getStorageElementTypeOrSelf(bType);
+ auto bZpEType = getStorageElementTypeOrSelf(getBZp().getType());
+ if (bEType != bZpEType) {
+ return emitOpError("expect input b and b_zp have the same "
+ "element type, got ")
+ << bEType << " and " << bZpEType;
}
+ FailureOr<int64_t> maybeAZp = getAZeroPoint();
+ if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed())
+ return failure();
+
+ FailureOr<int64_t> maybeBZp = getBZeroPoint();
+ if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed())
+ return failure();
+
return success();
}
@@ -1714,6 +1726,8 @@ ZERO_POINT_HELPER(TransposeConv2DOp, Input)
ZERO_POINT_HELPER(TransposeConv2DOp, Weight)
ZERO_POINT_HELPER(AvgPool2dOp, Input)
ZERO_POINT_HELPER(AvgPool2dOp, Output)
+ZERO_POINT_HELPER(MatMulOp, A)
+ZERO_POINT_HELPER(MatMulOp, B)
#undef ZERO_POINT_HELPER
LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index 345616c9563b5..983062ffd7912 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -178,6 +178,15 @@ void ProfileInfoDepot::populateProfileInfo(tosa::RescaleOp op) {
addValue(op.getOutput());
}
+template <>
+void ProfileInfoDepot::populateProfileInfo(tosa::MatMulOp op) {
+ addValue(op.getA());
+ addValue(op.getB());
+ addValue(op.getAZp());
+ addValue(op.getBZp());
+ addValue(op.getOutput());
+}
+
LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
// This helper function only populates the info for the customised operands.
#define POPULATE_PROFILE_INFO_CUSTOM(tosaOp) \
@@ -218,6 +227,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
POPULATE_PROFILE_INFO_CUSTOM(Resize)
POPULATE_PROFILE_INFO_CUSTOM(Select)
POPULATE_PROFILE_INFO_CUSTOM(Rescale)
+ POPULATE_PROFILE_INFO_CUSTOM(MatMul)
// Type Invariant Extension, a capability extension that is independent
// of the data type, meaning any compatible type can be used. No type
@@ -235,7 +245,6 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
POPULATE_PROFILE_INFO_COMMON(Cast)
POPULATE_PROFILE_INFO_COMMON(Const)
POPULATE_PROFILE_INFO_COMMON(ArgMax)
- POPULATE_PROFILE_INFO_COMMON(MatMul)
POPULATE_PROFILE_INFO_COMMON(Sub)
POPULATE_PROFILE_INFO_COMMON(Maximum)
POPULATE_PROFILE_INFO_COMMON(Minimum)
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 5bb4a3bddb51b..341f773c79a5e 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -8,7 +8,9 @@ func.func @matmul(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x6xf32>) -> (tensor
// CHECK: [[INIT:%.+]] = tensor.empty()
// CHECK: [[FILLED:%.+]] = linalg.fill ins([[C0]] : f32) outs([[INIT]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32>
// CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x5x3xf32>, tensor<1x3x6xf32>) outs([[FILLED]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32>
- %0 = tosa.matmul %arg0, %arg1 : (tensor<1x5x3xf32>, tensor<1x3x6xf32>) -> tensor<1x5x6xf32>
+ %a_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %b_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x5x3xf32>, tensor<1x3x6xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x5x6xf32>
return %0 : tensor<1x5x6xf32>
}
@@ -23,7 +25,9 @@ func.func @matmul_quantized(%arg0: tensor<1x5x3xi8>, %arg1: tensor<1x3x6xi8>) ->
// CHECK: [[ONE:%.+]] = arith.constant 1
// CHECK: [[TWO:%.+]] = arith.constant 2
// CHECK: linalg.quantized_batch_matmul ins(%arg0, %arg1, [[ONE]], [[TWO]] : tensor<1x5x3xi8>, tensor<1x3x6xi8>, i32, i32) outs([[FILLED]] : tensor<1x5x6xi32>) -> tensor<1x5x6xi32>
- %0 = tosa.matmul %arg0, %arg1 {a_zp = 1 : i32, b_zp = 2 : i32} : (tensor<1x5x3xi8>, tensor<1x3x6xi8>) -> tensor<1x5x6xi32>
+ %a_zp = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %b_zp = "tosa.const"() <{values = dense<2> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x5x3xi8>, tensor<1x3x6xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x5x6xi32>
return %0 : tensor<1x5x6xi32>
}
@@ -37,7 +41,9 @@ func.func @matmul_dyn_batch(%arg0: tensor<?x5x3xf32>, %arg1: tensor<?x3x6xf32>)
// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM]])
// CHECK: %[[FILLED:.+]] = linalg.fill ins(%[[C0_0]] : f32) outs(%[[INIT]] : tensor<?x5x6xf32>) -> tensor<?x5x6xf32>
// CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<?x5x3xf32>, tensor<?x3x6xf32>) outs(%[[FILLED]] : tensor<?x5x6xf32>) -> tensor<?x5x6xf32>
- %0 = tosa.matmul %arg0, %arg1 : (tensor<?x5x3xf32>, tensor<?x3x6xf32>) -> tensor<?x5x6xf32>
+ %a_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %b_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<?x5x3xf32>, tensor<?x3x6xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x5x6xf32>
return %0 : tensor<?x5x6xf32>
}
@@ -51,7 +57,9 @@ func.func @matmul_dyn_independent_dim(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x
// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM]])
// CHECK: %[[FILLED:.+]] = linalg.fill ins(%[[C0]] : f32) outs(%[[INIT]] : tensor<1x5x?xf32>) -> tensor<1x5x?xf32>
// CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x5x3xf32>, tensor<1x3x?xf32>) outs(%[[FILLED]] : tensor<1x5x?xf32>) -> tensor<1x5x?xf32>
- %0 = tosa.matmul %arg0, %arg1 : (tensor<1x5x3xf32>, tensor<1x3x?xf32>) -> tensor<1x5x?xf32>
+ %a_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %b_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x5x3xf32>, tensor<1x3x?xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x5x?xf32>
return %0 : tensor<1x5x?xf32>
}
@@ -63,7 +71,9 @@ func.func @matmul_dyn_independent_dim(%arg0: tensor<1x5x?xf32>, %arg1: tensor<1x
// CHECK: %[[INIT:.+]] = tensor.empty()
// CHECK: %[[FILLED:.+]] = linalg.fill ins(%[[C0]] : f32) outs(%[[INIT]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32>
// CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x5x?xf32>, tensor<1x?x6xf32>) outs(%[[FILLED]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32>
- %0 = tosa.matmul %arg0, %arg1 : (tensor<1x5x?xf32>, tensor<1x?x6xf32>) -> tensor<1x5x6xf32>
+ %a_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %b_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x5x?xf32>, tensor<1x?x6xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x5x6xf32>
return %0 : tensor<1x5x6xf32>
}
@@ -77,7 +87,9 @@ func.func @matmul_dyn_output(%arg0: tensor<1x1x8xf32>, %arg1: tensor<1x8x1xf32>)
// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM0]]) : tensor<?x1x1xf32>
// CHECK: %[[FILLED:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<?x1x1xf32>) -> tensor<?x1x1xf32>
// CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x1x8xf32>, tensor<1x8x1xf32>) outs(%[[FILLED]] : tensor<?x1x1xf32>) -> tensor<?x1x1xf32>
- %0 = tosa.matmul %arg0, %arg1 : (tensor<1x1x8xf32>, tensor<1x8x1xf32>) -> tensor<?x1x1xf32>
+ %a_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %b_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x1x8xf32>, tensor<1x8x1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x1x1xf32>
return %0 : tensor<?x1x1xf32>
}
diff --git a/mlir/test/Dialect/Mesh/sharding-propagation.mlir b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
index 83136f613b020..14c67e670e921 100644
--- a/mlir/test/Dialect/Mesh/sharding-propagation.mlir
+++ b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
@@ -98,14 +98,16 @@ func.func @arrow_structure(%arg0: tensor<8x16xf32>) -> (tensor<8x16xf32>, tensor
}
// CHECK-LABEL: func.func @matmul_on_def_shard_batch_and_m
-// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>
-func.func @matmul_on_def_shard_batch_and_m(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>) -> tensor<2x16x32xf32> {
+// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<1xf32>
+func.func @matmul_on_def_shard_batch_and_m(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> {
// CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding
// CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG0]] to %[[S0]] annotate_for_users : tensor<2x16x8xf32>
// CHECK-NEXT: %[[S1:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0]] : !mesh.sharding
// CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG1]] to %[[S1]] annotate_for_users : tensor<2x8x32xf32>
- // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]]
- %0 = tosa.matmul %arg0, %arg1 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>) -> tensor<2x16x32xf32>
+ // CHECK-NEXT: %[[S2:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding
+ // CHECK-NEXT: %[[ZP:.*]] = mesh.shard %[[ARG2]] to %[[S2]] annotate_for_users : tensor<1xf32>
+ // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]], %[[ZP]], %[[ZP]]
+ %0 = tosa.matmul %arg0, %arg1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x16x32xf32>
// CHECK-NEXT: %[[S3:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding
// CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S3]] :...
[truncated]
|
@llvm/pr-subscribers-mlir-tosa Author: Tai Ly (Tai78641) Changes
Patch is 47.21 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/130332.diff 15 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index a9b458acd87f2..d3fd4c3d1d3e1 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -35,9 +35,11 @@ profileComplianceMap = {
{fp16T, fp16T, fp32T, fp32T},
{fp32T, fp32T, fp32T, fp32T}}}}},
{"tosa.matmul",
- {{{Profile::pro_int}, {{i8T, i8T, i32T}}},
+ {{{Profile::pro_int}, {{i8T, i8T, i8T, i8T, i32T}}},
{{Profile::pro_fp},
- {{fp16T, fp16T, fp16T}, {fp16T, fp16T, fp32T}, {fp32T, fp32T, fp32T}}}}},
+ {{fp16T, fp16T, fp16T, fp16T, fp16T},
+ {fp16T, fp16T, fp16T, fp16T, fp32T},
+ {fp32T, fp32T, fp32T, fp32T, fp32T}}}}},
{"tosa.max_pool2d",
{{{Profile::pro_int}, {{i8T, i8T}}},
{{Profile::pro_fp}, {{fp16T, fp16T}, {fp32T, fp32T}}}}},
@@ -273,10 +275,10 @@ extensionComplianceMap = {
{{Extension::int16}, {{i16T, i8T, i48T, i48T}}},
{{Extension::bf16}, {{bf16T, bf16T, fp32T, fp32T}}}}},
{"tosa.matmul",
- {{{Extension::int16}, {{i16T, i16T, i48T}}},
- {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp16T}}},
- {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp16T}}},
- {{Extension::bf16}, {{bf16T, bf16T, fp32T}}}}},
+ {{{Extension::int16}, {{i16T, i16T, i16T, i16T, i48T}}},
+ {{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T, fp8e4m3T, fp8e4m3T, fp16T}}},
+ {{Extension::fp8e5m2}, {{fp8e5m2T, fp8e5m2T, fp8e5m2T, fp8e5m2T, fp16T}}},
+ {{Extension::bf16}, {{bf16T, bf16T, bf16T, bf16T, fp32T}}}}},
{"tosa.max_pool2d",
{{{Extension::int16}, {{i16T, i16T}}},
{{Extension::fp8e4m3}, {{fp8e4m3T, fp8e4m3T}}},
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 097f78cd487ea..ecddc9fe9a13d 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -311,8 +311,8 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
let arguments = (ins
Tosa_Tensor3D:$a,
Tosa_Tensor3D:$b,
- OptionalAttr<I32Attr>:$a_zp,
- OptionalAttr<I32Attr>:$b_zp
+ Tosa_ScalarIntOrFloatTensor:$a_zp,
+ Tosa_ScalarIntOrFloatTensor:$b_zp
);
let results = (outs
@@ -324,6 +324,13 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> {
Extension<[Tosa_EXT_INT16, Tosa_EXT_FP8E4M3, Tosa_EXT_FP8E5M2, Tosa_EXT_BF16]>,
];
+ let extraClassDeclaration = [{
+ FailureOr<int64_t> getAZeroPoint();
+ FailureOr<int64_t> getBZeroPoint();
+ LogicalResult verifyAZeroPoint(int64_t zp);
+ LogicalResult verifyBZeroPoint(int64_t zp);
+ }];
+
let builders = [Tosa_MatMulOpQuantInfoBuilder];
let hasVerifier = 1;
}
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 2a2589e19d0ac..13c62b2d3e91c 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -270,8 +270,8 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
return rewriter.notifyMatchFailure(
op, "weight zero point cannot be statically determined");
- int64_t inputZpVal = *maybeIZp;
- int64_t weightZpVal = *maybeWZp;
+ const int64_t inputZpVal = *maybeIZp;
+ const int64_t weightZpVal = *maybeWZp;
if (op.verifyInputZeroPoint(inputZpVal).failed())
return rewriter.notifyMatchFailure(
@@ -466,8 +466,8 @@ class DepthwiseConvConverter
return rewriter.notifyMatchFailure(
op, "weight zero point cannot be statically determined");
- int64_t inputZpVal = *maybeIZp;
- int64_t weightZpVal = *maybeWZp;
+ const int64_t inputZpVal = *maybeIZp;
+ const int64_t weightZpVal = *maybeWZp;
if (op.verifyInputZeroPoint(inputZpVal).failed())
return rewriter.notifyMatchFailure(
@@ -621,15 +621,38 @@ class MatMulConverter : public OpConversionPattern<tosa::MatMulOp> {
.create<linalg::FillOp>(loc, ValueRange{zero},
ValueRange{emptyTensor})
.result();
- if (!op.getAZp() && !op.getBZp()) {
+
+ FailureOr<int64_t> maybeAZp = op.getAZeroPoint();
+ FailureOr<int64_t> maybeBZp = op.getBZeroPoint();
+ if (failed(maybeAZp))
+ return rewriter.notifyMatchFailure(
+ op, "input a zero point cannot be statically determined");
+ if (failed(maybeBZp))
+ return rewriter.notifyMatchFailure(
+ op, "input b zero point cannot be statically determined");
+
+ const int64_t aZpVal = *maybeAZp;
+ const int64_t bZpVal = *maybeBZp;
+
+ if (op.verifyAZeroPoint(aZpVal).failed())
+ return rewriter.notifyMatchFailure(
+ op, "input a zero point must be zero for non-int8 integer types");
+
+ if (op.verifyBZeroPoint(bZpVal).failed())
+ return rewriter.notifyMatchFailure(
+ op, "input b zero point must be zero for non-int8 integer types");
+
+ if (aZpVal == 0 && bZpVal == 0) {
rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>(
op, TypeRange{op.getType()},
ValueRange{adaptor.getA(), adaptor.getB()}, ValueRange{zeroTensor});
return success();
}
- auto aZp = rewriter.create<arith::ConstantOp>(loc, op.getAZpAttr());
- auto bZp = rewriter.create<arith::ConstantOp>(loc, op.getBZpAttr());
+ auto aZp = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getI32IntegerAttr(aZpVal));
+ auto bZp = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getI32IntegerAttr(bZpVal));
rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>(
op, TypeRange{op.getType()},
ValueRange{adaptor.getA(), adaptor.getB(), aZp, bZp}, zeroTensor);
@@ -834,8 +857,8 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
return rewriter.notifyMatchFailure(
op, "output zero point could not be statically determined");
- int64_t inputZpVal = *maybeIZp;
- int64_t outputZpVal = *maybeOZp;
+ const int64_t inputZpVal = *maybeIZp;
+ const int64_t outputZpVal = *maybeOZp;
// Apply padding as necessary.
llvm::SmallVector<int64_t> pad;
diff --git a/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
index ffbb707344b8c..6dcb7c845b21f 100644
--- a/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp
@@ -55,6 +55,8 @@ struct MatMulOpSharding
SmallVector<AffineMap> maps;
maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 1, 3}, ctx));
maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 3, 2}, ctx));
+ maps.push_back(AffineMap::get(0, 0, {}, ctx));
+ maps.push_back(AffineMap::get(0, 0, {}, ctx));
maps.push_back(AffineMap::getMultiDimMapWithTargets(4, {0, 1, 2}, ctx));
return maps;
}
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 4711122dc76e2..7a991b3876f69 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -629,23 +629,13 @@ buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
OperationState &result, Type outputType,
Value a, Value b) {
- result.addOperands({a, b});
- auto quantAttr = ::buildMatMulOpQuantizationAttr(builder, a, b);
+ auto zps = createZPsAsConst(builder, a, b);
+ result.addOperands({a, b, zps.first, zps.second});
- if (quantAttr) {
- result.addAttribute("a_zp", builder.getI32IntegerAttr(
- static_cast<int32_t>(quantAttr.getAZp())));
- result.addAttribute("b_zp", builder.getI32IntegerAttr(
- static_cast<int32_t>(quantAttr.getBZp())));
-
- auto inputType = llvm::dyn_cast<ShapedType>(a.getType());
- assert(inputType && "Input must be a shaped tensor type!");
-
- auto inputQType = llvm::dyn_cast<mlir::quant::UniformQuantizedType>(
- inputType.getElementType());
- assert(inputQType && "Tensor must have quantized datatype!");
-
- unsigned inputBits = inputQType.getStorageTypeIntegralWidth();
+ Type finalOutputType{outputType};
+ if (auto quantAttr = buildMatMulOpQuantizationAttr(builder, a, b)) {
+ auto eType = getStorageElementTypeOrSelf(a.getType());
+ auto inputBits = eType.getIntOrFloatBitWidth();
auto outputShapedType = llvm::dyn_cast<ShapedType>(outputType);
assert(outputShapedType && "Output must be a shaped type");
@@ -655,11 +645,10 @@ static void buildMatMulOpWithQuantInfo(OpBuilder &builder,
accElementType = builder.getIntegerType(48);
else
accElementType = builder.getI32Type();
- auto accType = outputShapedType.clone(accElementType);
- result.addTypes(accType);
- } else {
- result.addTypes(outputType);
+
+ finalOutputType = outputShapedType.clone(accElementType);
}
+ result.addTypes(finalOutputType);
}
/// Both the tosa.avg_pool2d and unary ops use the same
@@ -1140,16 +1129,39 @@ LogicalResult MatMulOp::verify() {
return emitOpError("expect quantized operands to have same widths, got ")
<< aQuantWidth << " and " << bQuantWidth;
}
+ } else {
+ // non-quantized element types
+ if (aElementType != bElementType) {
+ return emitOpError("expect same element type for inputs a and b, got ")
+ << aElementType << " and " << bElementType;
+ }
+ }
- return success();
+ // check a_zp and b_zp
+ auto aEType = getStorageElementTypeOrSelf(aType);
+ auto aZpEType = getStorageElementTypeOrSelf(getAZp().getType());
+ if (aEType != aZpEType) {
+ return emitOpError("expect input a and a_zp have the same "
+ "element type, got ")
+ << aEType << " and " << aZpEType;
}
- // non-quantized element types
- if (aElementType != bElementType) {
- return emitOpError("expect same element type for inputs a and b, got ")
- << aElementType << " and " << bElementType;
+ auto bEType = getStorageElementTypeOrSelf(bType);
+ auto bZpEType = getStorageElementTypeOrSelf(getBZp().getType());
+ if (bEType != bZpEType) {
+ return emitOpError("expect input b and b_zp have the same "
+ "element type, got ")
+ << bEType << " and " << bZpEType;
}
+ FailureOr<int64_t> maybeAZp = getAZeroPoint();
+ if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed())
+ return failure();
+
+ FailureOr<int64_t> maybeBZp = getBZeroPoint();
+ if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed())
+ return failure();
+
return success();
}
@@ -1714,6 +1726,8 @@ ZERO_POINT_HELPER(TransposeConv2DOp, Input)
ZERO_POINT_HELPER(TransposeConv2DOp, Weight)
ZERO_POINT_HELPER(AvgPool2dOp, Input)
ZERO_POINT_HELPER(AvgPool2dOp, Output)
+ZERO_POINT_HELPER(MatMulOp, A)
+ZERO_POINT_HELPER(MatMulOp, B)
#undef ZERO_POINT_HELPER
LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index 345616c9563b5..983062ffd7912 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -178,6 +178,15 @@ void ProfileInfoDepot::populateProfileInfo(tosa::RescaleOp op) {
addValue(op.getOutput());
}
+template <>
+void ProfileInfoDepot::populateProfileInfo(tosa::MatMulOp op) {
+ addValue(op.getA());
+ addValue(op.getB());
+ addValue(op.getAZp());
+ addValue(op.getBZp());
+ addValue(op.getOutput());
+}
+
LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
// This helper function only populates the info for the customised operands.
#define POPULATE_PROFILE_INFO_CUSTOM(tosaOp) \
@@ -218,6 +227,7 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
POPULATE_PROFILE_INFO_CUSTOM(Resize)
POPULATE_PROFILE_INFO_CUSTOM(Select)
POPULATE_PROFILE_INFO_CUSTOM(Rescale)
+ POPULATE_PROFILE_INFO_CUSTOM(MatMul)
// Type Invariant Extension, a capability extension that is independent
// of the data type, meaning any compatible type can be used. No type
@@ -235,7 +245,6 @@ LogicalResult ProfileInfoDepot::populatationDispatch(Operation *op) {
POPULATE_PROFILE_INFO_COMMON(Cast)
POPULATE_PROFILE_INFO_COMMON(Const)
POPULATE_PROFILE_INFO_COMMON(ArgMax)
- POPULATE_PROFILE_INFO_COMMON(MatMul)
POPULATE_PROFILE_INFO_COMMON(Sub)
POPULATE_PROFILE_INFO_COMMON(Maximum)
POPULATE_PROFILE_INFO_COMMON(Minimum)
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 5bb4a3bddb51b..341f773c79a5e 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -8,7 +8,9 @@ func.func @matmul(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x3x6xf32>) -> (tensor
// CHECK: [[INIT:%.+]] = tensor.empty()
// CHECK: [[FILLED:%.+]] = linalg.fill ins([[C0]] : f32) outs([[INIT]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32>
// CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x5x3xf32>, tensor<1x3x6xf32>) outs([[FILLED]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32>
- %0 = tosa.matmul %arg0, %arg1 : (tensor<1x5x3xf32>, tensor<1x3x6xf32>) -> tensor<1x5x6xf32>
+ %a_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %b_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x5x3xf32>, tensor<1x3x6xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x5x6xf32>
return %0 : tensor<1x5x6xf32>
}
@@ -23,7 +25,9 @@ func.func @matmul_quantized(%arg0: tensor<1x5x3xi8>, %arg1: tensor<1x3x6xi8>) ->
// CHECK: [[ONE:%.+]] = arith.constant 1
// CHECK: [[TWO:%.+]] = arith.constant 2
// CHECK: linalg.quantized_batch_matmul ins(%arg0, %arg1, [[ONE]], [[TWO]] : tensor<1x5x3xi8>, tensor<1x3x6xi8>, i32, i32) outs([[FILLED]] : tensor<1x5x6xi32>) -> tensor<1x5x6xi32>
- %0 = tosa.matmul %arg0, %arg1 {a_zp = 1 : i32, b_zp = 2 : i32} : (tensor<1x5x3xi8>, tensor<1x3x6xi8>) -> tensor<1x5x6xi32>
+ %a_zp = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %b_zp = "tosa.const"() <{values = dense<2> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x5x3xi8>, tensor<1x3x6xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x5x6xi32>
return %0 : tensor<1x5x6xi32>
}
@@ -37,7 +41,9 @@ func.func @matmul_dyn_batch(%arg0: tensor<?x5x3xf32>, %arg1: tensor<?x3x6xf32>)
// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM]])
// CHECK: %[[FILLED:.+]] = linalg.fill ins(%[[C0_0]] : f32) outs(%[[INIT]] : tensor<?x5x6xf32>) -> tensor<?x5x6xf32>
// CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<?x5x3xf32>, tensor<?x3x6xf32>) outs(%[[FILLED]] : tensor<?x5x6xf32>) -> tensor<?x5x6xf32>
- %0 = tosa.matmul %arg0, %arg1 : (tensor<?x5x3xf32>, tensor<?x3x6xf32>) -> tensor<?x5x6xf32>
+ %a_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %b_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<?x5x3xf32>, tensor<?x3x6xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x5x6xf32>
return %0 : tensor<?x5x6xf32>
}
@@ -51,7 +57,9 @@ func.func @matmul_dyn_independent_dim(%arg0: tensor<1x5x3xf32>, %arg1: tensor<1x
// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM]])
// CHECK: %[[FILLED:.+]] = linalg.fill ins(%[[C0]] : f32) outs(%[[INIT]] : tensor<1x5x?xf32>) -> tensor<1x5x?xf32>
// CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x5x3xf32>, tensor<1x3x?xf32>) outs(%[[FILLED]] : tensor<1x5x?xf32>) -> tensor<1x5x?xf32>
- %0 = tosa.matmul %arg0, %arg1 : (tensor<1x5x3xf32>, tensor<1x3x?xf32>) -> tensor<1x5x?xf32>
+ %a_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %b_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x5x3xf32>, tensor<1x3x?xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x5x?xf32>
return %0 : tensor<1x5x?xf32>
}
@@ -63,7 +71,9 @@ func.func @matmul_dyn_independent_dim(%arg0: tensor<1x5x?xf32>, %arg1: tensor<1x
// CHECK: %[[INIT:.+]] = tensor.empty()
// CHECK: %[[FILLED:.+]] = linalg.fill ins(%[[C0]] : f32) outs(%[[INIT]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32>
// CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x5x?xf32>, tensor<1x?x6xf32>) outs(%[[FILLED]] : tensor<1x5x6xf32>) -> tensor<1x5x6xf32>
- %0 = tosa.matmul %arg0, %arg1 : (tensor<1x5x?xf32>, tensor<1x?x6xf32>) -> tensor<1x5x6xf32>
+ %a_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %b_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x5x?xf32>, tensor<1x?x6xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x5x6xf32>
return %0 : tensor<1x5x6xf32>
}
@@ -77,7 +87,9 @@ func.func @matmul_dyn_output(%arg0: tensor<1x1x8xf32>, %arg1: tensor<1x8x1xf32>)
// CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM0]]) : tensor<?x1x1xf32>
// CHECK: %[[FILLED:.+]] = linalg.fill ins(%[[CST]] : f32) outs(%[[INIT]] : tensor<?x1x1xf32>) -> tensor<?x1x1xf32>
// CHECK: linalg.batch_matmul ins(%arg0, %arg1 : tensor<1x1x8xf32>, tensor<1x8x1xf32>) outs(%[[FILLED]] : tensor<?x1x1xf32>) -> tensor<?x1x1xf32>
- %0 = tosa.matmul %arg0, %arg1 : (tensor<1x1x8xf32>, tensor<1x8x1xf32>) -> tensor<?x1x1xf32>
+ %a_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %b_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %0 = tosa.matmul %arg0, %arg1, %a_zp, %b_zp : (tensor<1x1x8xf32>, tensor<1x8x1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x1x1xf32>
return %0 : tensor<?x1x1xf32>
}
diff --git a/mlir/test/Dialect/Mesh/sharding-propagation.mlir b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
index 83136f613b020..14c67e670e921 100644
--- a/mlir/test/Dialect/Mesh/sharding-propagation.mlir
+++ b/mlir/test/Dialect/Mesh/sharding-propagation.mlir
@@ -98,14 +98,16 @@ func.func @arrow_structure(%arg0: tensor<8x16xf32>) -> (tensor<8x16xf32>, tensor
}
// CHECK-LABEL: func.func @matmul_on_def_shard_batch_and_m
-// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>
-func.func @matmul_on_def_shard_batch_and_m(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>) -> tensor<2x16x32xf32> {
+// CHECK-SAME: %[[ARG0:.*]]: tensor<2x16x8xf32>, %[[ARG1:.*]]: tensor<2x8x32xf32>, %[[ARG2:.*]]: tensor<1xf32>
+func.func @matmul_on_def_shard_batch_and_m(%arg0: tensor<2x16x8xf32>, %arg1: tensor<2x8x32xf32>, %arg2: tensor<1xf32>) -> tensor<2x16x32xf32> {
// CHECK-NEXT: %[[S0:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding
// CHECK-NEXT: %[[V0:.*]] = mesh.shard %[[ARG0]] to %[[S0]] annotate_for_users : tensor<2x16x8xf32>
// CHECK-NEXT: %[[S1:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0]] : !mesh.sharding
// CHECK-NEXT: %[[V1:.*]] = mesh.shard %[[ARG1]] to %[[S1]] annotate_for_users : tensor<2x8x32xf32>
- // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]]
- %0 = tosa.matmul %arg0, %arg1 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>) -> tensor<2x16x32xf32>
+ // CHECK-NEXT: %[[S2:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}]] : !mesh.sharding
+ // CHECK-NEXT: %[[ZP:.*]] = mesh.shard %[[ARG2]] to %[[S2]] annotate_for_users : tensor<1xf32>
+ // CHECK-NEXT: %[[V2:.*]] = tosa.matmul %[[V0]], %[[V1]], %[[ZP]], %[[ZP]]
+ %0 = tosa.matmul %arg0, %arg1, %arg2, %arg2 : (tensor<2x16x8xf32>, tensor<2x8x32xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<2x16x32xf32>
// CHECK-NEXT: %[[S3:.*]] = mesh.sharding @mesh_2d split_axes = {{\[\[}}0], [1]] : !mesh.sharding
// CHECK-NEXT: %[[V3:.*]] = mesh.shard %[[V2]] to %[[S3]] :...
[truncated]
|
This PR is rebased from this original PR: #129785
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approved. Minor rebase on top of the original PR (#129785)
* Change zero-point attributes to inputs * Fix relevant mlir tests * Enhance ShardingInterface in MatMul Signed-off-by: Udaya Ranga <[email protected]> Co-authored-by: Udaya Ranga <[email protected]>