Skip to content

Commit 1656bbb

Browse files
authored
Merge pull request #483 from Xilinx/jrickert.cast_folding
Add tosa.cast folding for unsigned integers
2 parents 350e0bd + 212603a commit 1656bbb

File tree

3 files changed

+144
-29
lines changed

3 files changed

+144
-29
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1877,6 +1877,15 @@ def Tosa_CastOp: Tosa_Op<"cast", [Pure,
18771877
| signed 16 to float | int16 | float |
18781878
| float 32 to float 64 | float32 | float64 |
18791879
| float 64 to float 32 | float64 | float32 |
1880+
1881+
AMD extensions:
1882+
| signed to unsigned | signed | unsigned|
1883+
| unsigned to signed | unsigned| signed |
1884+
| unsigned to float | unsigned| float |
1885+
- unsigned to signed integer and signed to unsigned integer:
1886+
wrap on overflow
1887+
- unsigned to float:
1888+
uses llvm's float to int conversion with TOSA rounding mode
18801889
}];
18811890

18821891
let arguments = (ins

mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp

Lines changed: 43 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -110,16 +110,34 @@ DenseElementsAttr applyElementWise(
110110
// We already know the amount of values we will insert, reserve space for
111111
// all of them to avoid dynamic resizing
112112
transformedValues.reserve(toTransform.getNumElements());
113-
for (auto val : toTransform.getValues<SrcValType>()) {
114-
auto transformedVal = toApply(val, targetType);
115-
transformedValues.push_back(transformedVal);
113+
if constexpr (std::is_same_v<SrcValType, APSInt>) {
114+
for (auto val : toTransform.getValues<APInt>()) {
115+
auto transformedVal =
116+
toApply(APSInt(val, toTransform.getElementType().isUnsignedInteger()),
117+
targetType);
118+
transformedValues.push_back(transformedVal);
119+
}
120+
} else {
121+
for (auto val : toTransform.getValues<SrcValType>()) {
122+
auto transformedVal = toApply(val, targetType);
123+
transformedValues.push_back(transformedVal);
124+
}
116125
}
117126

118127
// Make sure that the output tensor has the expected output type
119128
auto inShape = toTransform.getType();
120129
auto outTy = inShape.cloneWith({}, targetType);
121130

122-
return DenseElementsAttr::get(outTy, transformedValues);
131+
if constexpr (std::is_same_v<TargetValType, APSInt>) {
132+
SmallVector<APInt> transformedValuesAPInt;
133+
transformedValuesAPInt.reserve(transformedValues.size());
134+
for (APSInt val : transformedValues) {
135+
transformedValuesAPInt.emplace_back(val);
136+
}
137+
return DenseElementsAttr::get(outTy, transformedValuesAPInt);
138+
} else {
139+
return DenseElementsAttr::get(outTy, transformedValues);
140+
}
123141
}
124142

125143
template DenseElementsAttr applyElementWise<APFloat, APFloat, FloatType>(
@@ -881,10 +899,10 @@ struct TosaFoldConstantCast : public TosaFoldConstantBase<CastOp> {
881899

882900
using TosaFoldConstantBase::TosaFoldConstantBase;
883901

884-
static APFloat convertIntToFloat(const APInt &toConvert,
902+
static APFloat convertIntToFloat(const APSInt &toConvert,
885903
FloatType targetType) {
886904
APFloat res(targetType.getFloatSemantics());
887-
res.convertFromAPInt(toConvert, true /* isSigned */, tosaRoundingMode);
905+
res.convertFromAPInt(toConvert, toConvert.isSigned(), tosaRoundingMode);
888906
return res;
889907
}
890908

@@ -928,15 +946,14 @@ struct TosaFoldConstantCast : public TosaFoldConstantBase<CastOp> {
928946
return converted;
929947
}
930948

931-
static APInt convertIntToInt(const APInt &toConvert, IntegerType targetType) {
949+
static APSInt convertIntToInt(const APSInt &toConvert,
950+
IntegerType targetType) {
932951
// Make sure to properly translate booleans
933952
if (targetType.getWidth() == 1) {
934-
return toConvert.isZero() ? APInt::getZero(1) : APInt::getAllOnes(1);
935-
}
936-
if (targetType.isUnsigned()) {
937-
return toConvert.zextOrTrunc(targetType.getIntOrFloatBitWidth());
953+
return APSInt(toConvert.isZero() ? APInt::getZero(1)
954+
: APInt::getAllOnes(1));
938955
}
939-
return toConvert.sextOrTrunc(targetType.getIntOrFloatBitWidth());
956+
return toConvert.extOrTrunc(targetType.getIntOrFloatBitWidth());
940957
}
941958

942959
static void warnAboutNaNToIntCast(DenseElementsAttr elements, CastOp location,
@@ -981,11 +998,11 @@ struct TosaFoldConstantCast : public TosaFoldConstantBase<CastOp> {
981998
warnAboutNaNToIntCast(elements, tosaCast, rewriter);
982999

9831000
// Only fold splat tensors and those used only once to avoid duplicating
984-
// them.
1001+
// them and increasing memory consumption.
9851002
if (!inputTensor.hasOneUse() && !isa<SplatElementsAttr>(elements)) {
986-
return rewriter.notifyMatchFailure(tosaCast,
987-
"Currently, casts will only be folded "
988-
"if its input only has a single user");
1003+
return rewriter.notifyMatchFailure(
1004+
tosaCast, "Currently, casts will only be folded "
1005+
"if its input only has a single user or is a splat value.");
9891006
}
9901007

9911008
// Report a match failure for unexpected types
@@ -994,28 +1011,25 @@ struct TosaFoldConstantCast : public TosaFoldConstantBase<CastOp> {
9941011
tosaCast, "Only casts from/to int/float are supported.");
9951012
}
9961013

997-
auto isUnsigned = [](Type toCheck) {
998-
return isa<IntegerType>(toCheck) &&
999-
cast<IntegerType>(toCheck).isUnsigned();
1000-
};
1001-
auto typesToCheck = {toType, fromType};
1002-
if (llvm::any_of(typesToCheck, isUnsigned)) {
1014+
// TOSA spec does not allow casts from/to unsigned, but we partially do, to
1015+
// enable the folding of lowered qdq nodes
1016+
if (isa<FloatType>(fromType) && isa<IntegerType>(toType) &&
1017+
cast<IntegerType>(toType).isUnsigned()) {
10031018
// TOSA casts currently don't support unsigned integers.
1004-
// To support them by here, one could use APSInt instead of APInts,
1005-
// however, this causes trouble with `getValues` which does not support
1006-
// APSInts currently.
1019+
// Casting float to unsigned int would need a decision about how to handle
1020+
// negative floats
10071021
return rewriter.notifyMatchFailure(
1008-
tosaCast, "Cast folding from/to unsigned integers is not supported.");
1022+
tosaCast,
1023+
"Cast folding from float to unsigned integers is not supported.");
10091024
}
1010-
10111025
DenseElementsAttr res;
10121026
if (auto intOutTy = dyn_cast<IntegerType>(toType)) {
10131027
if (isa<FloatType>(fromType)) {
10141028
res = applyElementWise<APFloat, APInt, IntegerType>(
10151029
elements, &convertFloatToInt, intOutTy);
10161030
} else {
10171031
assert(isa<IntegerType>(fromType));
1018-
res = applyElementWise<APInt, APInt, IntegerType>(
1032+
res = applyElementWise<APSInt, APSInt, IntegerType>(
10191033
elements, &convertIntToInt, intOutTy);
10201034
}
10211035
} else {
@@ -1026,7 +1040,7 @@ struct TosaFoldConstantCast : public TosaFoldConstantBase<CastOp> {
10261040
elements, &convertFloatToFloat, floatOutTy);
10271041
} else {
10281042
assert(isa<IntegerType>(fromType));
1029-
res = applyElementWise<APInt, APFloat, FloatType>(
1043+
res = applyElementWise<APSInt, APFloat, FloatType>(
10301044
elements, &convertIntToFloat, floatOutTy);
10311045
}
10321046
}

mlir/test/Dialect/Tosa/constant-cast-opt.mlir

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,20 @@ func.func @cast_fold_f32_to_i8() -> tensor<5xi8> {
7171
return %1 : tensor<5xi8>
7272
}
7373

74+
// CHECK-LABEL: @cast_fold_f32_to_ui8
75+
// COM: Do not fold casts from floats to uint
76+
func.func @cast_fold_f32_to_ui8() -> tensor<5xui8> {
77+
// CHECK: tosa.const
78+
// CHECK-NOT: tensor<5xui8>
79+
// CHECK: tosa.cast
80+
%0 = "tosa.const"() {value =
81+
dense<[12.0, 0.0, 5.0, 32770.11, -32770.11]> :
82+
tensor<5xf32>
83+
} : () -> tensor<5xf32>
84+
%1 = "tosa.cast"(%0) : (tensor<5xf32>) -> tensor<5xui8>
85+
return %1 : tensor<5xui8>
86+
}
87+
7488
// CHECK-LABEL: @cast_fold_float_to_int_infinity_zero_nan
7589
func.func @cast_fold_float_to_int_infinity_zero_nan() -> tensor<5xi16> {
7690
// Check if infinity and zero are translated properly. Don't expect any
@@ -116,6 +130,71 @@ func.func @cast_fold_i32_to_i8() -> tensor<5xi8> {
116130
return %1 : tensor<5xi8>
117131
}
118132

133+
// CHECK-LABEL: @cast_fold_i8_to_ui8
134+
func.func @cast_fold_i8_to_ui8() -> tensor<3xui8> {
135+
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}4, 0, 251{{.*}}tensor<3xui8>
136+
// CHECK-NOT: tosa.cast
137+
// CHECK: return [[RES]]
138+
%0 = "tosa.const"() {value =
139+
dense<[4, 0, -5]> :
140+
tensor<3xi8>
141+
} : () -> tensor<3xi8>
142+
%1 = "tosa.cast"(%0) : (tensor<3xi8>) -> tensor<3xui8>
143+
return %1 : tensor<3xui8>
144+
}
145+
146+
// CHECK-LABEL: @cast_fold_ui8_to_i8
147+
func.func @cast_fold_ui8_to_i8() -> tensor<3xi8> {
148+
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}4, 0, -6{{.*}}tensor<3xi8>
149+
// CHECK-NOT: tosa.cast
150+
// CHECK: return [[RES]]
151+
%0 = "tosa.const"() {value =
152+
dense<[4, 0, 250]> :
153+
tensor<3xui8>
154+
} : () -> tensor<3xui8>
155+
%1 = "tosa.cast"(%0) : (tensor<3xui8>) -> tensor<3xi8>
156+
return %1 : tensor<3xi8>
157+
}
158+
159+
// CHECK-LABEL: @cast_fold_ui8_to_i16
160+
func.func @cast_fold_ui8_to_i16() -> tensor<3xi16> {
161+
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}4, 0, 250{{.*}}tensor<3xi16>
162+
// CHECK-NOT: tosa.cast
163+
// CHECK: return [[RES]]
164+
%0 = "tosa.const"() {value =
165+
dense<[4, 0, 250]> :
166+
tensor<3xui8>
167+
} : () -> tensor<3xui8>
168+
%1 = "tosa.cast"(%0) : (tensor<3xui8>) -> tensor<3xi16>
169+
return %1 : tensor<3xi16>
170+
}
171+
172+
// CHECK-LABEL: @cast_fold_ui8_to_i1
173+
func.func @cast_fold_ui8_to_i1() -> tensor<3xi1> {
174+
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}true, false, true{{.*}}tensor<3xi1>
175+
// CHECK-NOT: tosa.cast
176+
// CHECK: return [[RES]]
177+
%0 = "tosa.const"() {value =
178+
dense<[4, 0, 250]> :
179+
tensor<3xui8>
180+
} : () -> tensor<3xui8>
181+
%1 = "tosa.cast"(%0) : (tensor<3xui8>) -> tensor<3xi1>
182+
return %1 : tensor<3xi1>
183+
}
184+
185+
// CHECK-LABEL: @cast_fold_ui8_to_ui1
186+
func.func @cast_fold_ui8_to_ui1() -> tensor<3xui1> {
187+
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}true, false, true{{.*}}tensor<3xui1>
188+
// CHECK-NOT: tosa.cast
189+
// CHECK: return [[RES]]
190+
%0 = "tosa.const"() {value =
191+
dense<[4, 0, 250]> :
192+
tensor<3xui8>
193+
} : () -> tensor<3xui8>
194+
%1 = "tosa.cast"(%0) : (tensor<3xui8>) -> tensor<3xui1>
195+
return %1 : tensor<3xui1>
196+
}
197+
119198

120199
// CHECK-LABEL: @cast_fold_i16_to_i1
121200
func.func @cast_fold_i16_to_i1() -> tensor<3xi1> {
@@ -172,6 +251,19 @@ func.func @cast_fold_i32_to_f16() -> tensor<4xf16> {
172251
return %1 : tensor<4xf16>
173252
}
174253

254+
// CHECK-LABEL: @cast_fold_ui8_to_f32
255+
func.func @cast_fold_ui8_to_f32() -> tensor<4xf32> {
256+
// CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}0.000000e+00, 1.000000e+00, 4.000000e+00, 2.550000e+02{{.*}}tensor<4xf32>
257+
// CHECK-NOT: tosa.cast
258+
// CHECK: return [[RES]]
259+
%0 = "tosa.const"() {value =
260+
dense<[0, 1, 4, 255]> :
261+
tensor<4xui8>
262+
} : () -> tensor<4xui8>
263+
%1 = "tosa.cast"(%0) : (tensor<4xui8>) -> tensor<4xf32>
264+
return %1 : tensor<4xf32>
265+
}
266+
175267
// -----
176268
// Casts from float to float
177269

0 commit comments

Comments
 (0)