Skip to content

Commit 81e9f52

Browse files
committed
[mlir][emitc] Arith to EmitC: handle FP<->Integer conversions
1 parent 5e4a443 commit 81e9f52

File tree

8 files changed

+194
-5
lines changed

8 files changed

+194
-5
lines changed

mlir/include/mlir/Conversion/ArithToEmitC/ArithToEmitC.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ class RewritePatternSet;
1414
class TypeConverter;
1515

1616
void populateArithToEmitCPatterns(TypeConverter &typeConverter,
17-
RewritePatternSet &patterns);
17+
RewritePatternSet &patterns,
18+
bool optionFloatToIntTruncates);
1819
} // namespace mlir
1920

2021
#endif // MLIR_CONVERSION_ARITHTOEMITC_ARITHTOEMITC_H

mlir/include/mlir/Conversion/Passes.td

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,24 @@ def ArithToAMDGPUConversionPass : Pass<"convert-arith-to-amdgpu"> {
139139

140140
def ConvertArithToEmitC : Pass<"convert-arith-to-emitc"> {
141141
let summary = "Convert Arith dialect to EmitC dialect";
142+
let description = [{
143+
This pass converts `arith` dialect operations to `emitc`.
144+
145+
The semantics of floating-point to integer conversions `arith.fptosi`,
146+
`arith.fptoui` require rounding towards zero. Typical C++ implementations
147+
use this behavior for float-to-integer casts, but that is not mandated by
148+
C++ and there are implementation-defined means to change the default behavior.
149+
150+
If casts can be guaranteed to use round-to-zero, use the
151+
`float-to-int-truncates` flag to allow conversion of `arith.fptosi` and
152+
`arith.fptoui` operations.
153+
}];
142154
let dependentDialects = ["emitc::EmitCDialect"];
155+
let options = [
156+
Option<"floatToIntTruncates", "float-to-int-truncates", "bool",
157+
/*default=*/"false",
158+
"Whether the behavior of float-to-int cast in emitc is truncation">,
159+
];
143160
}
144161

145162
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp

Lines changed: 82 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,14 +128,87 @@ class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
128128
}
129129
};
130130

131+
// Floating-point to integer conversions.
132+
template <typename CastOp>
133+
class FtoICastOpConversion : public OpConversionPattern<CastOp> {
134+
private:
135+
bool floatToIntTruncates;
136+
137+
public:
138+
FtoICastOpConversion(const TypeConverter &typeConverter, MLIRContext *context,
139+
bool optionFloatToIntTruncates)
140+
: OpConversionPattern<CastOp>(typeConverter, context),
141+
floatToIntTruncates(optionFloatToIntTruncates) {}
142+
143+
LogicalResult
144+
matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
145+
ConversionPatternRewriter &rewriter) const override {
146+
147+
Type operandType = adaptor.getIn().getType();
148+
if (!emitc::isSupportedFloatType(operandType))
149+
return rewriter.notifyMatchFailure(castOp,
150+
"unsupported cast source type");
151+
152+
if (!floatToIntTruncates)
153+
return rewriter.notifyMatchFailure(
154+
castOp, "conversion currently requires EmitC casts to use truncation "
155+
"as rounding mode");
156+
157+
Type dstType = this->getTypeConverter()->convertType(castOp.getType());
158+
if (!dstType)
159+
return rewriter.notifyMatchFailure(castOp, "type conversion failed");
160+
161+
if (!emitc::isSupportedIntegerType(dstType))
162+
return rewriter.notifyMatchFailure(castOp,
163+
"unsupported cast destination type");
164+
165+
rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType,
166+
adaptor.getOperands());
167+
168+
return success();
169+
}
170+
};
171+
172+
// Integer to floating-point conversions.
173+
template <typename CastOp>
174+
class ItoFCastOpConversion : public OpConversionPattern<CastOp> {
175+
public:
176+
ItoFCastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
177+
: OpConversionPattern<CastOp>(typeConverter, context) {}
178+
179+
LogicalResult
180+
matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
181+
ConversionPatternRewriter &rewriter) const override {
182+
183+
Type operandType = adaptor.getIn().getType();
184+
if (!emitc::isSupportedIntegerType(operandType))
185+
return rewriter.notifyMatchFailure(castOp,
186+
"unsupported cast source type");
187+
188+
Type dstType = this->getTypeConverter()->convertType(castOp.getType());
189+
if (!dstType)
190+
return rewriter.notifyMatchFailure(castOp, "type conversion failed");
191+
192+
if (!emitc::isSupportedFloatType(dstType))
193+
return rewriter.notifyMatchFailure(castOp,
194+
"unsupported cast destination type");
195+
196+
rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType,
197+
adaptor.getOperands());
198+
199+
return success();
200+
}
201+
};
202+
131203
} // namespace
132204

133205
//===----------------------------------------------------------------------===//
134206
// Pattern population
135207
//===----------------------------------------------------------------------===//
136208

137209
void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
138-
RewritePatternSet &patterns) {
210+
RewritePatternSet &patterns,
211+
bool optionFloatToIntTruncates) {
139212
MLIRContext *ctx = patterns.getContext();
140213

141214
// clang-format off
@@ -148,7 +221,13 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
148221
IntegerOpConversion<arith::AddIOp, emitc::AddOp>,
149222
IntegerOpConversion<arith::MulIOp, emitc::MulOp>,
150223
IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
151-
SelectOpConversion
152-
>(typeConverter, ctx);
224+
SelectOpConversion,
225+
ItoFCastOpConversion<arith::SIToFPOp>,
226+
ItoFCastOpConversion<arith::UIToFPOp>
227+
>(typeConverter, ctx)
228+
.add<
229+
FtoICastOpConversion<arith::FPToSIOp>,
230+
FtoICastOpConversion<arith::FPToUIOp>
231+
>(typeConverter, ctx, optionFloatToIntTruncates);
153232
// clang-format on
154233
}

