Skip to content

Commit 9458bae

Browse files
committed
[NVPTX] Custom lower integer<->bf16 conversions for sm_80 (#74827)
sm_80 only has f32->bf16 conversions, the remaining integer conversions arrived with sm_90. Use a two-step conversion for sm_80. There doesn't seem to be a way to express this promotion directly within the legalization framework, so fallback on Custom lowering.
1 parent f58f089 commit 9458bae

File tree

3 files changed

+154
-0
lines changed

3 files changed

+154
-0
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -766,6 +766,17 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
766766
AddPromotedToType(Op, MVT::bf16, MVT::f32);
767767
}
768768

769+
// sm_80 only has conversions between f32 and bf16. Custom lower all other
770+
// bf16 conversions.
771+
if (STI.hasBF16Math() &&
772+
(STI.getSmVersion() < 90 || STI.getPTXVersion() < 78)) {
773+
for (MVT VT : {MVT::i1, MVT::i16, MVT::i32, MVT::i64}) {
774+
setOperationAction(
775+
{ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::FP_TO_SINT, ISD::FP_TO_UINT},
776+
VT, Custom);
777+
}
778+
}
779+
769780
setOperationAction(ISD::FROUND, MVT::f16, Promote);
770781
setOperationAction(ISD::FROUND, MVT::v2f16, Expand);
771782
setOperationAction(ISD::FROUND, MVT::v2bf16, Expand);
@@ -2580,6 +2591,37 @@ SDValue NVPTXTargetLowering::LowerFROUND64(SDValue Op,
25802591
return DAG.getNode(ISD::SELECT, SL, VT, IsLarge, A, RoundedA);
25812592
}
25822593

