Skip to content

Commit f2c715b

Browse files
committed
Merge branch 'feature/fused-ops' into bump_to_3494ee95
2 parents 6fc95f2 + e8be3be commit f2c715b

File tree

2 files changed

+103
-7
lines changed

2 files changed

+103
-7
lines changed

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -155,18 +155,43 @@ struct SelectToClampOptimization : public OpRewritePattern<tosa::SelectOp> {
155155
return rewriter.notifyMatchFailure(
156156
op, "RHS of predicate GreaterEqualOp is not a constant");
157157
}
158+
158159
auto isCompatibleSplat = [](DenseElementsAttr a,
159160
DenseElementsAttr b) -> bool {
160161
if (!a.isSplat() || !b.isSplat()) {
161162
return false;
162163
}
163-
if (llvm::isa<IntegerType>(a.getElementType())) {
164-
return a.getSplatValue<APInt>() == b.getSplatValue<APInt>();
164+
165+
auto aAsIntegerType = dyn_cast<IntegerType>(a.getElementType());
166+
auto bAsIntegerType = dyn_cast<IntegerType>(b.getElementType());
167+
if (aAsIntegerType && bAsIntegerType) {
168+
if (aAsIntegerType.getSignedness() != bAsIntegerType.getSignedness()) {
169+
return false;
170+
}
171+
172+
auto aAsAPInt = a.getSplatValue<APInt>();
173+
auto bAsAPInt = b.getSplatValue<APInt>();
174+
175+
const size_t aBitWidth = aAsAPInt.getBitWidth();
176+
const size_t bBitWidth = bAsAPInt.getBitWidth();
177+
178+
if (aBitWidth >= bBitWidth) {
179+
return aAsAPInt == (bAsIntegerType.isUnsigned()
180+
? bAsAPInt.zext(aBitWidth)
181+
: bAsAPInt.sext(aBitWidth));
182+
}
183+
return (aAsIntegerType.isUnsigned()
184+
? aAsAPInt.zext(bBitWidth)
185+
: aAsAPInt.sext(bBitWidth)) == bAsAPInt;
165186
}
166-
if (llvm::isa<FloatType>(a.getElementType())) {
167-
return a.getSplatValue<APFloat>() == b.getSplatValue<APFloat>();
187+
188+
auto aAsFloatType = dyn_cast<FloatType>(a.getElementType());
189+
auto bAsFloatType = dyn_cast<FloatType>(b.getElementType());
190+
if (!aAsFloatType || aAsFloatType != bAsFloatType) {
191+
return false;
168192
}
169-
return false; // Only int and float types are supported
193+
194+
return a.getSplatValue<APFloat>() == b.getSplatValue<APFloat>();
170195
};
171196

172197
auto onFalse = op.getOnFalse();
@@ -237,10 +262,25 @@ struct SelectToClampOptimization : public OpRewritePattern<tosa::SelectOp> {
237262
clampFloatMax = rewriter.getFloatAttr(inputElementType, splatValue);
238263
}
239264
}
265+
266+
Value input = geq.getInput1();
267+
268+
// In case they do not have same bit width, insert a cast to still be able
269+
// to do this canonicalization
270+
const size_t geqBitWidth =
271+
geq.getInput1().getType().getElementTypeBitWidth();
272+
const size_t selectBitWidth = op.getType().getElementTypeBitWidth();
273+
if (geqBitWidth != selectBitWidth) {
274+
input = rewriter.create<tosa::CastOp>(
275+
op->getLoc(),
276+
geq.getInput1().getType().clone(op.getType().getElementType()),
277+
input);
278+
}
279+
240280
rewriter.replaceOpWithNewOp<tosa::ClampOp>(
241-
op, op.getType(), geq.getInput1(),
242-
rewriter.getI64IntegerAttr(clampIntMin),
281+
op, op.getType(), input, rewriter.getI64IntegerAttr(clampIntMin),
243282
rewriter.getI64IntegerAttr(clampIntMax), clampFloatMin, clampFloatMax);
283+
244284
return success();
245285
}
246286
};

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1147,3 +1147,59 @@ func.func @canonicalize_select_lrelu_zero_pattern(%arg0: tensor<13x21x3xf32>) ->
11471147
return %3 : tensor<13x21x3xf32>
11481148
}
11491149

