-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][tosa] Replace UniformQuantizedType by the more generic Quantiz… #126275
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-tosa Author: Tai Ly (Tai78641) Changes…edType in Conv verifiers also fixed buildTransConvOpWithQuantInfo to insert input/weight zp operands Change-Id: Ie1961af931864f801914a62976bc988881ee075e Full diff: https://github.com/llvm/llvm-project/pull/126275.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 031c279ff09e275..6143a9d23a00394 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -240,16 +240,13 @@ static LogicalResult verifyConvOp(T op) {
bool biasIsFloat = llvm::isa<FloatType>(biasEType);
bool resultIsFloat = llvm::isa<FloatType>(resultEType);
- if (auto quantType =
- llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType))
+ if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
inputEType = quantType.getStorageType();
- if (auto quantType =
- llvm::dyn_cast<mlir::quant::UniformQuantizedType>(biasEType))
+ if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
biasEType = quantType.getStorageType();
- if (auto quantType =
- llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultEType))
+ if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
resultEType = quantType.getStorageType();
if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
@@ -346,8 +343,7 @@ static LogicalResult verifyConvOpModes(T op) {
auto inputEType =
llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
- if (auto quantType =
- llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType))
+ if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
inputEType = quantType.getStorageType();
auto accType = op.getAccType();
@@ -369,7 +365,23 @@ static LogicalResult verifyConvOpModes(T op) {
if (inputEType.isF32() && !accType.isF32())
return op.emitOpError("accumulator type for f32 tensor is not f32");
- return success();
+ auto resultEType =
+ llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
+
+ if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
+ resultEType = quantType.getStorageType();
+
+ // check allowed input/result element types combinations
+ if ((inputEType.isInteger(8) && resultEType.isInteger(32)) ||
+ (inputEType.isInteger(16) && resultEType.isInteger(48)) ||
+ (isa<Float8E5M2Type>(inputEType) && resultEType.isF16()) ||
+ (isa<Float8E4M3FNType>(inputEType) && resultEType.isF16()) ||
+ (inputEType.isF16() && resultEType.isF16()) ||
+ (inputEType.isBF16() && resultEType.isBF16()) ||
+ (inputEType.isF32() && resultEType.isF32()))
+ return success();
+
+ return op.emitOpError("input/output element types are incompatible.");
}
// verify that inType and outType have same element types
@@ -519,7 +531,8 @@ static void buildTransConvOpWithQuantInfo(
OpBuilder &builder, OperationState &result, Type outputType, Value input,
Value weight, Value bias, DenseI64ArrayAttr outpad,
DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape, TypeAttr accType) {
- result.addOperands({input, weight, bias});
+ auto zps = createZPsAsConst(builder, input, weight);
+ result.addOperands({input, weight, bias, zps.first, zps.second});
result.addAttribute("out_pad", outpad);
result.addAttribute("stride", stride);
result.addAttribute("out_shape", outputShape);
@@ -2478,18 +2491,15 @@ LogicalResult mlir::tosa::getZeroPoint(ElementsAttr zpAttr, int64_t &zp) {
return failure();
}
-// Create a rank-0 const tensor for zero point of the source tensor.
+// Create a rank-1 const tensor for zero point of the source tensor.
std::optional<Value> mlir::tosa::createZeroPointTensor(OpBuilder &builder,
Location loc,
Type srcElemType,
int64_t zp) {
- if (auto quantType =
- llvm::dyn_cast<mlir::quant::UniformQuantizedType>(srcElemType))
- srcElemType = quantType.getStorageType();
-
- auto zpType = mlir::RankedTensorType::get({1}, srcElemType);
+ srcElemType = getElementTypeOrSelf(srcElemType);
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcElemType))
srcElemType = quantType.getStorageType();
+ auto zpType = mlir::RankedTensorType::get({1}, srcElemType);
if (llvm::isa<FloatType>(srcElemType)) {
auto zpAttr = DenseElementsAttr::get(
zpType, builder.getFloatAttr(srcElemType, static_cast<double>(zp)));
|
@llvm/pr-subscribers-mlir Author: Tai Ly (Tai78641) Changes…edType in Conv verifiers also fixed buildTransConvOpWithQuantInfo to insert input/weight zp operands Change-Id: Ie1961af931864f801914a62976bc988881ee075e Full diff: https://github.com/llvm/llvm-project/pull/126275.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 031c279ff09e275..6143a9d23a00394 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -240,16 +240,13 @@ static LogicalResult verifyConvOp(T op) {
bool biasIsFloat = llvm::isa<FloatType>(biasEType);
bool resultIsFloat = llvm::isa<FloatType>(resultEType);
- if (auto quantType =
- llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType))
+ if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
inputEType = quantType.getStorageType();
- if (auto quantType =
- llvm::dyn_cast<mlir::quant::UniformQuantizedType>(biasEType))
+ if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
biasEType = quantType.getStorageType();
- if (auto quantType =
- llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultEType))
+ if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
resultEType = quantType.getStorageType();
if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
@@ -346,8 +343,7 @@ static LogicalResult verifyConvOpModes(T op) {
auto inputEType =
llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
- if (auto quantType =
- llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType))
+ if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
inputEType = quantType.getStorageType();
auto accType = op.getAccType();
@@ -369,7 +365,23 @@ static LogicalResult verifyConvOpModes(T op) {
if (inputEType.isF32() && !accType.isF32())
return op.emitOpError("accumulator type for f32 tensor is not f32");
- return success();
+ auto resultEType =
+ llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
+
+ if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
+ resultEType = quantType.getStorageType();
+
+ // check allowed input/result element types combinations
+ if ((inputEType.isInteger(8) && resultEType.isInteger(32)) ||
+ (inputEType.isInteger(16) && resultEType.isInteger(48)) ||
+ (isa<Float8E5M2Type>(inputEType) && resultEType.isF16()) ||
+ (isa<Float8E4M3FNType>(inputEType) && resultEType.isF16()) ||
+ (inputEType.isF16() && resultEType.isF16()) ||
+ (inputEType.isBF16() && resultEType.isBF16()) ||
+ (inputEType.isF32() && resultEType.isF32()))
+ return success();
+
+ return op.emitOpError("input/output element types are incompatible.");
}
// verify that inType and outType have same element types
@@ -519,7 +531,8 @@ static void buildTransConvOpWithQuantInfo(
OpBuilder &builder, OperationState &result, Type outputType, Value input,
Value weight, Value bias, DenseI64ArrayAttr outpad,
DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape, TypeAttr accType) {
- result.addOperands({input, weight, bias});
+ auto zps = createZPsAsConst(builder, input, weight);
+ result.addOperands({input, weight, bias, zps.first, zps.second});
result.addAttribute("out_pad", outpad);
result.addAttribute("stride", stride);
result.addAttribute("out_shape", outputShape);
@@ -2478,18 +2491,15 @@ LogicalResult mlir::tosa::getZeroPoint(ElementsAttr zpAttr, int64_t &zp) {
return failure();
}
-// Create a rank-0 const tensor for zero point of the source tensor.
+// Create a rank-1 const tensor for zero point of the source tensor.
std::optional<Value> mlir::tosa::createZeroPointTensor(OpBuilder &builder,
Location loc,
Type srcElemType,
int64_t zp) {
- if (auto quantType =
- llvm::dyn_cast<mlir::quant::UniformQuantizedType>(srcElemType))
- srcElemType = quantType.getStorageType();
-
- auto zpType = mlir::RankedTensorType::get({1}, srcElemType);
+ srcElemType = getElementTypeOrSelf(srcElemType);
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcElemType))
srcElemType = quantType.getStorageType();
+ auto zpType = mlir::RankedTensorType::get({1}, srcElemType);
if (llvm::isa<FloatType>(srcElemType)) {
auto zpAttr = DenseElementsAttr::get(
zpType, builder.getFloatAttr(srcElemType, static_cast<double>(zp)));
|
b67944b
to
c62371b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to add some tests? Are there already in place?
…edType in Conv verifiers Change-Id: Ie1961af931864f801914a62976bc988881ee075e Signed-off-by: Tai Ly <[email protected]>
c62371b
to
dd71961
Compare
added tests |
Replace UniformQuantizedType by the more generic QuantizedType in Conv verifiers. Change-Id: Ie1961af931864f801914a62976bc988881ee075e Signed-off-by: Tai Ly <[email protected]> Co-authored-by: Thibaut Goetghebuer-Planchon <[email protected]>
Replace UniformQuantizedType by the more generic QuantizedType in Conv verifiers. Change-Id: Ie1961af931864f801914a62976bc988881ee075e Signed-off-by: Tai Ly <[email protected]> Co-authored-by: Thibaut Goetghebuer-Planchon <[email protected]>
Replace UniformQuantizedType by the more generic QuantizedType in Conv verifiers. Change-Id: Ie1961af931864f801914a62976bc988881ee075e Signed-off-by: Tai Ly <[email protected]> Co-authored-by: Thibaut Goetghebuer-Planchon <[email protected]>
…edType in Conv verifiers
Change-Id: Ie1961af931864f801914a62976bc988881ee075e