Skip to content

Commit b5132b7

Browse files
[SPIR-V] Improve type inference: fix types of return values in call lowering (#116609)
Goals of the PR are: * to ensure that correct types are applied to virtual registers which were used as return values in call lowering. A reproducer is attached as a new test case, before the PR it fails because spirv-val considers output invalid due to wrong result/operand types in OpPhi's; * improve type inference by speeding up postprocessing of types: by limiting iterations by checking what remains to process, and processing each instruction just once for any number of operands with uncomplete types; * improve type inference by more accurate work with uncomplete types (pass uncomplete property to dependent operands, ensure consistency of uncomplete-types data structure); * change processing order and add traversing of PHI nodes when type inference apply instructions results to specify/update/cast operands type (fixes an issue with OpPhi's result type mismatch with operand types).
1 parent 820403c commit b5132b7

24 files changed

+1968
-303
lines changed

llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -469,12 +469,8 @@ static Register buildLoadInst(SPIRVType *BaseType, Register PtrRegister,
469469
MachineIRBuilder &MIRBuilder,
470470
SPIRVGlobalRegistry *GR, LLT LowLevelType,
471471
Register DestinationReg = Register(0)) {
472-
MachineRegisterInfo *MRI = MIRBuilder.getMRI();
473-
if (!DestinationReg.isValid()) {
474-
DestinationReg = MRI->createVirtualRegister(&SPIRV::iIDRegClass);
475-
MRI->setType(DestinationReg, LLT::scalar(64));
476-
GR->assignSPIRVTypeToVReg(BaseType, DestinationReg, MIRBuilder.getMF());
477-
}
472+
if (!DestinationReg.isValid())
473+
DestinationReg = createVirtualRegister(BaseType, GR, MIRBuilder);
478474
// TODO: consider using correct address space and alignment (p0 is canonical
479475
// type for selection though).
480476
MachinePointerInfo PtrInfo = MachinePointerInfo();
@@ -2151,7 +2147,7 @@ static bool buildEnqueueKernel(const SPIRV::IncomingCall *Call,
21512147
const SPIRVType *PointerSizeTy = GR->getOrCreateSPIRVPointerType(
21522148
Int32Ty, MIRBuilder, SPIRV::StorageClass::Function);
21532149
for (unsigned I = 0; I < LocalSizeNum; ++I) {
2154-
Register Reg = MRI->createVirtualRegister(&SPIRV::iIDRegClass);
2150+
Register Reg = MRI->createVirtualRegister(&SPIRV::pIDRegClass);
21552151
MRI->setType(Reg, LLType);
21562152
GR->assignSPIRVTypeToVReg(PointerSizeTy, Reg, MIRBuilder.getMF());
21572153
auto GEPInst = MIRBuilder.buildIntrinsic(
@@ -2539,23 +2535,11 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
25392535
SPIRVGlobalRegistry *GR) {
25402536
LLVM_DEBUG(dbgs() << "Lowering builtin call: " << DemangledCall << "\n");
25412537

2542-
// SPIR-V type and return register.
2543-
Register ReturnRegister = OrigRet;
2544-
SPIRVType *ReturnType = nullptr;
2545-
if (OrigRetTy && !OrigRetTy->isVoidTy()) {
2546-
ReturnType = GR->assignTypeToVReg(OrigRetTy, ReturnRegister, MIRBuilder);
2547-
if (!MIRBuilder.getMRI()->getRegClassOrNull(ReturnRegister))
2548-
MIRBuilder.getMRI()->setRegClass(ReturnRegister,
2549-
GR->getRegClass(ReturnType));
2550-
} else if (OrigRetTy && OrigRetTy->isVoidTy()) {
2551-
ReturnRegister = MIRBuilder.getMRI()->createVirtualRegister(&IDRegClass);
2552-
MIRBuilder.getMRI()->setType(ReturnRegister, LLT::scalar(64));
2553-
ReturnType = GR->assignTypeToVReg(OrigRetTy, ReturnRegister, MIRBuilder);
2554-
}
2555-
25562538
// Lookup the builtin in the TableGen records.
2539+
SPIRVType *SpvType = GR->getSPIRVTypeForVReg(OrigRet);
2540+
assert(SpvType && "Inconsistent return register: expected valid type info");
25572541
std::unique_ptr<const IncomingCall> Call =
2558-
lookupBuiltin(DemangledCall, Set, ReturnRegister, ReturnType, Args);
2542+
lookupBuiltin(DemangledCall, Set, OrigRet, SpvType, Args);
25592543

25602544
if (!Call) {
25612545
LLVM_DEBUG(dbgs() << "Builtin record was not found!\n");

llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,23 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
539539

540540
if (isFunctionDecl && !DemangledName.empty() &&
541541
(canUseGLSL || canUseOpenCL)) {
542+
if (ResVReg.isValid()) {
543+
if (!GR->getSPIRVTypeForVReg(ResVReg)) {
544+
const Type *RetTy = OrigRetTy;
545+
if (auto *PtrRetTy = dyn_cast<PointerType>(OrigRetTy)) {
546+
const Value *OrigValue = Info.OrigRet.OrigValue;
547+
if (!OrigValue)
548+
OrigValue = Info.CB;
549+
if (OrigValue)
550+
if (Type *ElemTy = GR->findDeducedElementType(OrigValue))
551+
RetTy =
552+
TypedPointerType::get(ElemTy, PtrRetTy->getAddressSpace());
553+
}
554+
setRegClassType(ResVReg, RetTy, GR, MIRBuilder);
555+
}
556+
} else {
557+
ResVReg = createVirtualRegister(OrigRetTy, GR, MIRBuilder);
558+
}
542559
SmallVector<Register, 8> ArgVRegs;
543560
for (auto Arg : Info.OrigArgs) {
544561
assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");

llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -69,16 +69,31 @@ void SPIRVGeneralDuplicatesTracker::buildDepsGraph(
6969
MachineOperand *RegOp = &VRegDef->getOperand(0);
7070
if (Reg2Entry.count(RegOp) == 0 &&
7171
(MI->getOpcode() != SPIRV::OpVariable || i != 3)) {
72-
std::string DiagMsg;
73-
raw_string_ostream OS(DiagMsg);
74-
OS << "Unexpected pattern while building a dependency "
75-
"graph.\nInstruction: ";
76-
MI->print(OS);
77-
OS << "Operand: ";
78-
Op.print(OS);
79-
OS << "\nOperand definition: ";
80-
VRegDef->print(OS);
81-
report_fatal_error(DiagMsg.c_str());
72+
// try to repair the unexpected code pattern
73+
bool IsFixed = false;
74+
if (VRegDef->getOpcode() == TargetOpcode::G_CONSTANT &&
75+
RegOp->isReg() && MRI.getType(RegOp->getReg()).isScalar()) {
76+
const Constant *C = VRegDef->getOperand(1).getCImm();
77+
add(C, MI->getParent()->getParent(), RegOp->getReg());
78+
auto Iter = CT.Storage.find(C);
79+
if (Iter != CT.Storage.end()) {
80+
SPIRV::DTSortableEntry &MissedEntry = Iter->second;
81+
Reg2Entry[RegOp] = &MissedEntry;
82+
IsFixed = true;
83+
}
84+
}
85+
if (!IsFixed) {
86+
std::string DiagMsg;
87+
raw_string_ostream OS(DiagMsg);
88+
OS << "Unexpected pattern while building a dependency "
89+
"graph.\nInstruction: ";
90+
MI->print(OS);
91+
OS << "Operand: ";
92+
Op.print(OS);
93+
OS << "\nOperand definition: ";
94+
VRegDef->print(OS);
95+
report_fatal_error(DiagMsg.c_str());
96+
}
8297
}
8398
if (Reg2Entry.count(RegOp))
8499
E->addDep(Reg2Entry[RegOp]);

0 commit comments

Comments
 (0)