1150+
// -----
1151+
1152+
// CHECK-LABEL: @canonicalize_select_to_clamp_i64_and_i8_pat1
1153+
func.func @canonicalize_select_to_clamp_i64_and_i8_pat1(%arg0: tensor<13x21x3xi64>, %arg1: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> {
1154+
// CHECK: %[[VAL_1:.*]] = tosa.cast %arg{{.*}} : (tensor<13x21x3xi64>) -> tensor<13x21x3xi8>
1155+
// CHECK: %[[VAL_2:.*]] = tosa.clamp %[[VAL_1]] {max_fp = 0x7F800000 : f32, max_int = 9223372036854775807 : i64, min_fp = 0xFF800000 : f32, min_int = 42 : i64} : (tensor<13x21x3xi8>) -> tensor<13x21x3xi8>
1156+
// CHECK: return %[[VAL_2]] : tensor<13x21x3xi8>
1157+
%0 = "tosa.const"() <{value = dense<42> : tensor<13x21x3xi64>}>: () -> tensor<13x21x3xi64>
1158+
%1 = "tosa.const"() <{value = dense<42> : tensor<13x21x3xi8>}>: () -> tensor<13x21x3xi8>
1159+
%2 = tosa.greater_equal %arg0, %0: (tensor<13x21x3xi64>, tensor<13x21x3xi64>) -> tensor<13x21x3xi1>
1160+
%3 = tosa.select %2, %arg1, %1: ( tensor<13x21x3xi1>, tensor<13x21x3xi8>, tensor<13x21x3xi8>) -> tensor<13x21x3xi8>
1161+
return %3 : tensor<13x21x3xi8>
1162+
}
1163+
1164+
// -----
1165+
1166+
// CHECK-LABEL: @canonicalize_select_to_clamp_i64_and_i8_pat2
1167+
func.func @canonicalize_select_to_clamp_i64_and_i8_pat2(%arg0: tensor<13x21x3xi64>, %arg1: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> {
1168+
// CHECK: %[[VAL_1:.*]] = tosa.cast %arg{{.*}} : (tensor<13x21x3xi64>) -> tensor<13x21x3xi8>
1169+
// CHECK: %[[VAL_2:.*]] = tosa.clamp %[[VAL_1]] {max_fp = 0x7F800000 : f32, max_int = -42 : i64, min_fp = 0xFF800000 : f32, min_int = -9223372036854775808 : i64} : (tensor<13x21x3xi8>) -> tensor<13x21x3xi8>
1170+
// CHECK: return %[[VAL_2]] : tensor<13x21x3xi8>
1171+
%0 = "tosa.const"() <{value = dense<-42> : tensor<13x21x3xi64>}>: () -> tensor<13x21x3xi64>
1172+
%1 = "tosa.const"() <{value = dense<-42> : tensor<13x21x3xi8>}>: () -> tensor<13x21x3xi8>
1173+
%2 = tosa.greater_equal %arg0, %0: (tensor<13x21x3xi64>, tensor<13x21x3xi64>) -> tensor<13x21x3xi1>
1174+
%3 = tosa.select %2, %1, %arg1 : ( tensor<13x21x3xi1>, tensor<13x21x3xi8>, tensor<13x21x3xi8>) -> tensor<13x21x3xi8>
1175+
return %3 : tensor<13x21x3xi8>
1176+
}
1177+
1178+
// -----
1179+
1180+
// CHECK-LABEL: @canonicalize_select_to_clamp_i8_and_i64_pat1
1181+
func.func @canonicalize_select_to_clamp_i8_and_i64_pat1(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x21x3xi64>) -> tensor<13x21x3xi64> {
1182+
// CHECK: %[[VAL_1:.*]] = tosa.cast %arg{{.*}} : (tensor<13x21x3xi8>) -> tensor<13x21x3xi64>
1183+
// CHECK: %[[VAL_2:.*]] = tosa.clamp %[[VAL_1]] {max_fp = 0x7F800000 : f32, max_int = 9223372036854775807 : i64, min_fp = 0xFF800000 : f32, min_int = 42 : i64} : (tensor<13x21x3xi64>) -> tensor<13x21x3xi64>
1184+
// CHECK: return %[[VAL_2]] : tensor<13x21x3xi64>
1185+
%0 = "tosa.const"() <{value = dense<42> : tensor<13x21x3xi8>}>: () -> tensor<13x21x3xi8>
1186+
%1 = "tosa.const"() <{value = dense<42> : tensor<13x21x3xi64>}>: () -> tensor<13x21x3xi64>
1187+
%2 = tosa.greater_equal %arg0, %0: (tensor<13x21x3xi8>, tensor<13x21x3xi8>) -> tensor<13x21x3xi1>
1188+
%3 = tosa.select %2, %arg1, %1: ( tensor<13x21x3xi1>, tensor<13x21x3xi64>, tensor<13x21x3xi64>) -> tensor<13x21x3xi64>
1189+
return %3 : tensor<13x21x3xi64>
1190+
}
1191+
1192+
// -----
1193+
1194+
// CHECK-LABEL: @canonicalize_select_to_clamp_i8_and_i64_pat2
1195+
func.func @canonicalize_select_to_clamp_i8_and_i64_pat2(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x21x3xi64>) -> tensor<13x21x3xi64> {
1196+
// CHECK: %[[VAL_1:.*]] = tosa.cast %arg{{.*}} : (tensor<13x21x3xi8>) -> tensor<13x21x3xi64>
1197+
// CHECK: %[[VAL_2:.*]] = tosa.clamp %[[VAL_1]] {max_fp = 0x7F800000 : f32, max_int = -42 : i64, min_fp = 0xFF800000 : f32, min_int = -9223372036854775808 : i64} : (tensor<13x21x3xi64>) -> tensor<13x21x3xi64>
1198+
// CHECK: return %[[VAL_2]] : tensor<13x21x3xi64>
1199+
%0 = "tosa.const"() <{value = dense<-42> : tensor<13x21x3xi8>}>: () -> tensor<13x21x3xi8>
1200+
%1 = "tosa.const"() <{value = dense<-42> : tensor<13x21x3xi64>}>: () -> tensor<13x21x3xi64>
1201+
%2 = tosa.greater_equal %arg0, %0: (tensor<13x21x3xi8>, tensor<13x21x3xi8>) -> tensor<13x21x3xi1>
1202+
%3 = tosa.select %2, %1, %arg1: ( tensor<13x21x3xi1>, tensor<13x21x3xi64>, tensor<13x21x3xi64>) -> tensor<13x21x3xi64>
1203+
return %3 : tensor<13x21x3xi64>
1204+
}
1205+

0 commit comments

Comments
 (0)