Skip to content

[SPIR-V] Rework usage of virtual registers' types and classes #104104

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVMCCodeEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ static bool hasType(const MCInst &MI, const MCInstrInfo &MII) {
// Check if we define an ID, and take a type as operand 1.
auto &DefOpInfo = MCDesc.operands()[0];
auto &FirstArgOpInfo = MCDesc.operands()[1];
return DefOpInfo.RegClass != SPIRV::TYPERegClassID &&
return DefOpInfo.RegClass >= 0 && FirstArgOpInfo.RegClass >= 0 &&
DefOpInfo.RegClass != SPIRV::TYPERegClassID &&
FirstArgOpInfo.RegClass == SPIRV::TYPERegClassID;
}
return false;
Expand Down
254 changes: 82 additions & 172 deletions llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp

Large diffs are not rendered by default.

27 changes: 18 additions & 9 deletions llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
}

auto MRI = MIRBuilder.getMRI();
Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(64));
MRI->setRegClass(FuncVReg, &SPIRV::iIDRegClass);
if (F.isDeclaration())
GR->add(&F, &MIRBuilder.getMF(), FuncVReg);
Expand Down Expand Up @@ -403,12 +403,14 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
int i = 0;
for (const auto &Arg : F.args()) {
assert(VRegs[i].size() == 1 && "Formal arg has multiple vregs");
MRI->setRegClass(VRegs[i][0], &SPIRV::iIDRegClass);
Register ArgReg = VRegs[i][0];
MRI->setRegClass(ArgReg, GR->getRegClass(ArgTypeVRegs[i]));
MRI->setType(ArgReg, GR->getRegType(ArgTypeVRegs[i]));
MIRBuilder.buildInstr(SPIRV::OpFunctionParameter)
.addDef(VRegs[i][0])
.addDef(ArgReg)
.addUse(GR->getSPIRVTypeID(ArgTypeVRegs[i]));
if (F.isDeclaration())
GR->add(&Arg, &MIRBuilder.getMF(), VRegs[i][0]);
GR->add(&Arg, &MIRBuilder.getMF(), ArgReg);
i++;
}
// Name the function.
Expand Down Expand Up @@ -532,10 +534,17 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
SmallVector<Register, 8> ArgVRegs;
for (auto Arg : Info.OrigArgs) {
assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
ArgVRegs.push_back(Arg.Regs[0]);
SPIRVType *SPIRVTy = GR->getOrCreateSPIRVType(Arg.Ty, MIRBuilder);
if (!GR->getSPIRVTypeForVReg(Arg.Regs[0]))
GR->assignSPIRVTypeToVReg(SPIRVTy, Arg.Regs[0], MF);
Register ArgReg = Arg.Regs[0];
ArgVRegs.push_back(ArgReg);
SPIRVType *SpvType = GR->getSPIRVTypeForVReg(ArgReg);
if (!SpvType) {
SpvType = GR->getOrCreateSPIRVType(Arg.Ty, MIRBuilder);
GR->assignSPIRVTypeToVReg(SpvType, ArgReg, MF);
}
if (!MRI->getRegClassOrNull(ArgReg)) {
MRI->setRegClass(ArgReg, GR->getRegClass(SpvType));
MRI->setType(ArgReg, GR->getRegType(SpvType));
}
}
auto instructionSet = canUseOpenCL ? SPIRV::InstructionSet::OpenCL_std
: SPIRV::InstructionSet::GLSL_std_450;
Expand All @@ -557,7 +566,7 @@ bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
for (const Argument &Arg : CF->args()) {
if (MIRBuilder.getDataLayout().getTypeStoreSize(Arg.getType()).isZero())
continue; // Don't handle zero sized types.
Register Reg = MRI->createGenericVirtualRegister(LLT::scalar(32));
Register Reg = MRI->createGenericVirtualRegister(LLT::scalar(64));
MRI->setRegClass(Reg, &SPIRV::iIDRegClass);
ToInsert.push_back({Reg});
VRegArgs.push_back(ToInsert.back());
Expand Down
155 changes: 100 additions & 55 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,17 +72,14 @@ void SPIRVGlobalRegistry::assignSPIRVTypeToVReg(SPIRVType *SpirvType,
VRegToTypeMap[&MF][VReg] = SpirvType;
}

static Register createTypeVReg(MachineIRBuilder &MIRBuilder) {
auto &MRI = MIRBuilder.getMF().getRegInfo();
auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32));
static Register createTypeVReg(MachineRegisterInfo &MRI) {
auto Res = MRI.createGenericVirtualRegister(LLT::scalar(64));
MRI.setRegClass(Res, &SPIRV::TYPERegClass);
return Res;
}