2594+
SDValue NVPTXTargetLowering::LowerINT_TO_FP(SDValue Op,
2595+
SelectionDAG &DAG) const {
2596+
assert(STI.getSmVersion() < 90 || STI.getPTXVersion() < 78);
2597+
2598+
if (Op.getValueType() == MVT::bf16) {
2599+
SDLoc Loc(Op);
2600+
return DAG.getNode(
2601+
ISD::FP_ROUND, Loc, MVT::bf16,
2602+
DAG.getNode(Op.getOpcode(), Loc, MVT::f32, Op.getOperand(0)),
2603+
DAG.getIntPtrConstant(0, Loc));
2604+
}
2605+
2606+
// Everything else is considered legal.
2607+
return Op;
2608+
}
2609+
2610+
SDValue NVPTXTargetLowering::LowerFP_TO_INT(SDValue Op,
2611+
SelectionDAG &DAG) const {
2612+
assert(STI.getSmVersion() < 90 || STI.getPTXVersion() < 78);
2613+
2614+
if (Op.getOperand(0).getValueType() == MVT::bf16) {
2615+
SDLoc Loc(Op);
2616+
return DAG.getNode(
2617+
Op.getOpcode(), Loc, Op.getValueType(),
2618+
DAG.getNode(ISD::FP_EXTEND, Loc, MVT::f32, Op.getOperand(0)));
2619+
}
2620+
2621+
// Everything else is considered legal.
2622+
return Op;
2623+
}
2624+
25832625
static SDValue LowerVectorArith(SDValue Op, SelectionDAG &DAG) {
25842626
SDLoc DL(Op);
25852627
if (Op.getValueType() != MVT::v2i16)
@@ -2636,6 +2678,12 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
26362678
return LowerSelect(Op, DAG);
26372679
case ISD::FROUND:
26382680
return LowerFROUND(Op, DAG);
2681+
case ISD::SINT_TO_FP:
2682+
case ISD::UINT_TO_FP:
2683+
return LowerINT_TO_FP(Op, DAG);
2684+
case ISD::FP_TO_SINT:
2685+
case ISD::FP_TO_UINT:
2686+
return LowerFP_TO_INT(Op, DAG);
26392687
case ISD::VAARG:
26402688
return LowerVAARG(Op, DAG);
26412689
case ISD::VASTART:

llvm/lib/Target/NVPTX/NVPTXISelLowering.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,9 @@ class NVPTXTargetLowering : public TargetLowering {
607607
SDValue LowerFROUND32(SDValue Op, SelectionDAG &DAG) const;
608608
SDValue LowerFROUND64(SDValue Op, SelectionDAG &DAG) const;
609609

610+
SDValue LowerINT_TO_FP(SDValue Op, SelectionDAG &DAG) const;
611+
SDValue LowerFP_TO_INT(SDValue Op, SelectionDAG &DAG) const;
612+
610613
SDValue LowerLOAD(SDValue Op, SelectionDAG &DAG) const;
611614
SDValue LowerLOADi1(SDValue Op, SelectionDAG &DAG) const;
612615

llvm/test/CodeGen/NVPTX/bf16-instructions.ll

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,3 +227,106 @@ define <8 x float> @test_extload_bf16x8(ptr addrspace(3) noundef %arg) #0 {
227227
%res = fpext <8 x bfloat> %load to <8 x float>
228228
ret <8 x float> %res
229229
}
230+
231+
; CHECK-LABEL: test_fptosi_i16(
232+
; CHECK: ld.param.b16 [[A:%rs[0-9]+]], [test_fptosi_i16_param_0];
233+
; SM80: cvt.f32.bf16 [[B:%f[0-9]+]], [[A]];
234+
; SM80: cvt.rzi.s16.f32 [[C:%rs[0-9]+]], [[B]];
235+
; SM80: cvt.u32.u16 [[R:%r[0-9]+]], [[C]];
236+
; SM90: cvt.rzi.s16.bf16 [[B:%rs[0-9]+]], [[A]];
237+
; SM90: cvt.u32.u16 [[R:%r[0-9]+]], [[B]];
238+
; CHECK: st.param.b32 [func_retval0+0], [[R]];
239+
; CHECK: ret;
240+
define i16 @test_fptosi_i16(bfloat %a) {
241+
%r = fptosi bfloat %a to i16
242+
ret i16 %r
243+
}
244+
245+
; CHECK-LABEL: test_fptoui_i16(
246+
; CHECK: ld.param.b16 [[A:%rs[0-9]+]], [test_fptoui_i16_param_0];
247+
; SM80: cvt.f32.bf16 [[B:%f[0-9]+]], [[A]];
248+
; SM80: cvt.rzi.u16.f32 [[C:%rs[0-9]+]], [[B]];
249+
; SM80: cvt.u32.u16 [[R:%r[0-9]+]], [[C]];
250+
; SM90: cvt.rzi.u16.bf16 [[B:%rs[0-9]+]], [[A]];
251+
; SM90: cvt.u32.u16 [[R:%r[0-9]+]], [[B]];
252+
; CHECK: st.param.b32 [func_retval0+0], [[R]];
253+
; CHECK: ret;
254+
define i16 @test_fptoui_i16(bfloat %a) {
255+
%r = fptoui bfloat %a to i16
256+
ret i16 %r
257+
}
258+
259+
; CHECK-LABEL: test_sitofp_i16(
260+
; CHECK: ld.param.u16 [[A:%rs[0-9]+]], [test_sitofp_i16_param_0];
261+
; SM80: cvt.rn.f32.s16 [[B:%f[0-9]+]], [[A]];
262+
; SM80: cvt.rn.bf16.f32 [[R:%rs[0-9]+]], [[B]];
263+
; SM90: cvt.rn.bf16.s16 [[R:%rs[0-9]+]], [[A]];
264+
; CHECK: st.param.b16 [func_retval0+0], [[R]];
265+
; CHECK: ret;
266+
define bfloat @test_sitofp_i16(i16 %a) {
267+
%r = sitofp i16 %a to bfloat
268+
ret bfloat %r
269+
}
270+
271+
; CHECK-LABEL: test_uitofp_i8(
272+
; CHECK: ld.param.u8 %rs1, [test_uitofp_i8_param_0];
273+
; SM80: cvt.rn.f32.u16 [[B:%f[0-9]+]], [[A]];
274+
; SM80: cvt.rn.bf16.f32 [[R:%rs[0-9]+]], [[B]];
275+
; SM90: cvt.rn.bf16.u16 [[R:%rs[0-9]+]], [[A]];
276+
; CHECK: st.param.b16 [func_retval0+0], [[R]];
277+
; CHECK: ret;
278+
define bfloat @test_uitofp_i8(i8 %a) {
279+
%r = uitofp i8 %a to bfloat
280+
ret bfloat %r
281+
}
282+
283+
; CHECK-LABEL: test_uitofp_i1(
284+
; CHECK: ld.param.u8 [[A:%rs[0-9]+]], [test_uitofp_i1_param_0];
285+
; CHECK: and.b16 [[B:%rs[0-9]+]], [[A]], 1;
286+
; CHECK: setp.eq.b16 [[C:%p[0-9]+]], [[B]], 1;
287+
; CHECK: selp.u32 [[D:%r[0-9]+]], 1, 0, [[C]];
288+
; SM80: cvt.rn.f32.u32 [[E:%f[0-9]+]], [[D]];
289+
; SM80: cvt.rn.bf16.f32 [[R:%rs[0-9]+]], [[E]];
290+
; SM90: cvt.rn.bf16.u32 [[R:%rs[0-9]+]], [[D]];
291+
; CHECK: st.param.b16 [func_retval0+0], [[R]];
292+
; CHECK: ret;
293+
define bfloat @test_uitofp_i1(i1 %a) {
294+
%r = uitofp i1 %a to bfloat
295+
ret bfloat %r
296+
}
297+
298+
; CHECK-LABEL: test_uitofp_i16(
299+
; CHECK: ld.param.u16 [[A:%rs[0-9]+]], [test_uitofp_i16_param_0];
300+
; SM80: cvt.rn.f32.u16 [[B:%f[0-9]+]], [[A]];
301+
; SM80: cvt.rn.bf16.f32 [[R:%rs[0-9]+]], [[B]];
302+
; SM90: cvt.rn.bf16.u16 [[R:%rs[0-9]+]], [[A]];
303+
; CHECK: st.param.b16 [func_retval0+0], [[R]];
304+
; CHECK: ret;
305+
define bfloat @test_uitofp_i16(i16 %a) {
306+
%r = uitofp i16 %a to bfloat
307+
ret bfloat %r
308+
}
309+
310+
; CHECK-LABEL: test_uitofp_i32(
311+
; CHECK: ld.param.u32 [[A:%r[0-9]+]], [test_uitofp_i32_param_0];
312+
; SM80: cvt.rn.f32.u32 [[B:%f[0-9]+]], [[A]];
313+
; SM80: cvt.rn.bf16.f32 [[R:%rs[0-9]+]], [[B]];
314+
; SM90: cvt.rn.bf16.u32 [[R:%rs[0-9]+]], [[A]];
315+
; CHECK: st.param.b16 [func_retval0+0], [[R]];
316+
; CHECK: ret;
317+
define bfloat @test_uitofp_i32(i32 %a) {
318+
%r = uitofp i32 %a to bfloat
319+
ret bfloat %r
320+
}
321+
322+
; CHECK-LABEL: test_uitofp_i64(
323+
; CHECK: ld.param.u64 [[A:%rd[0-9]+]], [test_uitofp_i64_param_0];
324+
; SM80: cvt.rn.f32.u64 [[B:%f[0-9]+]], [[A]];
325+
; SM80: cvt.rn.bf16.f32 [[R:%rs[0-9]+]], [[B]];
326+
; SM90: cvt.rn.bf16.u64 [[R:%rs[0-9]+]], [[A]];
327+
; CHECK: st.param.b16 [func_retval0+0], [[R]];
328+
; CHECK: ret;
329+
define bfloat @test_uitofp_i64(i64 %a) {
330+
%r = uitofp i64 %a to bfloat
331+
ret bfloat %r
332+
}

0 commit comments

Comments
 (0)