Skip to content

Commit b0d03cc

Browse files
[SPIR-V] Fix illegal OpConstantComposite instruction with non-const constituents in SPIR-V Backend (#86352)
This PR fixes illegal use of OpConstantComposite with non-constant constituents. The test attached to the PR is able now to satisfy `spirv-val` check. Before the fix SPIR-V Backend produced for the attached test case a pattern like ``` %a = OpVariable %_ptr_CrossWorkgroup_uint CrossWorkgroup %uint_123 %11 = OpConstantComposite %_struct_6 %a %a ``` so that `spirv-val` complained with ``` error: line 25: OpConstantComposite Constituent <id> '10[%a]' is not a constant or undef. %11 = OpConstantComposite %_struct_6 %a %a ```
1 parent 1d250d9 commit b0d03cc

File tree

8 files changed

+96
-17
lines changed

8 files changed

+96
-17
lines changed

llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ void SPIRVGeneralDuplicatesTracker::buildDepsGraph(
3939
prebuildReg2Entry(GT, Reg2Entry);
4040
prebuildReg2Entry(FT, Reg2Entry);
4141
prebuildReg2Entry(AT, Reg2Entry);
42+
prebuildReg2Entry(MT, Reg2Entry);
4243
prebuildReg2Entry(ST, Reg2Entry);
4344

4445
for (auto &Op2E : Reg2Entry) {

llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ class SPIRVGeneralDuplicatesTracker {
262262
SPIRVDuplicatesTracker<GlobalVariable> GT;
263263
SPIRVDuplicatesTracker<Function> FT;
264264
SPIRVDuplicatesTracker<Argument> AT;
265+
SPIRVDuplicatesTracker<MachineInstr> MT;
265266
SPIRVDuplicatesTracker<SPIRV::SpecialTypeDescriptor> ST;
266267

267268
// NOTE: using MOs instead of regs to get rid of MF dependency to be able
@@ -306,6 +307,10 @@ class SPIRVGeneralDuplicatesTracker {
306307
AT.add(Arg, MF, R);
307308
}
308309

310+
void add(const MachineInstr *MI, const MachineFunction *MF, Register R) {
311+
MT.add(MI, MF, R);
312+
}
313+
309314
void add(const SPIRV::SpecialTypeDescriptor &TD, const MachineFunction *MF,
310315
Register R) {
311316
ST.add(TD, MF, R);
@@ -337,6 +342,10 @@ class SPIRVGeneralDuplicatesTracker {
337342
return AT.find(const_cast<Argument *>(Arg), MF);
338343
}
339344

345+
Register find(const MachineInstr *MI, const MachineFunction *MF) {
346+
return MT.find(const_cast<MachineInstr *>(MI), MF);
347+
}
348+
340349
Register find(const SPIRV::SpecialTypeDescriptor &TD,
341350
const MachineFunction *MF) {
342351
return ST.find(TD, MF);

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems,
123123
SPIRVType *ElemType,
124124
MachineIRBuilder &MIRBuilder) {
125125
auto EleOpc = ElemType->getOpcode();
126+
(void)EleOpc;
126127
assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat ||
127128
EleOpc == SPIRV::OpTypeBool) &&
128129
"Invalid vector element type");

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,14 @@ class SPIRVGlobalRegistry {
9494
DT.add(Arg, MF, R);
9595
}
9696

97+
void add(const MachineInstr *MI, MachineFunction *MF, Register R) {
98+
DT.add(MI, MF, R);
99+
}
100+
101+
Register find(const MachineInstr *MI, MachineFunction *MF) {
102+
return DT.find(MI, MF);
103+
}
104+
97105
Register find(const Constant *C, MachineFunction *MF) {
98106
return DT.find(C, MF);
99107
}

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 74 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
231231
Register buildZerosVal(const SPIRVType *ResType, MachineInstr &I) const;
232232
Register buildOnesVal(bool AllOnes, const SPIRVType *ResType,
233233
MachineInstr &I) const;
234+
235+
bool wrapIntoSpecConstantOp(MachineInstr &I,
236+
SmallVector<Register> &CompositeArgs) const;
234237
};
235238

236239
} // end anonymous namespace
@@ -1249,6 +1252,24 @@ static unsigned getArrayComponentCount(MachineRegisterInfo *MRI,
12491252
return N;
12501253
}
12511254

1255+
// Return true if the type represents a constant register
1256+
static bool isConstReg(MachineRegisterInfo *MRI, SPIRVType *OpDef) {
1257+
if (OpDef->getOpcode() == SPIRV::ASSIGN_TYPE &&
1258+
OpDef->getOperand(1).isReg()) {
1259+
if (SPIRVType *RefDef = MRI->getVRegDef(OpDef->getOperand(1).getReg()))
1260+
OpDef = RefDef;
1261+
}
1262+
return OpDef->getOpcode() == TargetOpcode::G_CONSTANT ||
1263+
OpDef->getOpcode() == TargetOpcode::G_FCONSTANT;
1264+
}
1265+
1266+
// Return true if the virtual register represents a constant
1267+
static bool isConstReg(MachineRegisterInfo *MRI, Register OpReg) {
1268+
if (SPIRVType *OpDef = MRI->getVRegDef(OpReg))
1269+
return isConstReg(MRI, OpDef);
1270+
return false;
1271+
}
1272+
12521273
bool SPIRVInstructionSelector::selectSplatVector(Register ResVReg,
12531274
const SPIRVType *ResType,
12541275
MachineInstr &I) const {
@@ -1266,16 +1287,7 @@ bool SPIRVInstructionSelector::selectSplatVector(Register ResVReg,
12661287

12671288
// check if we may construct a constant vector
12681289
Register OpReg = I.getOperand(OpIdx).getReg();
1269-
bool IsConst = false;
1270-
if (SPIRVType *OpDef = MRI->getVRegDef(OpReg)) {
1271-
if (OpDef->getOpcode() == SPIRV::ASSIGN_TYPE &&
1272-
OpDef->getOperand(1).isReg()) {
1273-
if (SPIRVType *RefDef = MRI->getVRegDef(OpDef->getOperand(1).getReg()))
1274-
OpDef = RefDef;
1275-
}
1276-
IsConst = OpDef->getOpcode() == TargetOpcode::G_CONSTANT ||
1277-
OpDef->getOpcode() == TargetOpcode::G_FCONSTANT;
1278-
}
1290+
bool IsConst = isConstReg(MRI, OpReg);
12791291

12801292
if (!IsConst && N < 2)
12811293
report_fatal_error(
@@ -1628,6 +1640,48 @@ bool SPIRVInstructionSelector::selectGEP(Register ResVReg,
16281640
return Res.constrainAllUses(TII, TRI, RBI);
16291641
}
16301642

1643+
// Maybe wrap a value into OpSpecConstantOp
1644+
bool SPIRVInstructionSelector::wrapIntoSpecConstantOp(
1645+
MachineInstr &I, SmallVector<Register> &CompositeArgs) const {
1646+
bool Result = true;
1647+
unsigned Lim = I.getNumExplicitOperands();
1648+
for (unsigned i = I.getNumExplicitDefs() + 1; i < Lim; ++i) {
1649+
Register OpReg = I.getOperand(i).getReg();
1650+
SPIRVType *OpDefine = MRI->getVRegDef(OpReg);
1651+
SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpReg);
1652+
if (!OpDefine || !OpType || isConstReg(MRI, OpDefine) ||
1653+
OpDefine->getOpcode() == TargetOpcode::G_ADDRSPACE_CAST) {
1654+
// The case of G_ADDRSPACE_CAST inside spv_const_composite() is processed
1655+
// by selectAddrSpaceCast()
1656+
CompositeArgs.push_back(OpReg);
1657+
continue;
1658+
}
1659+
MachineFunction *MF = I.getMF();
1660+
Register WrapReg = GR.find(OpDefine, MF);
1661+
if (WrapReg.isValid()) {
1662+
CompositeArgs.push_back(WrapReg);
1663+
continue;
1664+
}
1665+
// Create a new register for the wrapper
1666+
WrapReg = MRI->createVirtualRegister(&SPIRV::IDRegClass);
1667+
GR.add(OpDefine, MF, WrapReg);
1668+
CompositeArgs.push_back(WrapReg);
1669+
// Decorate the wrapper register and generate a new instruction
1670+
MRI->setType(WrapReg, LLT::pointer(0, 32));
1671+
GR.assignSPIRVTypeToVReg(OpType, WrapReg, *MF);
1672+
MachineBasicBlock &BB = *I.getParent();
1673+
Result = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpSpecConstantOp))
1674+
.addDef(WrapReg)
1675+
.addUse(GR.getSPIRVTypeID(OpType))
1676+
.addImm(static_cast<uint32_t>(SPIRV::Opcode::Bitcast))
1677+
.addUse(OpReg)
1678+
.constrainAllUses(TII, TRI, RBI);
1679+
if (!Result)
1680+
break;
1681+
}
1682+
return Result;
1683+
}
1684+
16311685
bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
16321686
const SPIRVType *ResType,
16331687
MachineInstr &I) const {
@@ -1666,17 +1720,21 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
16661720
case Intrinsic::spv_const_composite: {
16671721
// If no values are attached, the composite is null constant.
16681722
bool IsNull = I.getNumExplicitDefs() + 1 == I.getNumExplicitOperands();
1669-
unsigned Opcode =
1670-
IsNull ? SPIRV::OpConstantNull : SPIRV::OpConstantComposite;
1723+
// Select a proper instruction.
1724+
unsigned Opcode = SPIRV::OpConstantNull;
1725+
SmallVector<Register> CompositeArgs;
1726+
if (!IsNull) {
1727+
Opcode = SPIRV::OpConstantComposite;
1728+
if (!wrapIntoSpecConstantOp(I, CompositeArgs))
1729+
return false;
1730+
}
16711731
auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
16721732
.addDef(ResVReg)
16731733
.addUse(GR.getSPIRVTypeID(ResType));
16741734
// skip type MD node we already used when generated assign.type for this
16751735
if (!IsNull) {
1676-
for (unsigned i = I.getNumExplicitDefs() + 1;
1677-
i < I.getNumExplicitOperands(); ++i) {
1678-
MIB.addUse(I.getOperand(i).getReg());
1679-
}
1736+
for (Register OpReg : CompositeArgs)
1737+
MIB.addUse(OpReg);
16801738
}
16811739
return MIB.constrainAllUses(TII, TRI, RBI);
16821740
}

llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,7 @@ static void processSwitches(MachineFunction &MF, SPIRVGlobalRegistry *GR,
543543
Register Dst = ICMP->getOperand(0).getReg();
544544
MachineOperand &PredOp = ICMP->getOperand(1);
545545
const auto CC = static_cast<CmpInst::Predicate>(PredOp.getPredicate());
546+
(void)CC;
546547
assert((CC == CmpInst::ICMP_EQ || CC == CmpInst::ICMP_ULE) &&
547548
MRI.hasOneUse(Dst) && MRI.hasOneDef(CompareReg));
548549
uint64_t Value = getIConstVal(ICMP->getOperand(3).getReg(), &MRI);

llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1611,3 +1611,4 @@ multiclass OpcodeOperand<bits<32> value> {
16111611
// TODO: implement other mnemonics.
16121612
defm InBoundsPtrAccessChain : OpcodeOperand<70>;
16131613
defm PtrCastToGeneric : OpcodeOperand<121>;
1614+
defm Bitcast : OpcodeOperand<124>;

llvm/test/CodeGen/SPIRV/pointers/struct-opaque-pointers.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
2-
; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
33

44
; CHECK: %[[TyInt8:.*]] = OpTypeInt 8 0
55
; CHECK: %[[TyInt8Ptr:.*]] = OpTypePointer {{[a-zA-Z]+}} %[[TyInt8]]

0 commit comments

Comments
 (0)