@@ -110,16 +110,34 @@ DenseElementsAttr applyElementWise(
110
110
// We already know the amount of values we will insert, reserve space for
111
111
// all of them to avoid dynamic resizing
112
112
transformedValues.reserve (toTransform.getNumElements ());
113
- for (auto val : toTransform.getValues <SrcValType>()) {
114
- auto transformedVal = toApply (val, targetType);
115
- transformedValues.push_back (transformedVal);
113
+ if constexpr (std::is_same_v<SrcValType, APSInt>) {
114
+ for (auto val : toTransform.getValues <APInt>()) {
115
+ auto transformedVal =
116
+ toApply (APSInt (val, toTransform.getElementType ().isUnsignedInteger ()),
117
+ targetType);
118
+ transformedValues.push_back (transformedVal);
119
+ }
120
+ } else {
121
+ for (auto val : toTransform.getValues <SrcValType>()) {
122
+ auto transformedVal = toApply (val, targetType);
123
+ transformedValues.push_back (transformedVal);
124
+ }
116
125
}
117
126
118
127
// Make sure that the output tensor has the expected output type
119
128
auto inShape = toTransform.getType ();
120
129
auto outTy = inShape.cloneWith ({}, targetType);
121
130
122
- return DenseElementsAttr::get (outTy, transformedValues);
131
+ if constexpr (std::is_same_v<TargetValType, APSInt>) {
132
+ SmallVector<APInt> transformedValuesAPInt;
133
+ transformedValuesAPInt.reserve (transformedValues.size ());
134
+ for (APSInt val : transformedValues) {
135
+ transformedValuesAPInt.emplace_back (val);
136
+ }
137
+ return DenseElementsAttr::get (outTy, transformedValuesAPInt);
138
+ } else {
139
+ return DenseElementsAttr::get (outTy, transformedValues);
140
+ }
123
141
}
124
142
125
143
template DenseElementsAttr applyElementWise<APFloat, APFloat, FloatType>(
@@ -881,10 +899,10 @@ struct TosaFoldConstantCast : public TosaFoldConstantBase<CastOp> {
881
899
882
900
using TosaFoldConstantBase::TosaFoldConstantBase;
883
901
884
- static APFloat convertIntToFloat (const APInt &toConvert,
902
+ static APFloat convertIntToFloat (const APSInt &toConvert,
885
903
FloatType targetType) {
886
904
APFloat res (targetType.getFloatSemantics ());
887
- res.convertFromAPInt (toConvert, true /* isSigned */ , tosaRoundingMode);
905
+ res.convertFromAPInt (toConvert, toConvert. isSigned () , tosaRoundingMode);
888
906
return res;
889
907
}
890
908
@@ -928,15 +946,14 @@ struct TosaFoldConstantCast : public TosaFoldConstantBase<CastOp> {
928
946
return converted;
929
947
}
930
948
931
- static APInt convertIntToInt (const APInt &toConvert, IntegerType targetType) {
949
+ static APSInt convertIntToInt (const APSInt &toConvert,
950
+ IntegerType targetType) {
932
951
// Make sure to properly translate booleans
933
952
if (targetType.getWidth () == 1 ) {
934
- return toConvert.isZero () ? APInt::getZero (1 ) : APInt::getAllOnes (1 );
935
- }
936
- if (targetType.isUnsigned ()) {
937
- return toConvert.zextOrTrunc (targetType.getIntOrFloatBitWidth ());
953
+ return APSInt (toConvert.isZero () ? APInt::getZero (1 )
954
+ : APInt::getAllOnes (1 ));
938
955
}
939
- return toConvert.sextOrTrunc (targetType.getIntOrFloatBitWidth ());
956
+ return toConvert.extOrTrunc (targetType.getIntOrFloatBitWidth ());
940
957
}
941
958
942
959
static void warnAboutNaNToIntCast (DenseElementsAttr elements, CastOp location,
@@ -981,11 +998,11 @@ struct TosaFoldConstantCast : public TosaFoldConstantBase<CastOp> {
981
998
warnAboutNaNToIntCast (elements, tosaCast, rewriter);
982
999
983
1000
// Only fold splat tensors and those used only once to avoid duplicating
984
- // them.
1001
+ // them and increasing memory consumption .
985
1002
if (!inputTensor.hasOneUse () && !isa<SplatElementsAttr>(elements)) {
986
- return rewriter.notifyMatchFailure (tosaCast,
987
- " Currently, casts will only be folded "
988
- " if its input only has a single user" );
1003
+ return rewriter.notifyMatchFailure (
1004
+ tosaCast, " Currently, casts will only be folded "
1005
+ " if its input only has a single user or is a splat value. " );
989
1006
}
990
1007
991
1008
// Report a match failure for unexpected types
@@ -994,28 +1011,25 @@ struct TosaFoldConstantCast : public TosaFoldConstantBase<CastOp> {
994
1011
tosaCast, " Only casts from/to int/float are supported." );
995
1012
}
996
1013
997
- auto isUnsigned = [](Type toCheck) {
998
- return isa<IntegerType>(toCheck) &&
999
- cast<IntegerType>(toCheck).isUnsigned ();
1000
- };
1001
- auto typesToCheck = {toType, fromType};
1002
- if (llvm::any_of (typesToCheck, isUnsigned)) {
1014
+ // TOSA spec does not allow casts from/to unsigned, but we partially do, to
1015
+ // enable the folding of lowered qdq nodes
1016
+ if (isa<FloatType>(fromType) && isa<IntegerType>(toType) &&
1017
+ cast<IntegerType>(toType).isUnsigned ()) {
1003
1018
// TOSA casts currently don't support unsigned integers.
1004
- // To support them by here, one could use APSInt instead of APInts,
1005
- // however, this causes trouble with `getValues` which does not support
1006
- // APSInts currently.
1019
+ // Casting float to unsigned int would need a decision about how to handle
1020
+ // negative floats
1007
1021
return rewriter.notifyMatchFailure (
1008
- tosaCast, " Cast folding from/to unsigned integers is not supported." );
1022
+ tosaCast,
1023
+ " Cast folding from float to unsigned integers is not supported." );
1009
1024
}
1010
-
1011
1025
DenseElementsAttr res;
1012
1026
if (auto intOutTy = dyn_cast<IntegerType>(toType)) {
1013
1027
if (isa<FloatType>(fromType)) {
1014
1028
res = applyElementWise<APFloat, APInt, IntegerType>(
1015
1029
elements, &convertFloatToInt, intOutTy);
1016
1030
} else {
1017
1031
assert (isa<IntegerType>(fromType));
1018
- res = applyElementWise<APInt, APInt , IntegerType>(
1032
+ res = applyElementWise<APSInt, APSInt , IntegerType>(
1019
1033
elements, &convertIntToInt, intOutTy);
1020
1034
}
1021
1035
} else {
@@ -1026,7 +1040,7 @@ struct TosaFoldConstantCast : public TosaFoldConstantBase<CastOp> {
1026
1040
elements, &convertFloatToFloat, floatOutTy);
1027
1041
} else {
1028
1042
assert (isa<IntegerType>(fromType));
1029
- res = applyElementWise<APInt , APFloat, FloatType>(
1043
+ res = applyElementWise<APSInt , APFloat, FloatType>(
1030
1044
elements, &convertIntToFloat, floatOutTy);
1031
1045
}
1032
1046
}
0 commit comments