Skip to content

Commit e7e14b2

Browse files
committed
[mlir][emitc] Arith to EmitC: handle FP<->Integer conversions
1 parent 4d7f3d9 commit e7e14b2

File tree

3 files changed

+177
-1
lines changed

3 files changed

+177
-1
lines changed

mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,94 @@ class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
201201
}
202202
};
203203

204+
// Floating-point to integer conversions.
205+
template <typename CastOp>
206+
class FtoICastOpConversion : public OpConversionPattern<CastOp> {
207+
public:
208+
FtoICastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
209+
: OpConversionPattern<CastOp>(typeConverter, context) {}
210+
211+
LogicalResult
212+
matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
213+
ConversionPatternRewriter &rewriter) const override {
214+
215+
Type operandType = adaptor.getIn().getType();
216+
if (!emitc::isSupportedFloatType(operandType))
217+
return rewriter.notifyMatchFailure(castOp,
218+
"unsupported cast source type");
219+
220+
Type dstType = this->getTypeConverter()->convertType(castOp.getType());
221+
if (!dstType)
222+
return rewriter.notifyMatchFailure(castOp, "type conversion failed");
223+
224+
if (!emitc::isSupportedIntegerType(dstType))
225+
return rewriter.notifyMatchFailure(castOp,
226+
"unsupported cast destination type");
227+
228+
// Convert to unsigned if it's the "ui" variant
229+
// Signless is interpreted as signed, so no need to cast for "si"
230+
Type actualResultType = dstType;
231+
if (isa<arith::FPToUIOp>(castOp)) {
232+
actualResultType =
233+
rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
234+
/*isSigned=*/false);
235+
}
236+
237+
Value result = rewriter.create<emitc::CastOp>(
238+
castOp.getLoc(), actualResultType, adaptor.getOperands());
239+
240+
if (isa<arith::FPToUIOp>(castOp)) {
241+
result = rewriter.create<emitc::CastOp>(castOp.getLoc(), dstType, result);
242+
}
243+
rewriter.replaceOp(castOp, result);
244+
245+
return success();
246+
}
247+
};
248+
249+
// Integer to floating-point conversions.
250+
template <typename CastOp>
251+
class ItoFCastOpConversion : public OpConversionPattern<CastOp> {
252+
public:
253+
ItoFCastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
254+
: OpConversionPattern<CastOp>(typeConverter, context) {}
255+
256+
LogicalResult
257+
matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
258+
ConversionPatternRewriter &rewriter) const override {
259+
// Vectors in particular are not supported
260+
Type operandType = adaptor.getIn().getType();
261+
if (!emitc::isSupportedIntegerType(operandType))
262+
return rewriter.notifyMatchFailure(castOp,
263+
"unsupported cast source type");
264+
265+
Type dstType = this->getTypeConverter()->convertType(castOp.getType());
266+
if (!dstType)
267+
return rewriter.notifyMatchFailure(castOp, "type conversion failed");
268+
269+
if (!emitc::isSupportedFloatType(dstType))
270+
return rewriter.notifyMatchFailure(castOp,
271+
"unsupported cast destination type");
272+
273+
// Convert to unsigned if it's the "ui" variant
274+
// Signless is interpreted as signed, so no need to cast for "si"
275+
Type actualOperandType = operandType;
276+
if (isa<arith::UIToFPOp>(castOp)) {
277+
actualOperandType =
278+
rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
279+
/*isSigned=*/false);
280+
}
281+
Value fpCastOperand = adaptor.getIn();
282+
if (actualOperandType != operandType) {
283+
fpCastOperand = rewriter.template create<emitc::CastOp>(
284+
castOp.getLoc(), actualOperandType, fpCastOperand);
285+
}
286+
rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand);
287+
288+
return success();
289+
}
290+
};
291+
204292
} // namespace
205293