mlir/lib/Conversion/ArithToEmitC/ArithToEmitCPass.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ using namespace mlir;
2929
namespace {
3030
struct ConvertArithToEmitC
3131
: public impl::ConvertArithToEmitCBase<ConvertArithToEmitC> {
32+
using Base::Base;
33+
3234
void runOnOperation() override;
3335
};
3436
} // namespace
@@ -44,7 +46,7 @@ void ConvertArithToEmitC::runOnOperation() {
4446
TypeConverter typeConverter;
4547
typeConverter.addConversion([](Type type) { return type; });
4648

47-
populateArithToEmitCPatterns(typeConverter, patterns);
49+
populateArithToEmitCPatterns(typeConverter, patterns, floatToIntTruncates);
4850

4951
if (failed(
5052
applyPartialConversion(getOperation(), target, std::move(patterns))))
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// RUN: mlir-opt -split-input-file --pass-pipeline="builtin.module(convert-arith-to-emitc{float-to-int-truncates})" %s | FileCheck %s
2+
3+
func.func @arith_float_to_int_cast_ops(%arg0: f32, %arg1: f64) {
4+
// CHECK: emitc.cast %arg0 : f32 to i32
5+
%0 = arith.fptosi %arg0 : f32 to i32
6+
7+
// CHECK: emitc.cast %arg1 : f64 to i32
8+
%1 = arith.fptosi %arg1 : f64 to i32
9+
10+
// CHECK: emitc.cast %arg0 : f32 to i16
11+
%2 = arith.fptosi %arg0 : f32 to i16
12+
13+
// CHECK: emitc.cast %arg1 : f64 to i16
14+
%3 = arith.fptosi %arg1 : f64 to i16
15+
16+
// CHECK: emitc.cast %arg0 : f32 to i32
17+
%4 = arith.fptoui %arg0 : f32 to i32
18+
19+
return
20+
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// RUN: mlir-opt -split-input-file --pass-pipeline="builtin.module(convert-arith-to-emitc{float-to-int-truncates})" -verify-diagnostics %s
2+
3+
func.func @arith_cast_tensor(%arg0: tensor<5xf32>) -> tensor<5xi32> {
4+
// expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
5+
%t = arith.fptosi %arg0 : tensor<5xf32> to tensor<5xi32>
6+
return %t: tensor<5xi32>
7+
}
8+
9+
// -----
10+
11+
func.func @arith_cast_vector(%arg0: vector<5xf32>) -> vector<5xi32> {
12+
// expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
13+
%t = arith.fptosi %arg0 : vector<5xf32> to vector<5xi32>
14+
return %t: vector<5xi32>
15+
}
16+
17+
// -----
18+
19+
func.func @arith_cast_bf16(%arg0: bf16) -> i32 {
20+
// expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
21+
%t = arith.fptosi %arg0 : bf16 to i32
22+
return %t: i32
23+
}
24+
25+
// -----
26+
27+
func.func @arith_cast_f16(%arg0: f16) -> i32 {
28+
// expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
29+
%t = arith.fptosi %arg0 : f16 to i32
30+
return %t: i32
31+
}
32+
33+
34+
// -----
35+
36+
func.func @arith_cast_to_bf16(%arg0: i32) -> bf16 {
37+
// expected-error @+1 {{failed to legalize operation 'arith.sitofp'}}
38+
%t = arith.sitofp %arg0 : i32 to bf16
39+
return %t: bf16
40+
}
41+
42+
// -----
43+
44+
func.func @arith_cast_to_f16(%arg0: i32) -> f16 {
45+
// expected-error @+1 {{failed to legalize operation 'arith.sitofp'}}
46+
%t = arith.sitofp %arg0 : i32 to f16
47+
return %t: f16
48+
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
// RUN: mlir-opt -split-input-file -convert-arith-to-emitc -verify-diagnostics %s
2+
3+
func.func @arith_cast_f32(%arg0: f32) -> i32 {
4+
// expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
5+
%t = arith.fptosi %arg0 : f32 to i32
6+
return %t: i32
7+
}

mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,18 @@ func.func @arith_select(%arg0: i1, %arg1: tensor<8xi32>, %arg2: tensor<8xi32>) -
9393
%0 = arith.select %arg0, %arg1, %arg2 : i1, tensor<8xi32>
9494
return
9595
}
96+
97+
// -----
98+
99+
func.func @arith_int_to_float_cast_ops(%arg0: i8, %arg1: i64) {
100+
// CHECK: emitc.cast %arg0 : i8 to f32
101+
%0 = arith.sitofp %arg0 : i8 to f32
102+
103+
// CHECK: emitc.cast %arg1 : i64 to f32
104+
%1 = arith.sitofp %arg1 : i64 to f32
105+
106+
// CHECK: emitc.cast %arg0 : i8 to f32
107+
%2 = arith.uitofp %arg0 : i8 to f32
108+
109+
return
110+
}

0 commit comments

Comments
 (0)