static Register createTypeVReg(MachineRegisterInfo &MRI) {
auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32));
MRI.setRegClass(Res, &SPIRV::TYPERegClass);
return Res;
inline Register createTypeVReg(MachineIRBuilder &MIRBuilder) {
return createTypeVReg(MIRBuilder.getMF().getRegInfo());
}

SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) {
Expand Down Expand Up @@ -157,26 +154,24 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems,
return MIB;
}

std::tuple<Register, ConstantInt *, bool>
std::tuple<Register, ConstantInt *, bool, unsigned>
SPIRVGlobalRegistry::getOrCreateConstIntReg(uint64_t Val, SPIRVType *SpvType,
MachineIRBuilder *MIRBuilder,
MachineInstr *I,
const SPIRVInstrInfo *TII) {
const IntegerType *LLVMIntTy;
if (SpvType)
LLVMIntTy = cast<IntegerType>(getTypeForSPIRVType(SpvType));
else
LLVMIntTy = IntegerType::getInt32Ty(CurMF->getFunction().getContext());
assert(SpvType);
const IntegerType *LLVMIntTy =
cast<IntegerType>(getTypeForSPIRVType(SpvType));
unsigned BitWidth = getScalarOrVectorBitWidth(SpvType);
bool NewInstr = false;
// Find a constant in DT or build a new one.
ConstantInt *CI = ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val);
Register Res = DT.find(CI, CurMF);
if (!Res.isValid()) {
unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
// TODO: handle cases where the type is not 32bit wide
// TODO: https://github.com/llvm/llvm-project/issues/88129
LLT LLTy = LLT::scalar(32);
Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
Res =
CurMF->getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth));
CurMF->getRegInfo().setRegClass(Res, &SPIRV::iIDRegClass);
if (MIRBuilder)
assignTypeToVReg(LLVMIntTy, Res, *MIRBuilder);
Expand All @@ -185,35 +180,27 @@ SPIRVGlobalRegistry::getOrCreateConstIntReg(uint64_t Val, SPIRVType *SpvType,
DT.add(CI, CurMF, Res);
NewInstr = true;
}
return std::make_tuple(Res, CI, NewInstr);
return std::make_tuple(Res, CI, NewInstr, BitWidth);
}

