Skip to content

Commit 3f11b11

Browse files
authored
[FXML-4281] Remove flag for rounding mode of casting ops
2 parents 746aa23 + 10282d0 commit 3f11b11

File tree

8 files changed

+106
-126
lines changed

8 files changed

+106
-126
lines changed

mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@ class RewritePatternSet;
1414
class TypeConverter;
1515

1616
void populateArithToEmitCPatterns(TypeConverter &typeConverter,
17-
RewritePatternSet &patterns,
18-
bool optionFloatToIntTruncates);
17+
RewritePatternSet &patterns);
1918
} // namespace mlir
2019

2120
#endif // MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H

mlir/include/mlir/Conversion/Passes.td

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -139,24 +139,7 @@ def ArithToAMDGPUConversionPass : Pass<"convert-arith-to-amdgpu"> {
139139

140140
def ConvertArithToEmitC : Pass<"convert-arith-to-emitc"> {
141141
let summary = "Convert Arith dialect to EmitC dialect";
142-
let description = [{
143-
This pass converts `arith` dialect operations to `emitc`.
144-
145-
The semantics of floating-point to integer conversions `arith.fptosi`,
146-
`arith.fptoui` require rounding towards zero. Typical C++ implementations
147-
use this behavior for float-to-integer casts, but that is not mandated by
148-
C++ and there are implementation-defined means to change the default behavior.
149-
150-
If casts can be guaranteed to use round-to-zero, use the
151-
`float-to-int-truncates` flag to allow conversion of `arith.fptosi` and
152-
`arith.fptoui` operations.
153-
}];
154142
let dependentDialects = ["emitc::EmitCDialect"];
155-
let options = [
156-
Option<"floatToIntTruncates", "float-to-int-truncates", "bool",
157-
/*default=*/"false",
158-
"Whether the behavior of float-to-int cast in emitc is truncation">,
159-
];
160143
}
161144

162145
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -366,14 +366,9 @@ class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
366366
// Floating-point to integer conversions.
367367
template <typename CastOp>
368368
class FtoICastOpConversion : public OpConversionPattern<CastOp> {
369-
private:
370-
bool floatToIntTruncates;
371-
372369
public:
373-
FtoICastOpConversion(const TypeConverter &typeConverter, MLIRContext *context,
374-
bool optionFloatToIntTruncates)
375-
: OpConversionPattern<CastOp>(typeConverter, context),
376-
floatToIntTruncates(optionFloatToIntTruncates) {}
370+
FtoICastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
371+
: OpConversionPattern<CastOp>(typeConverter, context) {}
377372

378373
LogicalResult
379374
matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
@@ -384,16 +379,13 @@ class FtoICastOpConversion : public OpConversionPattern<CastOp> {
384379
return rewriter.notifyMatchFailure(castOp,
385380
"unsupported cast source type");
386381

387-
if (!floatToIntTruncates)
388-
return rewriter.notifyMatchFailure(
389-
castOp, "conversion currently requires EmitC casts to use truncation "
390-
"as rounding mode");
391-
392382
Type dstType = this->getTypeConverter()->convertType(castOp.getType());
393383
if (!dstType)
394384
return rewriter.notifyMatchFailure(castOp, "type conversion failed");
395385

396-
if (!emitc::isSupportedIntegerType(dstType))
386+
// Float-to-i1 casts are not supported: any value with 0 < value < 1 must be
387+
// truncated to 0, whereas a boolean conversion would return true.
388+
if (!emitc::isSupportedIntegerType(dstType) || dstType.isInteger(1))
397389
return rewriter.notifyMatchFailure(castOp,
398390
"unsupported cast destination type");
399391

@@ -468,8 +460,7 @@ class ItoFCastOpConversion : public OpConversionPattern<CastOp> {
468460
//===----------------------------------------------------------------------===//
469461

470462
void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
471-
RewritePatternSet &patterns,
472-
bool optionFloatToIntTruncates) {
463+
RewritePatternSet &patterns) {
473464
MLIRContext *ctx = patterns.getContext();
474465

475466
// clang-format off
@@ -488,11 +479,9 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
488479
CmpIOpConversion,
489480
SelectOpConversion,
490481
ItoFCastOpConversion<arith::SIToFPOp>,
491-
ItoFCastOpConversion<arith::UIToFPOp>
492-
>(typeConverter, ctx)
493-
.add<
482+
ItoFCastOpConversion<arith::UIToFPOp>,
494483
FtoICastOpConversion<arith::FPToSIOp>,
495484
FtoICastOpConversion<arith::FPToUIOp>
496-
>(typeConverter, ctx, optionFloatToIntTruncates);
485+
>(typeConverter, ctx);
497486
// clang-format on
498487
}

mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ void ConvertArithToEmitC::runOnOperation() {
4545
TypeConverter typeConverter;
4646
typeConverter.addConversion([](Type type) { return type; });
4747

48-
populateArithToEmitCPatterns(typeConverter, patterns, floatToIntTruncates);
48+
populateArithToEmitCPatterns(typeConverter, patterns);
4949

5050
if (failed(
5151
applyPartialConversion(getOperation(), target, std::move(patterns))))

mlir/test/Conversion/ArithToEmitC/arith-to-emitc-cast-truncate.mlir

Lines changed: 0 additions & 21 deletions
This file was deleted.

mlir/test/Conversion/ArithToEmitC/arith-to-emitc-cast-unsupported.mlir

Lines changed: 0 additions & 48 deletions
This file was deleted.

mlir/test/Conversion/ArithToEmitC/arith-to-emitc-unsupported.mlir

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,66 @@ func.func @arith_cmpf_vector(%arg0: vector<5xf32>, %arg1: vector<5xf32>) -> vect
1616

1717
// -----
1818

19-
func.func @arith_cast_f32(%arg0: f32) -> i32 {
19+
func.func @arith_cast_tensor(%arg0: tensor<5xf32>) -> tensor<5xi32> {
2020
// expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
21-
%t = arith.fptosi %arg0 : f32 to i32
21+
%t = arith.fptosi %arg0 : tensor<5xf32> to tensor<5xi32>
22+
return %t: tensor<5xi32>
23+
}
24+
25+
// -----
26+
27+
func.func @arith_cast_vector(%arg0: vector<5xf32>) -> vector<5xi32> {
28+
// expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
29+
%t = arith.fptosi %arg0 : vector<5xf32> to vector<5xi32>
30+
return %t: vector<5xi32>
31+
}
32+
33+
// -----
34+
35+
func.func @arith_cast_bf16(%arg0: bf16) -> i32 {
36+
// expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
37+
%t = arith.fptosi %arg0 : bf16 to i32
2238
return %t: i32
2339
}
40+
41+
// -----
42+
43+
func.func @arith_cast_f16(%arg0: f16) -> i32 {
44+
// expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
45+
%t = arith.fptosi %arg0 : f16 to i32
46+
return %t: i32
47+
}
48+
49+
50+
// -----
51+
52+
func.func @arith_cast_to_bf16(%arg0: i32) -> bf16 {
53+
// expected-error @+1 {{failed to legalize operation 'arith.sitofp'}}
54+
%t = arith.sitofp %arg0 : i32 to bf16
55+
return %t: bf16
56+
}
57+
58+
// -----
59+
60+
func.func @arith_cast_to_f16(%arg0: i32) -> f16 {
61+
// expected-error @+1 {{failed to legalize operation 'arith.sitofp'}}
62+
%t = arith.sitofp %arg0 : i32 to f16
63+
return %t: f16
64+
}
65+
66+
// -----
67+
68+
func.func @arith_cast_fptosi_i1(%arg0: f32) -> i1 {
69+
// expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
70+
%t = arith.fptosi %arg0 : f32 to i1
71+
return %t: i1
72+
}
73+
74+
// -----
75+
76+
func.func @arith_cast_fptoui_i1(%arg0: f32) -> i1 {
77+
// expected-error @+1 {{failed to legalize operation 'arith.fptoui'}}
78+
%t = arith.fptoui %arg0 : f32 to i1
79+
return %t: i1
80+
}
81+

mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -309,22 +309,6 @@ func.func @arith_cmpf_true(%arg0: f32, %arg1: f32) -> i1 {
309309

310310
// -----
311311

312-
func.func @arith_int_to_float_cast_ops(%arg0: i8, %arg1: i64) {
313-
// CHECK: emitc.cast %arg0 : i8 to f32
314-
%0 = arith.sitofp %arg0 : i8 to f32
315-
316-
// CHECK: emitc.cast %arg1 : i64 to f32
317-
%1 = arith.sitofp %arg1 : i64 to f32
318-
319-
// CHECK: %[[CAST_UNS:.*]] = emitc.cast %arg0 : i8 to ui8
320-
// CHECK: emitc.cast %[[CAST_UNS]] : ui8 to f32
321-
%2 = arith.uitofp %arg0 : i8 to f32
322-
323-
return
324-
}
325-
326-
// -----
327-
328312
func.func @arith_cmpi_eq(%arg0: i32, %arg1: i32) -> i1 {
329313
// CHECK-LABEL: arith_cmpi_eq
330314
// CHECK-SAME: ([[Arg0:[^ ]*]]: i32, [[Arg1:[^ ]*]]: i32)
@@ -370,3 +354,39 @@ func.func @arith_cmpi_predicates(%arg0: i32, %arg1: i32) {
370354

371355
return
372356
}
357+
358+
// -----
359+
360+
func.func @arith_float_to_int_cast_ops(%arg0: f32, %arg1: f64) {
361+
// CHECK: emitc.cast %arg0 : f32 to i32
362+
%0 = arith.fptosi %arg0 : f32 to i32
363+
364+
// CHECK: emitc.cast %arg1 : f64 to i32
365+
%1 = arith.fptosi %arg1 : f64 to i32
366+
367+
// CHECK: emitc.cast %arg0 : f32 to i16
368+
%2 = arith.fptosi %arg0 : f32 to i16
369+
370+
// CHECK: emitc.cast %arg1 : f64 to i16
371+
%3 = arith.fptosi %arg1 : f64 to i16
372+
373+
// CHECK: %[[CAST0:.*]] = emitc.cast %arg0 : f32 to ui32
374+
// CHECK: emitc.cast %[[CAST0]] : ui32 to i32
375+
%4 = arith.fptoui %arg0 : f32 to i32
376+
377+
return
378+
}
379+
380+
func.func @arith_int_to_float_cast_ops(%arg0: i8, %arg1: i64) {
381+
// CHECK: emitc.cast %arg0 : i8 to f32
382+
%0 = arith.sitofp %arg0 : i8 to f32
383+
384+
// CHECK: emitc.cast %arg1 : i64 to f32
385+
%1 = arith.sitofp %arg1 : i64 to f32
386+
387+
// CHECK: %[[CAST_UNS:.*]] = emitc.cast %arg0 : i8 to ui8
388+
// CHECK: emitc.cast %[[CAST_UNS]] : ui8 to f32
389+
%2 = arith.uitofp %arg0 : i8 to f32
390+
391+
return
392+
}

0 commit comments

Comments
 (0)