Skip to content

Add tosa.cast folding for unsigned integers #483

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1877,6 +1877,15 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure,
| signed 16 to float | int16 | float |
| float 32 to float 64 | float32 | float64 |
| float 64 to float 32 | float64 | float32 |

AMD extensions:
| signed to unsigned | signed | unsigned|
| unsigned to signed | unsigned| signed |
| unsigned to float | unsigned| float |
- unsigned to signed integer and signed to unsigned integer:
wrap on overflow
- unsigned to float:
uses llvm's float to int conversion with TOSA rounding mode
}];

let arguments = (ins
Expand Down
72 changes: 43 additions & 29 deletions mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,16 +110,34 @@ DenseElementsAttr applyElementWise(
// We already know the amount of values we will insert, reserve space for
// all of them to avoid dynamic resizing
transformedValues.reserve(toTransform.getNumElements());
for (auto val : toTransform.getValues<SrcValType>()) {
auto transformedVal = toApply(val, targetType);
transformedValues.push_back(transformedVal);
if constexpr (std::is_same_v<SrcValType, APSInt>) {
for (auto val : toTransform.getValues<APInt>()) {
auto transformedVal =
toApply(APSInt(val, toTransform.getElementType().isUnsignedInteger()),
targetType);
transformedValues.push_back(transformedVal);
}
} else {
for (auto val : toTransform.getValues<SrcValType>()) {
auto transformedVal = toApply(val, targetType);
transformedValues.push_back(transformedVal);
}
}

// Make sure that the output tensor has the expected output type
auto inShape = toTransform.getType();
auto outTy = inShape.cloneWith({}, targetType);

return DenseElementsAttr::get(outTy, transformedValues);
if constexpr (std::is_same_v<TargetValType, APSInt>) {
SmallVector<APInt> transformedValuesAPInt;
transformedValuesAPInt.reserve(transformedValues.size());
for (APSInt val : transformedValues) {
transformedValuesAPInt.emplace_back(val);
}
return DenseElementsAttr::get(outTy, transformedValuesAPInt);
} else {
return DenseElementsAttr::get(outTy, transformedValues);
}
}

template DenseElementsAttr applyElementWise<APFloat, APFloat, FloatType>(
Expand Down Expand Up @@ -881,10 +899,10 @@ struct TosaFoldConstantCast : public TosaFoldConstantBase<CastOp> {

using TosaFoldConstantBase::TosaFoldConstantBase;

static APFloat convertIntToFloat(const APInt &toConvert,
static APFloat convertIntToFloat(const APSInt &toConvert,
FloatType targetType) {
APFloat res(targetType.getFloatSemantics());
res.convertFromAPInt(toConvert, true /* isSigned */, tosaRoundingMode);
res.convertFromAPInt(toConvert, toConvert.isSigned(), tosaRoundingMode);
return res;
}

Expand Down Expand Up @@ -928,15 +946,14 @@ struct TosaFoldConstantCast : public TosaFoldConstantBase<CastOp> {
return converted;
}

static APInt convertIntToInt(const APInt &toConvert, IntegerType targetType) {
static APSInt convertIntToInt(const APSInt &toConvert,
IntegerType targetType) {
// Make sure to properly translate booleans
if (targetType.getWidth() == 1) {
return toConvert.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1);
}
if (targetType.isUnsigned()) {
return toConvert.zextOrTrunc(targetType.getIntOrFloatBitWidth());
return APSInt(toConvert.isZero() ? APInt::getZero(1)
: APInt::getAllOnes(1));
}
return toConvert.sextOrTrunc(targetType.getIntOrFloatBitWidth());
return toConvert.extOrTrunc(targetType.getIntOrFloatBitWidth());
}

static void warnAboutNaNToIntCast(DenseElementsAttr elements, CastOp location,
Expand Down Expand Up @@ -981,11 +998,11 @@ struct TosaFoldConstantCast : public TosaFoldConstantBase<CastOp> {
warnAboutNaNToIntCast(elements, tosaCast, rewriter);

// Only fold splat tensors and those used only once to avoid duplicating
// them.
// them and increasing memory consumption.
if (!inputTensor.hasOneUse() && !isa<SplatElementsAttr>(elements)) {
return rewriter.notifyMatchFailure(tosaCast,
"Currently, casts will only be folded "
"if its input only has a single user");
return rewriter.notifyMatchFailure(
tosaCast, "Currently, casts will only be folded "
"if its input only has a single user or is a splat value.");
}

// Report a match failure for unexpected types
Expand All @@ -994,28 +1011,25 @@ struct TosaFoldConstantCast : public TosaFoldConstantBase<CastOp> {
tosaCast, "Only casts from/to int/float are supported.");
}

auto isUnsigned = [](Type toCheck) {
return isa<IntegerType>(toCheck) &&
cast<IntegerType>(toCheck).isUnsigned();
};
auto typesToCheck = {toType, fromType};
if (llvm::any_of(typesToCheck, isUnsigned)) {
// TOSA spec does not allow casts from/to unsigned, but we partially do, to
// enable the folding of lowered qdq nodes
if (isa<FloatType>(fromType) && isa<IntegerType>(toType) &&
cast<IntegerType>(toType).isUnsigned()) {
// TOSA casts currently don't support unsigned integers.
// To support them by here, one could use APSInt instead of APInts,
// however, this causes trouble with `getValues` which does not support
// APSInts currently.
// Casting float to unsigned int would need a decision about how to handle
// negative floats
return rewriter.notifyMatchFailure(
tosaCast, "Cast folding from/to unsigned integers is not supported.");
tosaCast,
"Cast folding from float to unsigned integers is not supported.");
}

DenseElementsAttr res;
if (auto intOutTy = dyn_cast<IntegerType>(toType)) {
if (isa<FloatType>(fromType)) {
res = applyElementWise<APFloat, APInt, IntegerType>(
elements, &convertFloatToInt, intOutTy);
} else {
assert(isa<IntegerType>(fromType));
res = applyElementWise<APInt, APInt, IntegerType>(
res = applyElementWise<APSInt, APSInt, IntegerType>(
elements, &convertIntToInt, intOutTy);
}
} else {
Expand All @@ -1026,7 +1040,7 @@ struct TosaFoldConstantCast : public TosaFoldConstantBase<CastOp> {
elements, &convertFloatToFloat, floatOutTy);
} else {
assert(isa<IntegerType>(fromType));
res = applyElementWise<APInt, APFloat, FloatType>(
res = applyElementWise<APSInt, APFloat, FloatType>(
elements, &convertIntToFloat, floatOutTy);
}
}
Expand Down
92 changes: 92 additions & 0 deletions mlir/test/Dialect/Tosa/constant-cast-opt.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,20 @@ func.func @cast_fold_f32_to_i8() -> tensor<5xi8> {
return %1 : tensor<5xi8>
}

// CHECK-LABEL: @cast_fold_f32_to_ui8
// COM: Do not fold casts from floats to uint
func.func @cast_fold_f32_to_ui8() -> tensor<5xui8> {
// CHECK: tosa.const
// CHECK-NOT: tensor<5xui8>
// CHECK: tosa.cast
%0 = "tosa.const"() {value =
dense<[12.0, 0.0, 5.0, 32770.11, -32770.11]> :
tensor<5xf32>
} : () -> tensor<5xf32>
%1 = "tosa.cast"(%0) : (tensor<5xf32>) -> tensor<5xui8>
return %1 : tensor<5xui8>
}

// CHECK-LABEL: @cast_fold_float_to_int_infinity_zero_nan
func.func @cast_fold_float_to_int_infinity_zero_nan() -> tensor<5xi16> {
// Check if infinity and zero are translated properly. Don't expect any
Expand Down Expand Up @@ -116,6 +130,71 @@ func.func @cast_fold_i32_to_i8() -> tensor<5xi8> {
return %1 : tensor<5xi8>
}

// CHECK-LABEL: @cast_fold_i8_to_ui8
func.func @cast_fold_i8_to_ui8() -> tensor<3xui8> {
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}4, 0, 251{{.*}}tensor<3xui8>
// CHECK-NOT: tosa.cast
// CHECK: return [[RES]]
%0 = "tosa.const"() {value =
dense<[4, 0, -5]> :
tensor<3xi8>
} : () -> tensor<3xi8>
%1 = "tosa.cast"(%0) : (tensor<3xi8>) -> tensor<3xui8>
return %1 : tensor<3xui8>
}

// CHECK-LABEL: @cast_fold_ui8_to_i8
func.func @cast_fold_ui8_to_i8() -> tensor<3xi8> {
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}4, 0, -6{{.*}}tensor<3xi8>
// CHECK-NOT: tosa.cast
// CHECK: return [[RES]]
%0 = "tosa.const"() {value =
dense<[4, 0, 250]> :
tensor<3xui8>
} : () -> tensor<3xui8>
%1 = "tosa.cast"(%0) : (tensor<3xui8>) -> tensor<3xi8>
return %1 : tensor<3xi8>
}

// CHECK-LABEL: @cast_fold_ui8_to_i16
func.func @cast_fold_ui8_to_i16() -> tensor<3xi16> {
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}4, 0, 250{{.*}}tensor<3xi16>
// CHECK-NOT: tosa.cast
// CHECK: return [[RES]]
%0 = "tosa.const"() {value =
dense<[4, 0, 250]> :
tensor<3xui8>
} : () -> tensor<3xui8>
%1 = "tosa.cast"(%0) : (tensor<3xui8>) -> tensor<3xi16>
return %1 : tensor<3xi16>
}

// CHECK-LABEL: @cast_fold_ui8_to_i1
func.func @cast_fold_ui8_to_i1() -> tensor<3xi1> {
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}true, false, true{{.*}}tensor<3xi1>
// CHECK-NOT: tosa.cast
// CHECK: return [[RES]]
%0 = "tosa.const"() {value =
dense<[4, 0, 250]> :
tensor<3xui8>
} : () -> tensor<3xui8>
%1 = "tosa.cast"(%0) : (tensor<3xui8>) -> tensor<3xi1>
return %1 : tensor<3xi1>
}

// CHECK-LABEL: @cast_fold_ui8_to_ui1
func.func @cast_fold_ui8_to_ui1() -> tensor<3xui1> {
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}true, false, true{{.*}}tensor<3xui1>
// CHECK-NOT: tosa.cast
// CHECK: return [[RES]]
%0 = "tosa.const"() {value =
dense<[4, 0, 250]> :
tensor<3xui8>
} : () -> tensor<3xui8>
%1 = "tosa.cast"(%0) : (tensor<3xui8>) -> tensor<3xui1>
return %1 : tensor<3xui1>
}


// CHECK-LABEL: @cast_fold_i16_to_i1
func.func @cast_fold_i16_to_i1() -> tensor<3xi1> {
Expand Down Expand Up @@ -172,6 +251,19 @@ func.func @cast_fold_i32_to_f16() -> tensor<4xf16> {
return %1 : tensor<4xf16>
}

// CHECK-LABEL: @cast_fold_ui8_to_f32
func.func @cast_fold_ui8_to_f32() -> tensor<4xf32> {
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0.000000e+00, 1.000000e+00, 4.000000e+00, 2.550000e+02{{.*}}tensor<4xf32>
// CHECK-NOT: tosa.cast
// CHECK: return [[RES]]
%0 = "tosa.const"() {value =
dense<[0, 1, 4, 255]> :
tensor<4xui8>
} : () -> tensor<4xui8>
%1 = "tosa.cast"(%0) : (tensor<4xui8>) -> tensor<4xf32>
return %1 : tensor<4xf32>
}

// -----
// Casts from float to float

Expand Down