Skip to content

Commit f450408

Browse files
spaitsGabor Spaits
andauthored
[GISel][RISCV]Implement indirect parameter passing (llvm#95429)
Some targets like RISC-V pass scalars wider than 2×XLEN bits by reference, so those arguments are replaced in the argument list with an address (See RISC-V ABIs Specification 1.0 section 2.1). This commit implements this indirect parameter passing in GlobalISel. --------- Co-authored-by: Gabor Spaits <[email protected]>
1 parent 5b00758 commit f450408

File tree

4 files changed

+898
-32
lines changed

4 files changed

+898
-32
lines changed

llvm/lib/CodeGen/GlobalISel/CallLowering.cpp

Lines changed: 101 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ void CallLowering::anchor() {}
3636
static void
3737
addFlagsUsingAttrFn(ISD::ArgFlagsTy &Flags,
3838
const std::function<bool(Attribute::AttrKind)> &AttrFn) {
39+
// TODO: There are missing flags. Add them here.
3940
if (AttrFn(Attribute::SExt))
4041
Flags.setSExt();
4142
if (AttrFn(Attribute::ZExt))
@@ -756,6 +757,8 @@ bool CallLowering::handleAssignments(ValueHandler &Handler,
756757
continue;
757758
}
758759

760+
auto AllocaAddressSpace = MF.getDataLayout().getAllocaAddrSpace();
761+
759762
const MVT ValVT = VA.getValVT();
760763
const MVT LocVT = VA.getLocVT();
761764

@@ -764,6 +767,8 @@ bool CallLowering::handleAssignments(ValueHandler &Handler,
764767
const LLT NewLLT = Handler.isIncomingArgumentHandler() ? LocTy : ValTy;
765768
const EVT OrigVT = EVT::getEVT(Args[i].Ty);
766769
const LLT OrigTy = getLLTForType(*Args[i].Ty, DL);
770+
const LLT PointerTy = LLT::pointer(
771+
AllocaAddressSpace, DL.getPointerSizeInBits(AllocaAddressSpace));
767772

768773
// Expected to be multiple regs for a single incoming arg.
769774
// There should be Regs.size() ArgLocs per argument.
@@ -778,31 +783,76 @@ bool CallLowering::handleAssignments(ValueHandler &Handler,
778783
// intermediate values.
779784
Args[i].Regs.resize(NumParts);
780785

781-
// For each split register, create and assign a vreg that will store
782-
// the incoming component of the larger value. These will later be
783-
// merged to form the final vreg.
784-
for (unsigned Part = 0; Part < NumParts; ++Part)
785-
Args[i].Regs[Part] = MRI.createGenericVirtualRegister(NewLLT);
786+
// When we have indirect parameter passing we are receiving a pointer,
787+
// that points to the actual value, so we need one "temporary" pointer.
788+
if (VA.getLocInfo() == CCValAssign::Indirect) {
789+
if (Handler.isIncomingArgumentHandler())
790+
Args[i].Regs[0] = MRI.createGenericVirtualRegister(PointerTy);
791+
} else {
792+
// For each split register, create and assign a vreg that will store
793+
// the incoming component of the larger value. These will later be
794+
// merged to form the final vreg.
795+
for (unsigned Part = 0; Part < NumParts; ++Part)
796+
Args[i].Regs[Part] = MRI.createGenericVirtualRegister(NewLLT);
797+
}
786798
}
787799

788800
assert((j + (NumParts - 1)) < ArgLocs.size() &&
789801
"Too many regs for number of args");
790802

791803
// Coerce into outgoing value types before register assignment.
792-
if (!Handler.isIncomingArgumentHandler() && OrigTy != ValTy) {
804+
if (!Handler.isIncomingArgumentHandler() && OrigTy != ValTy &&
805+
VA.getLocInfo() != CCValAssign::Indirect) {
793806
assert(Args[i].OrigRegs.size() == 1);
794807
buildCopyToRegs(MIRBuilder, Args[i].Regs, Args[i].OrigRegs[0], OrigTy,
795808
ValTy, extendOpFromFlags(Args[i].Flags[0]));
796809
}
797810

811+
bool IndirectParameterPassingHandled = false;
798812
bool BigEndianPartOrdering = TLI->hasBigEndianPartOrdering(OrigVT, DL);
799813
for (unsigned Part = 0; Part < NumParts; ++Part) {
814+
assert((VA.getLocInfo() != CCValAssign::Indirect || Part == 0) &&
815+
"Only the first parameter should be processed when "
816+
"handling indirect passing!");
800817
Register ArgReg = Args[i].Regs[Part];
801818
// There should be Regs.size() ArgLocs per argument.
802819
unsigned Idx = BigEndianPartOrdering ? NumParts - 1 - Part : Part;
803820
CCValAssign &VA = ArgLocs[j + Idx];
804821
const ISD::ArgFlagsTy Flags = Args[i].Flags[Part];
805822

823+
// We found an indirect parameter passing, and we have an
824+
// OutgoingValueHandler as our handler (so we are at the call site or the
825+
// return value). In this case, start the construction of the following
826+
// GMIR, that is responsible for the preparation of indirect parameter
827+
// passing:
828+
//
829+
// %1(indirectly passed type) = The value to pass
830+
// %3(pointer) = G_FRAME_INDEX %stack.0
831+
// G_STORE %1, %3 :: (store (s128), align 8)
832+
//
833+
// After this GMIR, the remaining part of the loop body will decide how
834+
// to get the value to the caller and we break out of the loop.
835+
if (VA.getLocInfo() == CCValAssign::Indirect &&
836+
!Handler.isIncomingArgumentHandler()) {
837+
Align AlignmentForStored = DL.getPrefTypeAlign(Args[i].Ty);
838+
MachineFrameInfo &MFI = MF.getFrameInfo();
839+
// Get some space on the stack for the value, so later we can pass it
840+
// as a reference.
841+
int FrameIdx = MFI.CreateStackObject(OrigTy.getScalarSizeInBits(),
842+
AlignmentForStored, false);
843+
Register PointerToStackReg =
844+
MIRBuilder.buildFrameIndex(PointerTy, FrameIdx).getReg(0);
845+
MachinePointerInfo StackPointerMPO =
846+
MachinePointerInfo::getFixedStack(MF, FrameIdx);
847+
// Store the value in the previously created stack space.
848+
MIRBuilder.buildStore(Args[i].OrigRegs[Part], PointerToStackReg,
849+
StackPointerMPO,
850+
inferAlignFromPtrInfo(MF, StackPointerMPO));
851+
852+
ArgReg = PointerToStackReg;
853+
IndirectParameterPassingHandled = true;
854+
}
855+
806856
if (VA.isMemLoc() && !Flags.isByVal()) {
807857
// Individual pieces may have been spilled to the stack and others
808858
// passed in registers.
@@ -812,14 +862,21 @@ bool CallLowering::handleAssignments(ValueHandler &Handler,
812862
LLT MemTy = Handler.getStackValueStoreType(DL, VA, Flags);
813863

814864
MachinePointerInfo MPO;
815-
Register StackAddr = Handler.getStackAddress(
816-
MemTy.getSizeInBytes(), VA.getLocMemOffset(), MPO, Flags);
817-
818-
Handler.assignValueToAddress(Args[i], Part, StackAddr, MemTy, MPO, VA);
819-
continue;
820-
}
821-
822-
if (VA.isMemLoc() && Flags.isByVal()) {
865+
Register StackAddr =
866+
Handler.getStackAddress(VA.getLocInfo() == CCValAssign::Indirect
867+
? PointerTy.getSizeInBytes()
868+
: MemTy.getSizeInBytes(),
869+
VA.getLocMemOffset(), MPO, Flags);
870+
871+
// Finish the handling of indirect passing from the passers
872+
// (OutgoingParameterHandler) side.
873+
// This branch is needed, so the pointer to the value is loaded onto the
874+
// stack.
875+
if (VA.getLocInfo() == CCValAssign::Indirect)
876+
Handler.assignValueToAddress(ArgReg, StackAddr, PointerTy, MPO, VA);
877+
else
878+
Handler.assignValueToAddress(Args[i], Part, StackAddr, MemTy, MPO, VA);
879+
} else if (VA.isMemLoc() && Flags.isByVal()) {
823880
assert(Args[i].Regs.size() == 1 &&
824881
"didn't expect split byval pointer");
825882

@@ -858,30 +915,45 @@ bool CallLowering::handleAssignments(ValueHandler &Handler,
858915
DstMPO, DstAlign, SrcMPO, SrcAlign,
859916
MemSize, VA);
860917
}
861-
continue;
862-
}
863-
864-
assert(!VA.needsCustom() && "custom loc should have been handled already");
865-
866-
if (i == 0 && !ThisReturnRegs.empty() &&
867-
Handler.isIncomingArgumentHandler() &&
868-
isTypeIsValidForThisReturn(ValVT)) {
918+
} else if (i == 0 && !ThisReturnRegs.empty() &&
919+
Handler.isIncomingArgumentHandler() &&
920+
isTypeIsValidForThisReturn(ValVT)) {
869921
Handler.assignValueToReg(ArgReg, ThisReturnRegs[Part], VA);
870-
continue;
871-
}
872-
873-
if (Handler.isIncomingArgumentHandler())
922+
} else if (Handler.isIncomingArgumentHandler()) {
874923
Handler.assignValueToReg(ArgReg, VA.getLocReg(), VA);
875-
else {
924+
} else {
876925
DelayedOutgoingRegAssignments.emplace_back([=, &Handler]() {
877926
Handler.assignValueToReg(ArgReg, VA.getLocReg(), VA);
878927
});
879928
}
929+
930+
// Finish the handling of indirect parameter passing when receiving
931+
// the value (we are in the called function or the caller when receiving
932+
// the return value).
933+
if (VA.getLocInfo() == CCValAssign::Indirect &&
934+
Handler.isIncomingArgumentHandler()) {
935+
Align Alignment = DL.getABITypeAlign(Args[i].Ty);
936+
MachinePointerInfo MPO = MachinePointerInfo::getUnknownStack(MF);
937+
938+
// Since we are doing indirect parameter passing, we know that the value
939+
// in the temporary register is not the value passed to the function,
940+
// but rather a pointer to that value. Let's load that value into the
941+
// virtual register where the parameter should go.
942+
MIRBuilder.buildLoad(Args[i].OrigRegs[0], Args[i].Regs[0], MPO,
943+
Alignment);
944+
945+
IndirectParameterPassingHandled = true;
946+
}
947+
948+
if (IndirectParameterPassingHandled)
949+
break;
880950
}
881951

882952
// Now that all pieces have been assigned, re-pack the register typed values
883-
// into the original value typed registers.
884-
if (Handler.isIncomingArgumentHandler() && OrigVT != LocVT) {
953+
// into the original value typed registers. This is only necessary, when
954+
// the value was passed in multiple registers, not indirectly.
955+
if (Handler.isIncomingArgumentHandler() && OrigVT != LocVT &&
956+
!IndirectParameterPassingHandled) {
885957
// Merge the split registers into the expected larger result vregs of
886958
// the original call.
887959
buildCopyFromRegs(MIRBuilder, Args[i].OrigRegs, Args[i].Regs, OrigTy,

llvm/lib/Target/RISCV/GISel/RISCVCallLowering.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -341,10 +341,8 @@ static bool isLegalElementTypeForRVV(Type *EltTy,
341341
// TODO: Remove IsLowerArgs argument by adding support for vectors in lowerCall.
342342
static bool isSupportedArgumentType(Type *T, const RISCVSubtarget &Subtarget,
343343
bool IsLowerArgs = false) {
344-
// TODO: Integers larger than 2*XLen are passed indirectly which is not
345-
// supported yet.
346344
if (T->isIntegerTy())
347-
return T->getIntegerBitWidth() <= Subtarget.getXLen() * 2;
345+
return true;
348346
if (T->isHalfTy() || T->isFloatTy() || T->isDoubleTy())
349347
return true;
350348
if (T->isPointerTy())

0 commit comments

Comments
 (0)