Skip to content

Commit 41402c6

Browse files
authored
[RISCV][GISel] Use CCValAssign::getCustomReg for converting f16/f32<->GPR. (#105700)
This gives us much better control of the generated code for GISel. I've tried to closely match the current gisel code, but it looks like we had 2 layers of G_ANYEXT in some cases before. SelectionDAG now checks needsCustom() instead of detecting the special cases in the Bitcast handler. Unfortunately, IRTranslator for bitcast still generates copies between register classes of different sizes. Because of this we can't handle i16<->f16 bitcasts without crashing. Not sure if I should teach RISCVInstrInfo::copyPhysReg to allow copies between FPR16 and GPR or if I should convert the copies to instructions in GISel.
1 parent 3e79847 commit 41402c6

File tree

3 files changed

+92
-46
lines changed

3 files changed

+92
-46
lines changed

llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -109,15 +109,6 @@ struct RISCVOutgoingValueHandler : public CallLowering::OutgoingValueHandler {
109109

110110
void assignValueToReg(Register ValVReg, Register PhysReg,
111111
const CCValAssign &VA) override {
112-
// If we're passing a smaller fp value into a larger integer register,
113-
// anyextend before copying.
114-
if ((VA.getLocVT() == MVT::i64 && VA.getValVT() == MVT::f32) ||
115-
((VA.getLocVT() == MVT::i32 || VA.getLocVT() == MVT::i64) &&
116-
VA.getValVT() == MVT::f16)) {
117-
LLT DstTy = LLT::scalar(VA.getLocVT().getSizeInBits());
118-
ValVReg = MIRBuilder.buildAnyExt(DstTy, ValVReg).getReg(0);
119-
}
120-
121112
Register ExtReg = extendRegister(ValVReg, VA);
122113
MIRBuilder.buildCopy(PhysReg, ExtReg);
123114
MIB.addUse(PhysReg, RegState::Implicit);
@@ -126,16 +117,35 @@ struct RISCVOutgoingValueHandler : public CallLowering::OutgoingValueHandler {
126117
unsigned assignCustomValue(CallLowering::ArgInfo &Arg,
127118
ArrayRef<CCValAssign> VAs,
128119
std::function<void()> *Thunk) override {
120+
const CCValAssign &VA = VAs[0];
121+
if ((VA.getLocVT() == MVT::i64 && VA.getValVT() == MVT::f32) ||
122+
(VA.getLocVT().isInteger() && VA.getValVT() == MVT::f16)) {
123+
Register PhysReg = VA.getLocReg();
124+
125+
auto assignFunc = [=]() {
126+
auto Trunc = MIRBuilder.buildAnyExt(LLT(VA.getLocVT()), Arg.Regs[0]);
127+
MIRBuilder.buildCopy(PhysReg, Trunc);
128+
MIB.addUse(PhysReg, RegState::Implicit);
129+
};
130+
131+
if (Thunk) {
132+
*Thunk = assignFunc;
133+
return 1;
134+
}
135+
136+
assignFunc();
137+
return 1;
138+
}
139+
129140
assert(VAs.size() >= 2 && "Expected at least 2 VAs.");
130-
const CCValAssign &VALo = VAs[0];
131141
const CCValAssign &VAHi = VAs[1];
132142

133143
assert(VAHi.needsCustom() && "Value doesn't need custom handling");
134-
assert(VALo.getValNo() == VAHi.getValNo() &&
144+
assert(VA.getValNo() == VAHi.getValNo() &&
135145
"Values belong to different arguments");
136146

137-
assert(VALo.getLocVT() == MVT::i32 && VAHi.getLocVT() == MVT::i32 &&
138-
VALo.getValVT() == MVT::f64 && VAHi.getValVT() == MVT::f64 &&
147+
assert(VA.getLocVT() == MVT::i32 && VAHi.getLocVT() == MVT::i32 &&
148+
VA.getValVT() == MVT::f64 && VAHi.getValVT() == MVT::f64 &&
139149
"unexpected custom value");
140150

141151
Register NewRegs[] = {MRI.createGenericVirtualRegister(LLT::scalar(32)),
@@ -154,7 +164,7 @@ struct RISCVOutgoingValueHandler : public CallLowering::OutgoingValueHandler {
154164
}
155165

156166
auto assignFunc = [=]() {
157-
assignValueToReg(NewRegs[0], VALo.getLocReg(), VALo);
167+
assignValueToReg(NewRegs[0], VA.getLocReg(), VA);
158168
if (VAHi.isRegLoc())
159169
assignValueToReg(NewRegs[1], VAHi.getLocReg(), VAHi);
160170
};
@@ -258,16 +268,29 @@ struct RISCVIncomingValueHandler : public CallLowering::IncomingValueHandler {
258268
unsigned assignCustomValue(CallLowering::ArgInfo &Arg,
259269
ArrayRef<CCValAssign> VAs,
260270
std::function<void()> *Thunk) override {
271+
const CCValAssign &VA = VAs[0];
272+
if ((VA.getLocVT() == MVT::i64 && VA.getValVT() == MVT::f32) ||
273+
(VA.getLocVT().isInteger() && VA.getValVT() == MVT::f16)) {
274+
Register PhysReg = VA.getLocReg();
275+
276+
markPhysRegUsed(PhysReg);
277+
278+
LLT LocTy(VA.getLocVT());
279+
auto Copy = MIRBuilder.buildCopy(LocTy, PhysReg);
280+
281+
MIRBuilder.buildTrunc(Arg.Regs[0], Copy.getReg(0));
282+
return 1;
283+
}
284+
261285
assert(VAs.size() >= 2 && "Expected at least 2 VAs.");
262-
const CCValAssign &VALo = VAs[0];
263286
const CCValAssign &VAHi = VAs[1];
264287

265288
assert(VAHi.needsCustom() && "Value doesn't need custom handling");
266-
assert(VALo.getValNo() == VAHi.getValNo() &&
289+
assert(VA.getValNo() == VAHi.getValNo() &&
267290
"Values belong to different arguments");
268291

269-
assert(VALo.getLocVT() == MVT::i32 && VAHi.getLocVT() == MVT::i32 &&
270-
VALo.getValVT() == MVT::f64 && VAHi.getValVT() == MVT::f64 &&
292+
assert(VA.getLocVT() == MVT::i32 && VAHi.getLocVT() == MVT::i32 &&
293+
VA.getValVT() == MVT::f64 && VAHi.getValVT() == MVT::f64 &&
271294
"unexpected custom value");
272295

273296
Register NewRegs[] = {MRI.createGenericVirtualRegister(LLT::scalar(32)),
@@ -284,7 +307,7 @@ struct RISCVIncomingValueHandler : public CallLowering::IncomingValueHandler {
284307
const_cast<CCValAssign &>(VAHi));
285308
}
286309

287-
assignValueToReg(NewRegs[0], VALo.getLocReg(), VALo);
310+
assignValueToReg(NewRegs[0], VA.getLocReg(), VA);
288311
if (VAHi.isRegLoc())
289312
assignValueToReg(NewRegs[1], VAHi.getLocReg(), VAHi);
290313

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19226,6 +19226,19 @@ bool RISCV::CC_RISCV(const DataLayout &DL, RISCVABI::ABI ABI, unsigned ValNo,
1922619226
// similar local variables rather than directly checking against the target
1922719227
// ABI.
1922819228

19229+
ArrayRef<MCPhysReg> ArgGPRs = RISCV::getArgGPRs(ABI);
19230+
19231+
if (UseGPRForF16_F32 && (ValVT == MVT::f16 || ValVT == MVT::bf16 ||
19232+
(ValVT == MVT::f32 && XLen == 64))) {
19233+
Register Reg = State.AllocateReg(ArgGPRs);
19234+
if (Reg) {
19235+
LocVT = XLenVT;
19236+
State.addLoc(
19237+
CCValAssign::getCustomReg(ValNo, ValVT, Reg, LocVT, LocInfo));
19238+
return false;
19239+
}
19240+
}
19241+
1922919242
if (UseGPRForF16_F32 &&
1923019243
(ValVT == MVT::f16 || ValVT == MVT::bf16 || ValVT == MVT::f32)) {
1923119244
LocVT = XLenVT;
@@ -19235,8 +19248,6 @@ bool RISCV::CC_RISCV(const DataLayout &DL, RISCVABI::ABI ABI, unsigned ValNo,
1923519248
LocInfo = CCValAssign::BCvt;
1923619249
}
1923719250

19238-
ArrayRef<MCPhysReg> ArgGPRs = RISCV::getArgGPRs(ABI);
19239-
1924019251
// If this is a variadic argument, the RISC-V calling convention requires
1924119252
// that it is assigned an 'even' or 'aligned' register if it has 8-byte
1924219253
// alignment (RV32) or 16-byte alignment (RV64). An aligned register should
@@ -19483,6 +19494,17 @@ void RISCVTargetLowering::analyzeOutputArgs(
1948319494
static SDValue convertLocVTToValVT(SelectionDAG &DAG, SDValue Val,
1948419495
const CCValAssign &VA, const SDLoc &DL,
1948519496
const RISCVSubtarget &Subtarget) {
19497+
if (VA.needsCustom()) {
19498+
if (VA.getLocVT().isInteger() &&
19499+
(VA.getValVT() == MVT::f16 || VA.getValVT() == MVT::bf16))
19500+
Val = DAG.getNode(RISCVISD::FMV_H_X, DL, VA.getValVT(), Val);
19501+
else if (VA.getLocVT() == MVT::i64 && VA.getValVT() == MVT::f32)
19502+
Val = DAG.getNode(RISCVISD::FMV_W_X_RV64, DL, MVT::f32, Val);
19503+
else
19504+
llvm_unreachable("Unexpected Custom handling.");
19505+
return Val;
19506+
}
19507+
1948619508
switch (VA.getLocInfo()) {
1948719509
default:
1948819510
llvm_unreachable("Unexpected CCValAssign::LocInfo");
@@ -19491,14 +19513,7 @@ static SDValue convertLocVTToValVT(SelectionDAG &DAG, SDValue Val,
1949119513
Val = convertFromScalableVector(VA.getValVT(), Val, DAG, Subtarget);
1949219514
break;
1949319515
case CCValAssign::BCvt:
19494-
if (VA.getLocVT().isInteger() &&
19495-
(VA.getValVT() == MVT::f16 || VA.getValVT() == MVT::bf16)) {
19496-
Val = DAG.getNode(RISCVISD::FMV_H_X, DL, VA.getValVT(), Val);
19497-
} else if (VA.getLocVT() == MVT::i64 && VA.getValVT() == MVT::f32) {
19498-
Val = DAG.getNode(RISCVISD::FMV_W_X_RV64, DL, MVT::f32, Val);
19499-
} else {
19500-
Val = DAG.getNode(ISD::BITCAST, DL, VA.getValVT(), Val);
19501-
}
19516+
Val = DAG.getNode(ISD::BITCAST, DL, VA.getValVT(), Val);
1950219517
break;
1950319518
}
1950419519
return Val;
@@ -19544,6 +19559,17 @@ static SDValue convertValVTToLocVT(SelectionDAG &DAG, SDValue Val,
1954419559
const RISCVSubtarget &Subtarget) {
1954519560
EVT LocVT = VA.getLocVT();
1954619561

19562+
if (VA.needsCustom()) {
19563+
if (LocVT.isInteger() &&
19564+
(VA.getValVT() == MVT::f16 || VA.getValVT() == MVT::bf16))
19565+
Val = DAG.getNode(RISCVISD::FMV_X_ANYEXTH, DL, LocVT, Val);
19566+
else if (LocVT == MVT::i64 && VA.getValVT() == MVT::f32)
19567+
Val = DAG.getNode(RISCVISD::FMV_X_ANYEXTW_RV64, DL, MVT::i64, Val);
19568+
else
19569+
llvm_unreachable("Unexpected Custom handling.");
19570+
return Val;
19571+
}
19572+
1954719573
switch (VA.getLocInfo()) {
1954819574
default:
1954919575
llvm_unreachable("Unexpected CCValAssign::LocInfo");
@@ -19552,14 +19578,7 @@ static SDValue convertValVTToLocVT(SelectionDAG &DAG, SDValue Val,
1955219578
Val = convertToScalableVector(LocVT, Val, DAG, Subtarget);
1955319579
break;
1955419580
case CCValAssign::BCvt:
19555-
if (LocVT.isInteger() &&
19556-
(VA.getValVT() == MVT::f16 || VA.getValVT() == MVT::bf16)) {
19557-
Val = DAG.getNode(RISCVISD::FMV_X_ANYEXTH, DL, LocVT, Val);
19558-
} else if (LocVT == MVT::i64 && VA.getValVT() == MVT::f32) {
19559-
Val = DAG.getNode(RISCVISD::FMV_X_ANYEXTW_RV64, DL, MVT::i64, Val);
19560-
} else {
19561-
Val = DAG.getNode(ISD::BITCAST, DL, LocVT, Val);
19562-
}
19581+
Val = DAG.getNode(ISD::BITCAST, DL, LocVT, Val);
1956319582
break;
1956419583
}
1956519584
return Val;
@@ -19693,8 +19712,14 @@ bool RISCV::CC_RISCV_FastCC(const DataLayout &DL, RISCVABI::ABI ABI,
1969319712
(LocVT == MVT::f64 && Subtarget.is64Bit() &&
1969419713
Subtarget.hasStdExtZdinx())) {
1969519714
if (MCRegister Reg = State.AllocateReg(getFastCCArgGPRs(ABI))) {
19696-
LocInfo = CCValAssign::BCvt;
19715+
if (LocVT.getSizeInBits() != Subtarget.getXLen()) {
19716+
LocVT = Subtarget.getXLenVT();
19717+
State.addLoc(
19718+
CCValAssign::getCustomReg(ValNo, ValVT, Reg, LocVT, LocInfo));
19719+
return false;
19720+
}
1969719721
LocVT = Subtarget.getXLenVT();
19722+
LocInfo = CCValAssign::BCvt;
1969819723
State.addLoc(CCValAssign::getReg(ValNo, ValVT, Reg, LocVT, LocInfo));
1969919724
return false;
1970019725
}
@@ -20337,9 +20362,8 @@ SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI,
2033720362
Glue = RetValue2.getValue(2);
2033820363
RetValue = DAG.getNode(RISCVISD::BuildPairF64, DL, MVT::f64, RetValue,
2033920364
RetValue2);
20340-
}
20341-
20342-
RetValue = convertLocVTToValVT(DAG, RetValue, VA, DL, Subtarget);
20365+
} else
20366+
RetValue = convertLocVTToValVT(DAG, RetValue, VA, DL, Subtarget);
2034320367

2034420368
InVals.push_back(RetValue);
2034520369
}

llvm/test/CodeGen/RISCV/GlobalISel/irtranslator/calling-conv-half.ll

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1018,7 +1018,6 @@ define half @caller_half_return_stack2(half %x, half %y) nounwind {
10181018
; RV64IF-NEXT: [[ANYEXT5:%[0-9]+]]:_(s32) = G_ANYEXT [[TRUNC1]](s16)
10191019
; RV64IF-NEXT: [[ANYEXT6:%[0-9]+]]:_(s32) = G_ANYEXT [[TRUNC1]](s16)
10201020
; RV64IF-NEXT: [[ANYEXT7:%[0-9]+]]:_(s32) = G_ANYEXT [[TRUNC1]](s16)
1021-
; RV64IF-NEXT: [[ANYEXT8:%[0-9]+]]:_(s32) = G_ANYEXT [[TRUNC]](s16)
10221021
; RV64IF-NEXT: $f10_f = COPY [[ANYEXT]](s32)
10231022
; RV64IF-NEXT: $f11_f = COPY [[ANYEXT1]](s32)
10241023
; RV64IF-NEXT: $f12_f = COPY [[ANYEXT2]](s32)
@@ -1027,14 +1026,14 @@ define half @caller_half_return_stack2(half %x, half %y) nounwind {
10271026
; RV64IF-NEXT: $f15_f = COPY [[ANYEXT5]](s32)
10281027
; RV64IF-NEXT: $f16_f = COPY [[ANYEXT6]](s32)
10291028
; RV64IF-NEXT: $f17_f = COPY [[ANYEXT7]](s32)
1030-
; RV64IF-NEXT: [[ANYEXT9:%[0-9]+]]:_(s64) = G_ANYEXT [[ANYEXT8]](s32)
1031-
; RV64IF-NEXT: $x10 = COPY [[ANYEXT9]](s64)
1029+
; RV64IF-NEXT: [[ANYEXT8:%[0-9]+]]:_(s64) = G_ANYEXT [[TRUNC]](s16)
1030+
; RV64IF-NEXT: $x10 = COPY [[ANYEXT8]](s64)
10321031
; RV64IF-NEXT: PseudoCALL target-flags(riscv-call) @callee_half_return_stack2, csr_ilp32f_lp64f, implicit-def $x1, implicit $f10_f, implicit $f11_f, implicit $f12_f, implicit $f13_f, implicit $f14_f, implicit $f15_f, implicit $f16_f, implicit $f17_f, implicit $x10, implicit-def $f10_f
10331032
; RV64IF-NEXT: ADJCALLSTACKUP 0, 0, implicit-def $x2, implicit $x2
10341033
; RV64IF-NEXT: [[COPY2:%[0-9]+]]:_(s32) = COPY $f10_f
10351034
; RV64IF-NEXT: [[TRUNC2:%[0-9]+]]:_(s16) = G_TRUNC [[COPY2]](s32)
1036-
; RV64IF-NEXT: [[ANYEXT10:%[0-9]+]]:_(s32) = G_ANYEXT [[TRUNC2]](s16)
1037-
; RV64IF-NEXT: $f10_f = COPY [[ANYEXT10]](s32)
1035+
; RV64IF-NEXT: [[ANYEXT9:%[0-9]+]]:_(s32) = G_ANYEXT [[TRUNC2]](s16)
1036+
; RV64IF-NEXT: $f10_f = COPY [[ANYEXT9]](s32)
10381037
; RV64IF-NEXT: PseudoRET implicit $f10_f
10391038
;
10401039
; RV64IZFH-LABEL: name: caller_half_return_stack2

0 commit comments

Comments
 (0)