std::tuple<Register, ConstantFP *, bool, unsigned>
SPIRVGlobalRegistry::getOrCreateConstFloatReg(APFloat Val, SPIRVType *SpvType,
MachineIRBuilder *MIRBuilder,
MachineInstr *I,
const SPIRVInstrInfo *TII) {
const Type *LLVMFloatTy;
assert(SpvType);
LLVMContext &Ctx = CurMF->getFunction().getContext();
unsigned BitWidth = 32;
if (SpvType)
LLVMFloatTy = getTypeForSPIRVType(SpvType);
else {
LLVMFloatTy = Type::getFloatTy(Ctx);
if (MIRBuilder)
SpvType = getOrCreateSPIRVType(LLVMFloatTy, *MIRBuilder);
}
const Type *LLVMFloatTy = getTypeForSPIRVType(SpvType);
unsigned BitWidth = getScalarOrVectorBitWidth(SpvType);
bool NewInstr = false;
// Find a constant in DT or build a new one.
auto *const CI = ConstantFP::get(Ctx, Val);
Register Res = DT.find(CI, CurMF);
if (!Res.isValid()) {
if (SpvType)
BitWidth = getScalarOrVectorBitWidth(SpvType);
// TODO: handle cases where the type is not 32bit wide
// TODO: https://github.com/llvm/llvm-project/issues/88129
LLT LLTy = LLT::scalar(32);
Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
Res =
CurMF->getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth));
CurMF->getRegInfo().setRegClass(Res, &SPIRV::fIDRegClass);
if (MIRBuilder)
assignTypeToVReg(LLVMFloatTy, Res, *MIRBuilder);
Expand Down Expand Up @@ -269,7 +256,8 @@ Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
ConstantInt *CI;
Register Res;
bool New;
std::tie(Res, CI, New) =
unsigned BitWidth;
std::tie(Res, CI, New, BitWidth) =
getOrCreateConstIntReg(Val, SpvType, nullptr, &I, &TII);
// If we have found Res register which is defined by the passed G_CONSTANT
// machine instruction, a new constant instruction should be created.
Expand All @@ -281,7 +269,7 @@ Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantI))
.addDef(Res)
.addUse(getSPIRVTypeID(SpvType));
addNumImm(APInt(getScalarOrVectorBitWidth(SpvType), Val), MIB);
addNumImm(APInt(BitWidth, Val), MIB);
} else {
MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull))
.addDef(Res)
Expand All @@ -297,19 +285,17 @@ Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val,
MachineIRBuilder &MIRBuilder,
SPIRVType *SpvType,
bool EmitIR) {
assert(SpvType);
auto &MF = MIRBuilder.getMF();
const IntegerType *LLVMIntTy;
if (SpvType)
LLVMIntTy = cast<IntegerType>(getTypeForSPIRVType(SpvType));
else
LLVMIntTy = IntegerType::getInt32Ty(MF.getFunction().getContext());
const IntegerType *LLVMIntTy =
cast<IntegerType>(getTypeForSPIRVType(SpvType));
// Find a constant in DT or build a new one.
const auto ConstInt =
ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val);
Register Res = DT.find(ConstInt, &MF);
if (!Res.isValid()) {
unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
LLT LLTy = LLT::scalar(EmitIR ? BitWidth : 32);
unsigned BitWidth = getScalarOrVectorBitWidth(SpvType);
LLT LLTy = LLT::scalar(BitWidth);
Res = MF.getRegInfo().createGenericVirtualRegister(LLTy);
MF.getRegInfo().setRegClass(Res, &SPIRV::iIDRegClass);
assignTypeToVReg(LLVMIntTy, Res, MIRBuilder,
Expand All @@ -318,18 +304,17 @@ Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val,
if (EmitIR) {
MIRBuilder.buildConstant(Res, *ConstInt);
} else {
if (!SpvType)
SpvType = getOrCreateSPIRVIntegerType(BitWidth, MIRBuilder);
Register SpvTypeReg = getSPIRVTypeID(SpvType);
MachineInstrBuilder MIB;
if (Val) {
MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI)
.addDef(Res)
.addUse(getSPIRVTypeID(SpvType));
.addUse(SpvTypeReg);
addNumImm(APInt(BitWidth, Val), MIB);
} else {
MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
.addDef(Res)
.addUse(getSPIRVTypeID(SpvType));
.addUse(SpvTypeReg);
}
const auto &Subtarget = CurMF->getSubtarget();
constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
Expand All @@ -353,7 +338,8 @@ Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val,
const auto ConstFP = ConstantFP::get(Ctx, Val);
Register Res = DT.find(ConstFP, &MF);
if (!Res.isValid()) {
Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(32));
Res = MF.getRegInfo().createGenericVirtualRegister(
LLT::scalar(getScalarOrVectorBitWidth(SpvType)));
MF.getRegInfo().setRegClass(Res, &SPIRV::fIDRegClass);
assignSPIRVTypeToVReg(SpvType, Res, MF);
DT.add(ConstFP, &MF, Res);
Expand Down Expand Up @@ -407,7 +393,7 @@ Register SPIRVGlobalRegistry::getOrCreateCompositeOrNull(

// TODO: handle cases where the type is not 32bit wide
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

todo doesn't seem to make sense anymore.

// TODO: https://github.com/llvm/llvm-project/issues/88129
LLT LLTy = LLT::scalar(32);
LLT LLTy = LLT::scalar(64);
Register SpvVecConst =
CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
CurMF->getRegInfo().setRegClass(SpvVecConst, &SPIRV::iIDRegClass);
Expand Down Expand Up @@ -509,7 +495,7 @@ Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(
getOrCreateSPIRVIntegerType(BitWidth, MIRBuilder);
SpvScalConst = buildConstantInt(Val, MIRBuilder, SpvBaseType, EmitIR);
}
LLT LLTy = EmitIR ? LLT::fixed_vector(ElemCnt, BitWidth) : LLT::scalar(32);
LLT LLTy = EmitIR ? LLT::fixed_vector(ElemCnt, BitWidth) : LLT::scalar(64);
Register SpvVecConst =
CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
CurMF->getRegInfo().setRegClass(SpvVecConst, &SPIRV::iIDRegClass);
Expand Down Expand Up @@ -650,7 +636,6 @@ Register SPIRVGlobalRegistry::buildGlobalVariable(

// Set to Reg the same type as ResVReg has.
auto MRI = MIRBuilder.getMRI();
assert(MRI->getType(ResVReg).isPointer() && "Pointer type is expected");
if (Reg != ResVReg) {
LLT RegLLTy =
LLT::pointer(MRI->getType(ResVReg).getAddressSpace(), getPointerSize());
Expand Down Expand Up @@ -706,8 +691,9 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeArray(uint32_t NumElems,
bool EmitIR) {
assert((ElemType->getOpcode() != SPIRV::OpTypeVoid) &&
"Invalid array element type");
SPIRVType *SpvTypeInt32 = getOrCreateSPIRVIntegerType(32, MIRBuilder);
Register NumElementsVReg =
buildConstantInt(NumElems, MIRBuilder, nullptr, EmitIR);
buildConstantInt(NumElems, MIRBuilder, SpvTypeInt32, EmitIR);
auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeArray)
.addDef(createTypeVReg(MIRBuilder))
.addUse(getSPIRVTypeID(ElemType))
Expand Down Expand Up @@ -1188,14 +1174,15 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeCoopMatr(
if (ResVReg.isValid())
return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(ResVReg);
ResVReg = createTypeVReg(MIRBuilder);
SPIRVType *SpvTypeInt32 = getOrCreateSPIRVIntegerType(32, MIRBuilder);
SPIRVType *SpirvTy =
MIRBuilder.buildInstr(SPIRV::OpTypeCooperativeMatrixKHR)
.addDef(ResVReg)
.addUse(getSPIRVTypeID(ElemType))
.addUse(buildConstantInt(Scope, MIRBuilder, nullptr, true))
.addUse(buildConstantInt(Rows, MIRBuilder, nullptr, true))
.addUse(buildConstantInt(Columns, MIRBuilder, nullptr, true))
.addUse(buildConstantInt(Use, MIRBuilder, nullptr, true));
.addUse(buildConstantInt(Scope, MIRBuilder, SpvTypeInt32, true))
.addUse(buildConstantInt(Rows, MIRBuilder, SpvTypeInt32, true))
.addUse(buildConstantInt(Columns, MIRBuilder, SpvTypeInt32, true))
.addUse(buildConstantInt(Use, MIRBuilder, SpvTypeInt32, true));
DT.add(ExtensionType, &MIRBuilder.getMF(), ResVReg);
return SpirvTy;
}
Expand Down Expand Up @@ -1386,8 +1373,8 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVArrayType(
if (Reg.isValid())
return getSPIRVTypeForVReg(Reg);
MachineBasicBlock &BB = *I.getParent();
SPIRVType *SpirvType = getOrCreateSPIRVIntegerType(32, I, TII);
Register Len = getOrCreateConstInt(NumElements, I, SpirvType, TII);
SPIRVType *SpvTypeInt32 = getOrCreateSPIRVIntegerType(32, I, TII);
Register Len = getOrCreateConstInt(NumElements, I, SpvTypeInt32, TII);
auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeArray))
.addDef(createTypeVReg(CurMF->getRegInfo()))
.addUse(getSPIRVTypeID(BaseType))
Expand Down Expand Up @@ -1436,7 +1423,7 @@ Register SPIRVGlobalRegistry::getOrCreateUndef(MachineInstr &I,
Register Res = DT.find(UV, CurMF);
if (Res.isValid())
return Res;
LLT LLTy = LLT::scalar(32);
LLT LLTy = LLT::scalar(64);
Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
CurMF->getRegInfo().setRegClass(Res, &SPIRV::iIDRegClass);
assignSPIRVTypeToVReg(SpvType, Res, *CurMF);
Expand All @@ -1451,3 +1438,61 @@ Register SPIRVGlobalRegistry::getOrCreateUndef(MachineInstr &I,
*ST.getRegisterInfo(), *ST.getRegBankInfo());
return Res;
}

const TargetRegisterClass *
SPIRVGlobalRegistry::getRegClass(SPIRVType *SpvType) const {
unsigned Opcode = SpvType->getOpcode();
switch (Opcode) {
case SPIRV::OpTypeFloat:
return &SPIRV::fIDRegClass;
case SPIRV::OpTypePointer:
return &SPIRV::pIDRegClass;
case SPIRV::OpTypeVector: {
SPIRVType *ElemType = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
unsigned ElemOpcode = ElemType ? ElemType->getOpcode() : 0;
if (ElemOpcode == SPIRV::OpTypeFloat)
return &SPIRV::vfIDRegClass;
if (ElemOpcode == SPIRV::OpTypePointer)
return &SPIRV::vpIDRegClass;
return &SPIRV::vIDRegClass;
}
}
return &SPIRV::iIDRegClass;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add an assert for SPIRV::OpTypeInt?

}

inline unsigned getAS(SPIRVType *SpvType) {
return storageClassToAddressSpace(
static_cast<SPIRV::StorageClass::StorageClass>(
SpvType->getOperand(1).getImm()));
}

LLT SPIRVGlobalRegistry::getRegType(SPIRVType *SpvType) const {
unsigned Opcode = SpvType ? SpvType->getOpcode() : 0;
switch (Opcode) {
case SPIRV::OpTypeInt:
case SPIRV::OpTypeFloat:
case SPIRV::OpTypeBool:
return LLT::scalar(getScalarOrVectorBitWidth(SpvType));
case SPIRV::OpTypePointer:
return LLT::pointer(getAS(SpvType), getPointerSize());
case SPIRV::OpTypeVector: {
SPIRVType *ElemType = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
LLT ET;
switch (ElemType ? ElemType->getOpcode() : 0) {
case SPIRV::OpTypePointer:
ET = LLT::pointer(getAS(ElemType), getPointerSize());
break;
case SPIRV::OpTypeInt:
case SPIRV::OpTypeFloat:
case SPIRV::OpTypeBool:
ET = LLT::scalar(getScalarOrVectorBitWidth(ElemType));
break;
default:
ET = LLT::scalar(64);
}
return LLT::fixed_vector(
static_cast<unsigned>(SpvType->getOperand(2).getImm()), ET);
}
}
return LLT::scalar(64);
}
Loading