Skip to content

Commit 28afbdd

Browse files
committed
rebase
Created using spr 1.3.4
2 parents 019c682 + c04cdc2 commit 28afbdd

31 files changed

+8999
-10755
lines changed

llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3199,7 +3199,16 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
31993199
return true;
32003200
}
32013201
break;
3202-
case ISD::FP_ROUND:
3202+
case ISD::FP_ROUND: {
3203+
EVT VT = Node->getValueType(0);
3204+
if (VT.getScalarType() == MVT::bf16) {
3205+
Results.push_back(
3206+
DAG.getNode(ISD::FP_TO_BF16, SDLoc(Node), VT, Node->getOperand(0)));
3207+
break;
3208+
}
3209+
3210+
LLVM_FALLTHROUGH;
3211+
}
32033212
case ISD::BITCAST:
32043213
if ((Tmp1 = EmitStackConvert(Node->getOperand(0), Node->getValueType(0),
32053214
Node->getValueType(0), dl)))
@@ -3226,12 +3235,19 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
32263235
return true;
32273236
}
32283237
break;
3229-
case ISD::FP_EXTEND:
3230-
if ((Tmp1 = EmitStackConvert(Node->getOperand(0),
3231-
Node->getOperand(0).getValueType(),
3232-
Node->getValueType(0), dl)))
3238+
case ISD::FP_EXTEND: {
3239+
SDValue Op = Node->getOperand(0);
3240+
EVT SrcVT = Op.getValueType();
3241+
EVT DstVT = Node->getValueType(0);
3242+
if (SrcVT.getScalarType() == MVT::bf16) {
3243+
Results.push_back(DAG.getNode(ISD::BF16_TO_FP, SDLoc(Node), DstVT, Op));
3244+
break;
3245+
}
3246+
3247+
if ((Tmp1 = EmitStackConvert(Op, SrcVT, DstVT, dl)))
32333248
Results.push_back(Tmp1);
32343249
break;
3250+
}
32353251
case ISD::BF16_TO_FP: {
32363252
// Always expand bf16 to f32 casts, they lower to ext + shift.
32373253
//

llvm/lib/Target/AMDGPU/AMDGPUCallingConv.td

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,28 +22,28 @@ def CC_SI_Gfx : CallingConv<[
2222
// 32 is reserved for the stack pointer
2323
// 33 is reserved for the frame pointer
2424
// 34 is reserved for the base pointer
25-
CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<[
25+
CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<[
2626
SGPR4, SGPR5, SGPR6, SGPR7,
2727
SGPR8, SGPR9, SGPR10, SGPR11, SGPR12, SGPR13, SGPR14, SGPR15,
2828
SGPR16, SGPR17, SGPR18, SGPR19, SGPR20, SGPR21, SGPR22, SGPR23,
2929
SGPR24, SGPR25, SGPR26, SGPR27, SGPR28, SGPR29
3030
]>>>,
3131

32-
CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<[
32+
CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<[
3333
VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
3434
VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
3535
VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
3636
VGPR24, VGPR25, VGPR26, VGPR27, VGPR28, VGPR29, VGPR30, VGPR31
3737
]>>>,
3838

39-
CCIfType<[i32, f32, v2i16, v2f16, i16, f16, i1], CCAssignToStack<4, 4>>
39+
CCIfType<[i32, f32, v2i16, v2f16, i16, f16, i1, bf16, v2bf16], CCAssignToStack<4, 4>>
4040
]>;
4141

4242
def RetCC_SI_Gfx : CallingConv<[
4343
CCIfType<[i1], CCPromoteToType<i32>>,
4444
CCIfType<[i1, i16], CCIfExtend<CCPromoteToType<i32>>>,
4545

46-
CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<[
46+
CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<[
4747
VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
4848
VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
4949
VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
@@ -66,7 +66,7 @@ def RetCC_SI_Gfx : CallingConv<[
6666

6767
def CC_SI_SHADER : CallingConv<[
6868

69-
CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<[
69+
CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<[
7070
SGPR0, SGPR1, SGPR2, SGPR3, SGPR4, SGPR5, SGPR6, SGPR7,
7171
SGPR8, SGPR9, SGPR10, SGPR11, SGPR12, SGPR13, SGPR14, SGPR15,
7272
SGPR16, SGPR17, SGPR18, SGPR19, SGPR20, SGPR21, SGPR22, SGPR23,
@@ -76,7 +76,7 @@ def CC_SI_SHADER : CallingConv<[
7676
]>>>,
7777

7878
// 32*4 + 4 is the minimum for a fetch shader consumer with 32 inputs.
79-
CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<[
79+
CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<[
8080
VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
8181
VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
8282
VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
@@ -109,7 +109,7 @@ def RetCC_SI_Shader : CallingConv<[
109109
]>>,
110110

111111
// 32*4 + 4 is the minimum for a fetch shader with 32 outputs.
112-
CCIfType<[f32, f16, v2f16] , CCAssignToReg<[
112+
CCIfType<[f32, f16, v2f16, bf16, v2bf16] , CCAssignToReg<[
113113
VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
114114
VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
115115
VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
@@ -188,23 +188,23 @@ def CC_AMDGPU_Func : CallingConv<[
188188
CCIfType<[i1], CCPromoteToType<i32>>,
189189
CCIfType<[i8, i16], CCIfExtend<CCPromoteToType<i32>>>,
190190

191-
CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<
191+
CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<
192192
!foreach(i, !range(0, 30), !cast<Register>("SGPR"#i)) // SGPR0-29
193193
>>>,
194194

195-
CCIfType<[i32, f32, i16, f16, v2i16, v2f16, i1], CCAssignToReg<[
195+
CCIfType<[i32, f32, i16, f16, v2i16, v2f16, i1, bf16, v2bf16], CCAssignToReg<[
196196
VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
197197
VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
198198
VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
199199
VGPR24, VGPR25, VGPR26, VGPR27, VGPR28, VGPR29, VGPR30, VGPR31]>>,
200-
CCIfType<[i32, f32, v2i16, v2f16, i16, f16, i1], CCAssignToStack<4, 4>>
200+
CCIfType<[i32, f32, v2i16, v2f16, i16, f16, i1, bf16, v2bf16], CCAssignToStack<4, 4>>
201201
]>;
202202

203203
// Calling convention for leaf functions
204204
def RetCC_AMDGPU_Func : CallingConv<[
205205
CCIfType<[i1], CCPromoteToType<i32>>,
206206
CCIfType<[i1, i16], CCIfExtend<CCPromoteToType<i32>>>,
207-
CCIfType<[i32, f32, i16, f16, v2i16, v2f16], CCAssignToReg<[
207+
CCIfType<[i32, f32, i16, f16, v2i16, v2f16, bf16, v2bf16], CCAssignToReg<[
208208
VGPR0, VGPR1, VGPR2, VGPR3, VGPR4, VGPR5, VGPR6, VGPR7,
209209
VGPR8, VGPR9, VGPR10, VGPR11, VGPR12, VGPR13, VGPR14, VGPR15,
210210
VGPR16, VGPR17, VGPR18, VGPR19, VGPR20, VGPR21, VGPR22, VGPR23,
@@ -223,11 +223,11 @@ def CC_AMDGPU : CallingConv<[
223223
]>;
224224

225225
def CC_AMDGPU_CS_CHAIN : CallingConv<[
226-
CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<
226+
CCIfInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<
227227
!foreach(i, !range(105), !cast<Register>("SGPR"#i))
228228
>>>,
229229

230-
CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16] , CCAssignToReg<
230+
CCIfNotInReg<CCIfType<[f32, i32, f16, i16, v2i16, v2f16, bf16, v2bf16] , CCAssignToReg<
231231
!foreach(i, !range(8, 255), !cast<Register>("VGPR"#i))
232232
>>>
233233
]>;

llvm/lib/Target/AMDGPU/AMDGPUISelDAGToDAG.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,7 @@ void AMDGPUDAGToDAGISel::PreprocessISelDAG() {
303303

304304
switch (N->getOpcode()) {
305305
case ISD::BUILD_VECTOR:
306+
// TODO: Match load d16 from shl (extload:i16), 16
306307
MadeChange |= matchLoadD16FromBuildVector(N);
307308
break;
308309
default:

llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3281,7 +3281,15 @@ SDValue AMDGPUTargetLowering::LowerUINT_TO_FP(SDValue Op,
32813281
return DAG.getNode(ISD::UINT_TO_FP, DL, DestVT, Ext);
32823282
}
32833283

3284-
assert(SrcVT == MVT::i64 && "operation should be legal");
3284+
if (DestVT == MVT::bf16) {
3285+
SDLoc SL(Op);
3286+
SDValue ToF32 = DAG.getNode(ISD::UINT_TO_FP, SL, MVT::f32, Src);
3287+
SDValue FPRoundFlag = DAG.getIntPtrConstant(0, SL, /*isTarget=*/true);
3288+
return DAG.getNode(ISD::FP_ROUND, SL, MVT::bf16, ToF32, FPRoundFlag);
3289+
}
3290+
3291+
if (SrcVT != MVT::i64)
3292+
return Op;
32853293

32863294
if (Subtarget->has16BitInsts() && DestVT == MVT::f16) {
32873295
SDLoc DL(Op);
@@ -3319,7 +3327,15 @@ SDValue AMDGPUTargetLowering::LowerSINT_TO_FP(SDValue Op,
33193327
return DAG.getNode(ISD::SINT_TO_FP, DL, DestVT, Ext);
33203328
}
33213329

3322-
assert(SrcVT == MVT::i64 && "operation should be legal");
3330+
if (DestVT == MVT::bf16) {
3331+
SDLoc SL(Op);
3332+
SDValue ToF32 = DAG.getNode(ISD::SINT_TO_FP, SL, MVT::f32, Src);
3333+
SDValue FPRoundFlag = DAG.getIntPtrConstant(0, SL, /*isTarget=*/true);
3334+
return DAG.getNode(ISD::FP_ROUND, SL, MVT::bf16, ToF32, FPRoundFlag);
3335+
}
3336+
3337+
if (SrcVT != MVT::i64)
3338+
return Op;
33233339

33243340
// TODO: Factor out code common with LowerUINT_TO_FP.
33253341

@@ -3517,7 +3533,7 @@ SDValue AMDGPUTargetLowering::LowerFP_TO_FP16(SDValue Op, SelectionDAG &DAG) con
35173533
return DAG.getZExtOrTrunc(V, DL, Op.getValueType());
35183534
}
35193535

3520-
SDValue AMDGPUTargetLowering::LowerFP_TO_INT(SDValue Op,
3536+
SDValue AMDGPUTargetLowering::LowerFP_TO_INT(const SDValue Op,
35213537
SelectionDAG &DAG) const {
35223538
SDValue Src = Op.getOperand(0);
35233539
unsigned OpOpcode = Op.getOpcode();
@@ -3528,6 +3544,12 @@ SDValue AMDGPUTargetLowering::LowerFP_TO_INT(SDValue Op,
35283544
if (SrcVT == MVT::f16 && DestVT == MVT::i16)
35293545
return Op;
35303546

3547+
if (SrcVT == MVT::bf16) {
3548+
SDLoc DL(Op);
3549+
SDValue PromotedSrc = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, Src);
3550+
return DAG.getNode(Op.getOpcode(), DL, DestVT, PromotedSrc);
3551+
}
3552+
35313553
// Promote i16 to i32
35323554
if (DestVT == MVT::i16 && (SrcVT == MVT::f32 || SrcVT == MVT::f64)) {
35333555
SDLoc DL(Op);
@@ -3536,6 +3558,9 @@ SDValue AMDGPUTargetLowering::LowerFP_TO_INT(SDValue Op,
35363558
return DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FpToInt32);
35373559
}
35383560

3561+
if (DestVT != MVT::i64)
3562+
return Op;
3563+
35393564
if (SrcVT == MVT::f16 ||
35403565
(SrcVT == MVT::f32 && Src.getOpcode() == ISD::FP16_TO_FP)) {
35413566
SDLoc DL(Op);
@@ -3546,7 +3571,7 @@ SDValue AMDGPUTargetLowering::LowerFP_TO_INT(SDValue Op,
35463571
return DAG.getNode(Ext, DL, MVT::i64, FpToInt32);
35473572
}
35483573

3549-
if (DestVT == MVT::i64 && (SrcVT == MVT::f32 || SrcVT == MVT::f64))
3574+
if (SrcVT == MVT::f32 || SrcVT == MVT::f64)
35503575
return LowerFP_TO_INT64(Op, DAG, OpOpcode == ISD::FP_TO_SINT);
35513576

35523577
return SDValue();
@@ -4947,7 +4972,9 @@ SDValue AMDGPUTargetLowering::PerformDAGCombine(SDNode *N,
49474972
// vnt1 = build_vector (t1 (bitcast t0:x)), (t1 (bitcast t0:y))
49484973
if (DestVT.isVector()) {
49494974
SDValue Src = N->getOperand(0);
4950-
if (Src.getOpcode() == ISD::BUILD_VECTOR) {
4975+
if (Src.getOpcode() == ISD::BUILD_VECTOR &&
4976+
(DCI.getDAGCombineLevel() < AfterLegalizeDAG ||
4977+
isOperationLegal(ISD::BUILD_VECTOR, DestVT))) {
49514978
EVT SrcVT = Src.getValueType();
49524979
unsigned NElts = DestVT.getVectorNumElements();
49534980

0 commit comments

Comments
 (0)