Skip to content

Commit 3c6f05b

Browse files
authored
AMDGPU: Make v2f32 -> v2f16 legal when target supports v_cvt_pk_f16_f32 (llvm#139956) (llvm#2186)
If targets support v_cvt_pk_f16_f32 instruction, v2f32 -> v2f16 should be legal. However, SelectionDAG does not allow us to specify the source type in the legalization rules. To workaround this, we make FP_ROUND Custom for v2f16 then set up v2f32 -> v2f16 to be legal during custom lowering. Fixes: SWDEV-532608 -- expected v_cvt_pk_f16_f32 was not generated. Cherry-pick f01f082 from upstreaam to mainline.
1 parent 1ce493a commit 3c6f05b

File tree

2 files changed

+26
-2
lines changed

2 files changed

+26
-2
lines changed

llvm/lib/Target/AMDGPU/SIISelLowering.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -907,6 +907,9 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
907907
setOperationAction(ISD::BUILD_VECTOR, MVT::v2bf16, Legal);
908908
}
909909

910+
if (Subtarget->hasCvtPkF16F32Inst())
911+
setOperationAction(ISD::FP_ROUND, MVT::v2f16, Custom);
912+
910913
setTargetDAGCombine({ISD::ADD,
911914
ISD::UADDO_CARRY,
912915
ISD::SUB,
@@ -6696,11 +6699,18 @@ SDValue SITargetLowering::getFPExtOrFPRound(SelectionDAG &DAG,
66966699
}
66976700

66986701
SDValue SITargetLowering::lowerFP_ROUND(SDValue Op, SelectionDAG &DAG) const {
6702+
SDValue Src = Op.getOperand(0);
6703+
EVT SrcVT = Src.getValueType();
6704+
EVT DstVT = Op.getValueType();
6705+
6706+
if (DstVT == MVT::v2f16) {
6707+
assert(Subtarget->hasCvtPkF16F32Inst() && "support v_cvt_pk_f16_f32");
6708+
return SrcVT == MVT::v2f32 ? Op : SDValue();
6709+
}
6710+
66996711
assert(Op.getValueType() == MVT::f16 &&
67006712
"Do not know how to custom lower FP_ROUND for non-f16 type");
67016713

6702-
SDValue Src = Op.getOperand(0);
6703-
EVT SrcVT = Src.getValueType();
67046714
if (SrcVT != MVT::f64)
67056715
return Op;
67066716

llvm/test/CodeGen/AMDGPU/fptrunc.v2f16.no.fast.math.ll

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,20 @@ define <2 x half> @v_test_cvt_v2f32_v2f16(<2 x float> %src) {
1212
ret <2 x half> %res
1313
}
1414

15+
define half @fptrunc_v2f32_v2f16_then_extract(<2 x float> %src) {
16+
; GFX950-LABEL: fptrunc_v2f32_v2f16_then_extract:
17+
; GFX950: ; %bb.0:
18+
; GFX950-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
19+
; GFX950-NEXT: v_cvt_pk_f16_f32 v0, v0, v1
20+
; GFX950-NEXT: v_add_f16_sdwa v0, v0, v0 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD
21+
; GFX950-NEXT: s_setpc_b64 s[30:31]
22+
%vec_half = fptrunc <2 x float> %src to <2 x half>
23+
%first = extractelement <2 x half> %vec_half, i64 1
24+
%second = extractelement <2 x half> %vec_half, i64 0
25+
%res = fadd half %first, %second
26+
ret half %res
27+
}
28+
1529
define <2 x half> @v_test_cvt_v2f64_v2f16(<2 x double> %src) {
1630
; GFX950-SDAG-LABEL: v_test_cvt_v2f64_v2f16:
1731
; GFX950-SDAG: ; %bb.0:

0 commit comments

Comments
 (0)