Skip to content

[SPIR-V] Fix illegal OpConstantComposite instruction with non-const constituents in SPIR-V Backend #86352

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

VyacheslavLevytskyy
Copy link
Contributor

This PR fixes illegal use of OpConstantComposite with non-constant constituents. The test attached to the PR is able now to satisfy spirv-val check. Before the fix SPIR-V Backend produced for the attached test case a pattern like

%a = OpVariable %_ptr_CrossWorkgroup_uint CrossWorkgroup %uint_123
%11 = OpConstantComposite %_struct_6 %a %a

so that spirv-val complained with

error: line 25: OpConstantComposite Constituent <id> '10[%a]' is not a constant or undef.
  %11 = OpConstantComposite %_struct_6 %a %a

@llvmbot
Copy link
Member

llvmbot commented Mar 22, 2024

@llvm/pr-subscribers-backend-spir-v

Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)

Changes

This PR fixes illegal use of OpConstantComposite with non-constant constituents. The test attached to the PR is able now to satisfy spirv-val check. Before the fix SPIR-V Backend produced for the attached test case a pattern like

%a = OpVariable %_ptr_CrossWorkgroup_uint CrossWorkgroup %uint_123
%11 = OpConstantComposite %_struct_6 %a %a

so that spirv-val complained with

error: line 25: OpConstantComposite Constituent &lt;id&gt; '10[%a]' is not a constant or undef.
  %11 = OpConstantComposite %_struct_6 %a %a

Full diff: https://github.com/llvm/llvm-project/pull/86352.diff

