Skip to content

Commit dac1f7b

Browse files
authored
[NVPTX] fixup incorrect rounding mode for int to float conversion (#106600)
`uitofp` and `sitofp` instructions use the default rounding mode which is defined as round-to-nearest.
1 parent a1441ca commit dac1f7b

File tree

2 files changed

+24
-24
lines changed

2 files changed

+24
-24
lines changed

llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -290,15 +290,16 @@ static Instruction *simplifyNvvmIntrinsic(IntrinsicInst *II, InstCombiner &IC) {
290290
case Intrinsic::nvvm_d2ull_rz:
291291
case Intrinsic::nvvm_f2ull_rz:
292292
return {Instruction::FPToUI};
293-
case Intrinsic::nvvm_i2d_rz:
294-
case Intrinsic::nvvm_i2f_rz:
295-
case Intrinsic::nvvm_ll2d_rz:
296-
case Intrinsic::nvvm_ll2f_rz:
293+
// Integer to floating-point uses RN rounding, not RZ
294+
case Intrinsic::nvvm_i2d_rn:
295+
case Intrinsic::nvvm_i2f_rn:
296+
case Intrinsic::nvvm_ll2d_rn:
297+
case Intrinsic::nvvm_ll2f_rn:
297298
return {Instruction::SIToFP};
298-
case Intrinsic::nvvm_ui2d_rz:
299-
case Intrinsic::nvvm_ui2f_rz:
300-
case Intrinsic::nvvm_ull2d_rz:
301-
case Intrinsic::nvvm_ull2f_rz:
299+
case Intrinsic::nvvm_ui2d_rn:
300+
case Intrinsic::nvvm_ui2f_rn:
301+
case Intrinsic::nvvm_ull2d_rn:
302+
case Intrinsic::nvvm_ull2f_rn:
302303
return {Instruction::UIToFP};
303304

304305
// NVVM intrinsics that map to LLVM binary ops.

llvm/test/Transforms/InstCombine/NVPTX/nvvm-intrins.ll

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -238,49 +238,49 @@ define i64 @test_f2ull(float %a) #0 {
238238
; CHECK-LABEL: @test_i2d
239239
define double @test_i2d(i32 %a) #0 {
240240
; CHECK: sitofp i32 %a to double
241-
%ret = call double @llvm.nvvm.i2d.rz(i32 %a)
241+
%ret = call double @llvm.nvvm.i2d.rn(i32 %a)
242242
ret double %ret
243243
}
244244
; CHECK-LABEL: @test_i2f
245245
define float @test_i2f(i32 %a) #0 {
246246
; CHECK: sitofp i32 %a to float
247-
%ret = call float @llvm.nvvm.i2f.rz(i32 %a)
247+
%ret = call float @llvm.nvvm.i2f.rn(i32 %a)
248248
ret float %ret
249249
}
250250
; CHECK-LABEL: @test_ll2d
251251
define double @test_ll2d(i64 %a) #0 {
252252
; CHECK: sitofp i64 %a to double
253-
%ret = call double @llvm.nvvm.ll2d.rz(i64 %a)
253+
%ret = call double @llvm.nvvm.ll2d.rn(i64 %a)
254254
ret double %ret
255255
}
256256
; CHECK-LABEL: @test_ll2f
257257
define float @test_ll2f(i64 %a) #0 {
258258
; CHECK: sitofp i64 %a to float
259-
%ret = call float @llvm.nvvm.ll2f.rz(i64 %a)
259+
%ret = call float @llvm.nvvm.ll2f.rn(i64 %a)
260260
ret float %ret
261261
}
262262
; CHECK-LABEL: @test_ui2d
263263
define double @test_ui2d(i32 %a) #0 {
264264
; CHECK: uitofp i32 %a to double
265-
%ret = call double @llvm.nvvm.ui2d.rz(i32 %a)
265+
%ret = call double @llvm.nvvm.ui2d.rn(i32 %a)
266266
ret double %ret
267267
}
268268
; CHECK-LABEL: @test_ui2f
269269
define float @test_ui2f(i32 %a) #0 {
270270
; CHECK: uitofp i32 %a to float
271-
%ret = call float @llvm.nvvm.ui2f.rz(i32 %a)
271+
%ret = call float @llvm.nvvm.ui2f.rn(i32 %a)
272272
ret float %ret
273273
}
274274
; CHECK-LABEL: @test_ull2d
275275
define double @test_ull2d(i64 %a) #0 {
276276
; CHECK: uitofp i64 %a to double
277-
%ret = call double @llvm.nvvm.ull2d.rz(i64 %a)
277+
%ret = call double @llvm.nvvm.ull2d.rn(i64 %a)
278278
ret double %ret
279279
}
280280
; CHECK-LABEL: @test_ull2f
281281
define float @test_ull2f(i64 %a) #0 {
282282
; CHECK: uitofp i64 %a to float
283-
%ret = call float @llvm.nvvm.ull2f.rz(i64 %a)
283+
%ret = call float @llvm.nvvm.ull2f.rn(i64 %a)
284284
ret float %ret
285285
}
286286

@@ -428,10 +428,10 @@ declare float @llvm.nvvm.fmax.ftz.f(float, float)
428428
declare double @llvm.nvvm.fmin.d(double, double)
429429
declare float @llvm.nvvm.fmin.f(float, float)
430430
declare float @llvm.nvvm.fmin.ftz.f(float, float)
431-
declare double @llvm.nvvm.i2d.rz(i32)
432-
declare float @llvm.nvvm.i2f.rz(i32)
433-
declare double @llvm.nvvm.ll2d.rz(i64)
434-
declare float @llvm.nvvm.ll2f.rz(i64)
431+
declare double @llvm.nvvm.i2d.rn(i32)
432+
declare float @llvm.nvvm.i2f.rn(i32)
433+
declare double @llvm.nvvm.ll2d.rn(i64)
434+
declare float @llvm.nvvm.ll2f.rn(i64)
435435
declare double @llvm.nvvm.lohi.i2d(i32, i32)
436436
declare double @llvm.nvvm.mul.rn.d(double, double)
437437
declare float @llvm.nvvm.mul.rn.f(float, float)
@@ -450,8 +450,7 @@ declare float @llvm.nvvm.sqrt.rn.ftz.f(float)
450450
declare double @llvm.nvvm.trunc.d(double)
451451
declare float @llvm.nvvm.trunc.f(float)
452452
declare float @llvm.nvvm.trunc.ftz.f(float)
453-
declare double @llvm.nvvm.ui2d.rz(i32)
453+
declare double @llvm.nvvm.ui2d.rn(i32)
454454
declare float @llvm.nvvm.ui2f.rn(i32)
455-
declare float @llvm.nvvm.ui2f.rz(i32)
456-
declare double @llvm.nvvm.ull2d.rz(i64)
457-
declare float @llvm.nvvm.ull2f.rz(i64)
455+
declare double @llvm.nvvm.ull2d.rn(i64)
456+
declare float @llvm.nvvm.ull2f.rn(i64)

0 commit comments

Comments
 (0)