206294
//===----------------------------------------------------------------------===//
@@ -222,7 +310,11 @@ void mlir::populateArithToEmitCPatterns(TypeConverter &typeConverter,
222310
IntegerOpConversion<arith::MulIOp, emitc::MulOp>,
223311
IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
224312
CmpIOpConversion,
225-
SelectOpConversion
313+
SelectOpConversion,
314+
ItoFCastOpConversion<arith::SIToFPOp>,
315+
ItoFCastOpConversion<arith::UIToFPOp>,
316+
FtoICastOpConversion<arith::FPToSIOp>,
317+
FtoICastOpConversion<arith::FPToUIOp>
226318
>(typeConverter, ctx);
227319
// clang-format on
228320
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// RUN: mlir-opt -split-input-file -convert-arith-to-emitc -verify-diagnostics %s
2+
3+
func.func @arith_cast_tensor(%arg0: tensor<5xf32>) -> tensor<5xi32> {
4+
// expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
5+
%t = arith.fptosi %arg0 : tensor<5xf32> to tensor<5xi32>
6+
return %t: tensor<5xi32>
7+
}
8+
9+
// -----
10+
11+
func.func @arith_cast_vector(%arg0: vector<5xf32>) -> vector<5xi32> {
12+
// expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
13+
%t = arith.fptosi %arg0 : vector<5xf32> to vector<5xi32>
14+
return %t: vector<5xi32>
15+
}
16+
17+
// -----
18+
19+
func.func @arith_cast_bf16(%arg0: bf16) -> i32 {
20+
// expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
21+
%t = arith.fptosi %arg0 : bf16 to i32
22+
return %t: i32
23+
}
24+
25+
// -----
26+
27+
func.func @arith_cast_f16(%arg0: f16) -> i32 {
28+
// expected-error @+1 {{failed to legalize operation 'arith.fptosi'}}
29+
%t = arith.fptosi %arg0 : f16 to i32
30+
return %t: i32
31+
}
32+
33+
34+
// -----
35+
36+
func.func @arith_cast_to_bf16(%arg0: i32) -> bf16 {
37+
// expected-error @+1 {{failed to legalize operation 'arith.sitofp'}}
38+
%t = arith.sitofp %arg0 : i32 to bf16
39+
return %t: bf16
40+
}
41+
42+
// -----
43+
44+
func.func @arith_cast_to_f16(%arg0: i32) -> f16 {
45+
// expected-error @+1 {{failed to legalize operation 'arith.sitofp'}}
46+
%t = arith.sitofp %arg0 : i32 to f16
47+
return %t: f16
48+
}

mlir/test/Conversion/ArithToEmitC/arith-to-emitc.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,39 @@ func.func @arith_cmpi_predicates(%arg0: i32, %arg1: i32) {
141141

142142
return
143143
}
144+
145+
// -----
146+
147+
func.func @arith_float_to_int_cast_ops(%arg0: f32, %arg1: f64) {
148+
// CHECK: emitc.cast %arg0 : f32 to i32
149+
%0 = arith.fptosi %arg0 : f32 to i32
150+
151+
// CHECK: emitc.cast %arg1 : f64 to i32
152+
%1 = arith.fptosi %arg1 : f64 to i32
153+
154+
// CHECK: emitc.cast %arg0 : f32 to i16
155+
%2 = arith.fptosi %arg0 : f32 to i16
156+
157+
// CHECK: emitc.cast %arg1 : f64 to i16
158+
%3 = arith.fptosi %arg1 : f64 to i16
159+
160+
// CHECK: %[[CAST0:.*]] = emitc.cast %arg0 : f32 to ui32
161+
// CHECK: emitc.cast %[[CAST0]] : ui32 to i32
162+
%4 = arith.fptoui %arg0 : f32 to i32
163+
164+
return
165+
}
166+
167+
func.func @arith_int_to_float_cast_ops(%arg0: i8, %arg1: i64) {
168+
// CHECK: emitc.cast %arg0 : i8 to f32
169+
%0 = arith.sitofp %arg0 : i8 to f32
170+
171+
// CHECK: emitc.cast %arg1 : i64 to f32
172+
%1 = arith.sitofp %arg1 : i64 to f32
173+
174+
// CHECK: %[[CAST_UNS:.*]] = emitc.cast %arg0 : i8 to ui8
175+
// CHECK: emitc.cast %[[CAST_UNS]] : ui8 to f32
176+
%2 = arith.uitofp %arg0 : i8 to f32
177+
178+
return
179+
}

0 commit comments

Comments
 (0)