8 Files Affected:

  • (modified) llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp (+1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h (+9)
  • (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (+1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h (+8)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+75-16)
  • (modified) llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp (+1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td (+1)
  • (modified) llvm/test/CodeGen/SPIRV/pointers/struct-opaque-pointers.ll (+1-1)
diff --git a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp
index d82fb2df4539a3..7c32bb1968ef58 100644
--- a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.cpp
@@ -39,6 +39,7 @@ void SPIRVGeneralDuplicatesTracker::buildDepsGraph(
   prebuildReg2Entry(GT, Reg2Entry);
   prebuildReg2Entry(FT, Reg2Entry);
   prebuildReg2Entry(AT, Reg2Entry);
+  prebuildReg2Entry(MT, Reg2Entry);
   prebuildReg2Entry(ST, Reg2Entry);
 
   for (auto &Op2E : Reg2Entry) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
index 96cc621791e972..2ec3fb35ca0451 100644
--- a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
+++ b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
@@ -262,6 +262,7 @@ class SPIRVGeneralDuplicatesTracker {
   SPIRVDuplicatesTracker<GlobalVariable> GT;
   SPIRVDuplicatesTracker<Function> FT;
   SPIRVDuplicatesTracker<Argument> AT;
+  SPIRVDuplicatesTracker<MachineInstr> MT;
   SPIRVDuplicatesTracker<SPIRV::SpecialTypeDescriptor> ST;
 
   // NOTE: using MOs instead of regs to get rid of MF dependency to be able
@@ -306,6 +307,10 @@ class SPIRVGeneralDuplicatesTracker {
     AT.add(Arg, MF, R);
   }
 
+  void add(const MachineInstr *MI, const MachineFunction *MF, Register R) {
+    MT.add(MI, MF, R);
+  }
+
   void add(const SPIRV::SpecialTypeDescriptor &TD, const MachineFunction *MF,
            Register R) {
     ST.add(TD, MF, R);
@@ -337,6 +342,10 @@ class SPIRVGeneralDuplicatesTracker {
     return AT.find(const_cast<Argument *>(Arg), MF);
   }
 
+  Register find(const MachineInstr *MI, const MachineFunction *MF) {
+    return MT.find(const_cast<MachineInstr *>(MI), MF);
+  }
+
   Register find(const SPIRV::SpecialTypeDescriptor &TD,
                 const MachineFunction *MF) {
     return ST.find(TD, MF);
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index 42f8397a3023b1..0e0ca07fc7f86b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -123,6 +123,7 @@ SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems,
                                                 SPIRVType *ElemType,
                                                 MachineIRBuilder &MIRBuilder) {
   auto EleOpc = ElemType->getOpcode();
+  (void)EleOpc;
   assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat ||
           EleOpc == SPIRV::OpTypeBool) &&
          "Invalid vector element type");
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index da480b22a525f2..ed0f90ff89ce6e 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -94,6 +94,14 @@ class SPIRVGlobalRegistry {
     DT.add(Arg, MF, R);
   }
 
+  void add(const MachineInstr *MI, MachineFunction *MF, Register R) {
+    DT.add(MI, MF, R);
+  }
+
+  Register find(const MachineInstr *MI, MachineFunction *MF) {
+    return DT.find(MI, MF);
+  }
+
   Register find(const Constant *C, MachineFunction *MF) {
     return DT.find(C, MF);
   }
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 5bb8f6084f9671..f905ee6de17ba6 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -231,6 +231,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
   Register buildZerosVal(const SPIRVType *ResType, MachineInstr &I) const;
   Register buildOnesVal(bool AllOnes, const SPIRVType *ResType,
                         MachineInstr &I) const;
+
+  bool wrapIntoSpecConstantOp(MachineInstr &I,
+                              SmallVector<Register> &CompositeArgs) const;
 };
 
 } // end anonymous namespace
@@ -1245,6 +1248,24 @@ static unsigned getArrayComponentCount(MachineRegisterInfo *MRI,
   return N;
 }
 
+// Return true if the type represents a constant register
+static bool isConstReg(MachineRegisterInfo *MRI, SPIRVType *OpDef) {
+  if (OpDef->getOpcode() == SPIRV::ASSIGN_TYPE &&
+      OpDef->getOperand(1).isReg()) {
+    if (SPIRVType *RefDef = MRI->getVRegDef(OpDef->getOperand(1).getReg()))
+      OpDef = RefDef;
+  }
+  return OpDef->getOpcode() == TargetOpcode::G_CONSTANT ||
+         OpDef->getOpcode() == TargetOpcode::G_FCONSTANT;
+}
+
+// Return true if the virtual register represents a constant
+static bool isConstReg(MachineRegisterInfo *MRI, Register OpReg) {
+  if (SPIRVType *OpDef = MRI->getVRegDef(OpReg))
+    return isConstReg(MRI, OpDef);
+  return false;
+}
+
 bool SPIRVInstructionSelector::selectSplatVector(Register ResVReg,
                                                  const SPIRVType *ResType,
                                                  MachineInstr &I) const {
@@ -1262,16 +1283,7 @@ bool SPIRVInstructionSelector::selectSplatVector(Register ResVReg,
 
   // check if we may construct a constant vector
   Register OpReg = I.getOperand(OpIdx).getReg();
-  bool IsConst = false;
-  if (SPIRVType *OpDef = MRI->getVRegDef(OpReg)) {
-    if (OpDef->getOpcode() == SPIRV::ASSIGN_TYPE &&
-        OpDef->getOperand(1).isReg()) {
-      if (SPIRVType *RefDef = MRI->getVRegDef(OpDef->getOperand(1).getReg()))
-        OpDef = RefDef;
-    }
-    IsConst = OpDef->getOpcode() == TargetOpcode::G_CONSTANT ||
-              OpDef->getOpcode() == TargetOpcode::G_FCONSTANT;
-  }
+  bool IsConst = isConstReg(MRI, OpReg);
 
   if (!IsConst && N < 2)
     report_fatal_error(
@@ -1624,6 +1636,49 @@ bool SPIRVInstructionSelector::selectGEP(Register ResVReg,
   return Res.constrainAllUses(TII, TRI, RBI);
 }
 
+// Maybe wrap a value into OpSpecConstantOp
+bool SPIRVInstructionSelector::wrapIntoSpecConstantOp(
+    MachineInstr &I, SmallVector<Register> &CompositeArgs) const {
+  bool Result = true;
+  unsigned Lim = I.getNumExplicitOperands();
+  for (unsigned i = I.getNumExplicitDefs() + 1; i < Lim; ++i) {
+    Register OpReg = I.getOperand(i).getReg();
+    SPIRVType *OpDefine = MRI->getVRegDef(OpReg);
+    SPIRVType *OpType = GR.getSPIRVTypeForVReg(OpReg);
+    if (!OpDefine || !OpType || isConstReg(MRI, OpDefine) ||
+        OpDefine->getOpcode() == TargetOpcode::G_ADDRSPACE_CAST) {
+      // The case of G_ADDRSPACE_CAST inside spv_const_composite() is processed
+      // by selectAddrSpaceCast()
+      CompositeArgs.push_back(OpReg);
+      continue;
+    }
+    MachineFunction *MF = I.getMF();
+    Register WrapReg = GR.find(OpDefine, MF);
+    if (WrapReg.isValid()) {
+      CompositeArgs.push_back(WrapReg);
+      continue;
+    }
+    // Create a new register for the wrapper
+    WrapReg = MRI->createVirtualRegister(&SPIRV::IDRegClass);
+    GR.add(OpDefine, MF, WrapReg);
+    CompositeArgs.push_back(WrapReg);
+    // Decorate the wrapper register and generate a new instruction
+    MRI->setType(WrapReg, LLT::pointer(0, 32));
+    GR.assignSPIRVTypeToVReg(OpType, WrapReg, *MF);
+    MachineBasicBlock &BB = *I.getParent();
+    Result = BuildMI(BB, I, I.getDebugLoc(),
+                     TII.get(SPIRV::OpSpecConstantOp))
+                 .addDef(WrapReg)
+                 .addUse(GR.getSPIRVTypeID(OpType))
+                 .addImm(static_cast<uint32_t>(SPIRV::Opcode::Bitcast))
+                 .addUse(OpReg)
+                 .constrainAllUses(TII, TRI, RBI);
+    if (!Result)
+      break;
+  }
+  return Result;
+}
+
 bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
                                                const SPIRVType *ResType,
                                                MachineInstr &I) const {
@@ -1662,17 +1717,21 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg,
   case Intrinsic::spv_const_composite: {
     // If no values are attached, the composite is null constant.
     bool IsNull = I.getNumExplicitDefs() + 1 == I.getNumExplicitOperands();
-    unsigned Opcode =
-        IsNull ? SPIRV::OpConstantNull : SPIRV::OpConstantComposite;
+    // Select a proper instruction.
+    unsigned Opcode = SPIRV::OpConstantNull;
+    SmallVector<Register> CompositeArgs;
+    if (!IsNull) {
+      Opcode = SPIRV::OpConstantComposite;
+      if (!wrapIntoSpecConstantOp(I, CompositeArgs))
+        return false;
+    }
     auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(Opcode))
                    .addDef(ResVReg)
                    .addUse(GR.getSPIRVTypeID(ResType));
     // skip type MD node we already used when generated assign.type for this
     if (!IsNull) {
-      for (unsigned i = I.getNumExplicitDefs() + 1;
-           i < I.getNumExplicitOperands(); ++i) {
-        MIB.addUse(I.getOperand(i).getReg());
-      }
+      for (Register OpReg : CompositeArgs)
+        MIB.addUse(OpReg);
     }
     return MIB.constrainAllUses(TII, TRI, RBI);
   }
diff --git a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
index d547f91ba4a565..1f0d8d8cd43a8f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp
@@ -543,6 +543,7 @@ static void processSwitches(MachineFunction &MF, SPIRVGlobalRegistry *GR,
       Register Dst = ICMP->getOperand(0).getReg();
       MachineOperand &PredOp = ICMP->getOperand(1);
       const auto CC = static_cast<CmpInst::Predicate>(PredOp.getPredicate());
+      (void)CC;
       assert((CC == CmpInst::ICMP_EQ || CC == CmpInst::ICMP_ULE) &&
              MRI.hasOneUse(Dst) && MRI.hasOneDef(CompareReg));
       uint64_t Value = getIConstVal(ICMP->getOperand(3).getReg(), &MRI);
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index 8dbbd9049844c8..ff102e318469f4 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -1611,3 +1611,4 @@ multiclass OpcodeOperand<bits<32> value> {
 // TODO: implement other mnemonics.
 defm InBoundsPtrAccessChain : OpcodeOperand<70>;
 defm PtrCastToGeneric : OpcodeOperand<121>;
+defm Bitcast : OpcodeOperand<124>;
diff --git a/llvm/test/CodeGen/SPIRV/pointers/struct-opaque-pointers.ll b/llvm/test/CodeGen/SPIRV/pointers/struct-opaque-pointers.ll
index d426fc4dfd4eec..ce3ab8895a5948 100644
--- a/llvm/test/CodeGen/SPIRV/pointers/struct-opaque-pointers.ll
+++ b/llvm/test/CodeGen/SPIRV/pointers/struct-opaque-pointers.ll
@@ -1,5 +1,5 @@
 ; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
-; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
 
 ; CHECK: %[[TyInt8:.*]] = OpTypeInt 8 0
 ; CHECK: %[[TyInt8Ptr:.*]] = OpTypePointer {{[a-zA-Z]+}} %[[TyInt8]]

Copy link

github-actions bot commented Mar 22, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link

✅ With the latest revision this PR passed the Python code formatter.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants