Skip to content

Commit 9e85ffd

Browse files
uniform creation of vregs; address call lowering & builtins
1 parent a962e08 commit 9e85ffd

File tree

3 files changed

+52
-12
lines changed

3 files changed

+52
-12
lines changed

llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,7 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
404404
for (const auto &Arg : F.args()) {
405405
assert(VRegs[i].size() == 1 && "Formal arg has multiple vregs");
406406
MRI->setRegClass(VRegs[i][0], GR->getRegClass(ArgTypeVRegs[i]));
407+
MRI->setType(VRegs[i][0], GR->getRegType(ArgTypeVRegs[i]));
407408
MIRBuilder.buildInstr(SPIRV::OpFunctionParameter)
408409
.addDef(VRegs[i][0])
409410
.addUse(GR->getSPIRVTypeID(ArgTypeVRegs[i]));
@@ -532,10 +533,17 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
532533
SmallVector<Register, 8> ArgVRegs;
533534
for (auto Arg : Info.OrigArgs) {
534535
assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
535-
ArgVRegs.push_back(Arg.Regs[0]);
536-
SPIRVType *SPIRVTy = GR->getOrCreateSPIRVType(Arg.Ty, MIRBuilder);
537-
if (!GR->getSPIRVTypeForVReg(Arg.Regs[0]))
538-
GR->assignSPIRVTypeToVReg(SPIRVTy, Arg.Regs[0], MF);
536+
Register ArgReg = Arg.Regs[0];
537+
ArgVRegs.push_back(ArgReg);
538+
SPIRVType *SpvType = GR->getSPIRVTypeForVReg(ArgReg);
539+
if (!SpvType) {
540+
SpvType = GR->getOrCreateSPIRVType(Arg.Ty, MIRBuilder);
541+
GR->assignSPIRVTypeToVReg(SpvType, ArgReg, MF);
542+
}
543+
if (!MRI->getRegClassOrNull(ArgReg)) {
544+
MRI->setRegClass(ArgReg, GR->getRegClass(SpvType));
545+
MRI->setType(ArgReg, GR->getRegType(SpvType));
546+
}
539547
}
540548
auto instructionSet = canUseOpenCL ? SPIRV::InstructionSet::OpenCL_std
541549
: SPIRV::InstructionSet::GLSL_std_450;

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1460,6 +1460,12 @@ SPIRVGlobalRegistry::getRegClass(SPIRVType *SpvType) const {
14601460
return &SPIRV::iIDRegClass;
14611461
}
14621462

1463+
inline unsigned getAS(SPIRVType *SpvType) {
1464+
return storageClassToAddressSpace(
1465+
static_cast<SPIRV::StorageClass::StorageClass>(
1466+
SpvType->getOperand(1).getImm()));
1467+
}
1468+
14631469
LLT SPIRVGlobalRegistry::getRegType(SPIRVType *SpvType) const {
14641470
unsigned Opcode = SpvType ? SpvType->getOpcode() : 0;
14651471
switch (Opcode) {
@@ -1468,13 +1474,13 @@ LLT SPIRVGlobalRegistry::getRegType(SPIRVType *SpvType) const {
14681474
case SPIRV::OpTypeBool:
14691475
return LLT::scalar(getScalarOrVectorBitWidth(SpvType));
14701476
case SPIRV::OpTypePointer:
1471-
return LLT::pointer(0, getPointerSize());
1477+
return LLT::pointer(getAS(SpvType), getPointerSize());
14721478
case SPIRV::OpTypeVector: {
14731479
SPIRVType *ElemType = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
14741480
LLT ET;
14751481
switch (ElemType ? ElemType->getOpcode() : 0) {
14761482
case SPIRV::OpTypePointer:
1477-
ET = LLT::pointer(0, getPointerSize());
1483+
ET = LLT::pointer(getAS(ElemType), getPointerSize());
14781484
break;
14791485
case SPIRV::OpTypeInt:
14801486
case SPIRV::OpTypeFloat:
@@ -1484,7 +1490,8 @@ LLT SPIRVGlobalRegistry::getRegType(SPIRVType *SpvType) const {
14841490
default:
14851491
ET = LLT::scalar(64);
14861492
}
1487-
return LLT::fixed_vector(2, ET);
1493+
return LLT::fixed_vector(
1494+
static_cast<unsigned>(SpvType->getOperand(2).getImm()), ET);
14881495
}
14891496
}
14901497
return LLT::scalar(64);

llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,7 @@ inline bool getIsFloat(SPIRVType *SpvType, const SPIRVGlobalRegistry &GR) {
327327
->getOpcode() == SPIRV::OpTypeFloat;
328328
}
329329

330+
/*
330331
static std::pair<Register, unsigned>
331332
createNewIdReg(SPIRVType *SpvType, Register SrcReg, MachineRegisterInfo &MRI,
332333
const SPIRVGlobalRegistry &GR) {
@@ -337,16 +338,18 @@ createNewIdReg(SPIRVType *SpvType, Register SrcReg, MachineRegisterInfo &MRI,
337338
LLT SrcLLT = MRI.getType(SrcReg);
338339
bool IsFloat = getIsFloat(SpvType, GR);
339340
auto GetIdOp = IsFloat ? SPIRV::GET_fID : SPIRV::GET_ID;
341+
bool IsVec = SrcLLT.isVector();
342+
unsigned NumElements =
343+
IsVec ? std::max(2U, GR.getScalarOrVectorComponentCount(SpvType)) : 1;
340344
if (SrcLLT.isPointer()) {
341345
unsigned PtrSz = GR.getPointerSize();
342346
NewT = LLT::pointer(0, PtrSz);
343-
bool IsVec = SrcLLT.isVector();
344347
if (IsVec)
345-
NewT = LLT::fixed_vector(2, NewT);
348+
NewT = LLT::fixed_vector(NumElements, NewT);
346349
GetIdOp = IsVec ? SPIRV::GET_vpID : SPIRV::GET_pID;
347-
} else if (SrcLLT.isVector()) {
348-
NewT = LLT::scalar(GR.getScalarOrVectorBitWidth(SpvType));
349-
NewT = LLT::fixed_vector(2, NewT);
350+
} else if (IsVec) {
351+
NewT = LLT::fixed_vector(
352+
NumElements, LLT::scalar(GR.getScalarOrVectorBitWidth(SpvType)));
350353
GetIdOp = IsFloat ? SPIRV::GET_vfID : SPIRV::GET_vID;
351354
} else {
352355
NewT = LLT::scalar(GR.getScalarOrVectorBitWidth(SpvType));
@@ -355,6 +358,28 @@ createNewIdReg(SPIRVType *SpvType, Register SrcReg, MachineRegisterInfo &MRI,
355358
MRI.setRegClass(IdReg, GR.getRegClass(SpvType));
356359
return {IdReg, GetIdOp};
357360
}
361+
*/
362+
static std::pair<Register, unsigned>
363+
createNewIdReg(SPIRVType *SpvType, Register SrcReg, MachineRegisterInfo &MRI,
364+
const SPIRVGlobalRegistry &GR) {
365+
if (!SpvType)
366+
SpvType = GR.getSPIRVTypeForVReg(SrcReg);
367+
const TargetRegisterClass *RC = GR.getRegClass(SpvType);
368+
Register Reg = MRI.createGenericVirtualRegister(GR.getRegType(SpvType));
369+
MRI.setRegClass(Reg, RC);
370+
unsigned GetIdOp = SPIRV::GET_ID;
371+
if (RC == &SPIRV::fIDRegClass)
372+
GetIdOp = SPIRV::GET_fID;
373+
else if (RC == &SPIRV::pIDRegClass)
374+
GetIdOp = SPIRV::GET_pID;
375+
else if (RC == &SPIRV::vfIDRegClass)
376+
GetIdOp = SPIRV::GET_vfID;
377+
else if (RC == &SPIRV::vpIDRegClass)
378+
GetIdOp = SPIRV::GET_vpID;
379+
else if (RC == &SPIRV::vIDRegClass)
380+
GetIdOp = SPIRV::GET_vID;
381+
return {Reg, GetIdOp};
382+
}
358383

359384
// Insert ASSIGN_TYPE instuction between Reg and its definition, set NewReg as
360385
// a dst of the definition, assign SPIRVType to both registers. If SpvType is

0 commit comments

Comments
 (0)