Skip to content

[SPIR-V]: Fix creation of constants of array types in SPIRV Backend #96514

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

VyacheslavLevytskyy
Copy link
Contributor

This PR fixes #96513.

The way of creation of array type constant was incorrect: instead of creating [1, 1, 1] or [1, 1, 1, 1, 1, ....] constants, the same [1] constant was always created, substituting original composite constants. This in its turn led to a situation when only one of constants might exist in the code without emitting invalid code, the second constant would be eventually rewritten to the first constant, because a key to address both was an array of a single element (like [1]).

This PR fixes the issue and purges from the code unneeded copy/pasted clone of the function that creates an array constant.

@llvmbot
Copy link
Member

llvmbot commented Jun 24, 2024

@llvm/pr-subscribers-backend-spir-v

Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)

Changes

This PR fixes #96513.

The way of creation of array type constant was incorrect: instead of creating [1, 1, 1] or [1, 1, 1, 1, 1, ....] constants, the same [1] constant was always created, substituting original composite constants. This in its turn led to a situation when only one of constants might exist in the code without emitting invalid code, the second constant would be eventually rewritten to the first constant, because a key to address both was an array of a single element (like [1]).

This PR fixes the issue and purges from the code unneeded copy/pasted clone of the function that creates an array constant.


Full diff: https://github.com/llvm/llvm-project/pull/96514.diff

5 Files Affected:

  • (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp (+4-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (+9-27)
  • (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h (+3-5)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+1-1)
  • (added) llvm/test/CodeGen/SPIRV/var-uniform-const.ll (+55)
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index f5f36075d4a31..71168d2d7dacd 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -1972,7 +1972,10 @@ static bool buildNDRange(const SPIRV::IncomingCall *Call,
           .addDef(GlobalWorkSize)
           .addUse(GR->getSPIRVTypeID(SpvFieldTy))
           .addUse(GWSPtr);
-      Const = GR->getOrCreateConsIntArray(0, MIRBuilder, SpvFieldTy);
+      const SPIRVSubtarget &ST =
+          cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
+      Const = GR->getOrCreateConstIntArray(0, Size, *MIRBuilder.getInsertPt(),
+                                           SpvFieldTy, *ST.getInstrInfo());
     } else {
       Const = GR->buildConstantInt(0, MIRBuilder, SpvTy);
     }
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index b8710d24bff94..55b1b3684a393 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -394,7 +394,7 @@ Register SPIRVGlobalRegistry::getOrCreateCompositeOrNull(
     Constant *Val, MachineInstr &I, SPIRVType *SpvType,
     const SPIRVInstrInfo &TII, Constant *CA, unsigned BitWidth,
     unsigned ElemCnt, bool ZeroAsNull) {
-  // Find a constant vector in DT or build a new one.
+  // Find a constant vector or array in DT or build a new one.
   Register Res = DT.find(CA, CurMF);
   // If no values are attached, the composite is null constant.
   bool IsNull = Val->isNullValue() && ZeroAsNull;
@@ -474,20 +474,20 @@ Register SPIRVGlobalRegistry::getOrCreateConstVector(APFloat Val,
                                     ZeroAsNull);
 }
 
-Register
-SPIRVGlobalRegistry::getOrCreateConsIntArray(uint64_t Val, MachineInstr &I,
-                                             SPIRVType *SpvType,
-                                             const SPIRVInstrInfo &TII) {
+Register SPIRVGlobalRegistry::getOrCreateConstIntArray(
+    uint64_t Val, size_t Num, MachineInstr &I, SPIRVType *SpvType,
+    const SPIRVInstrInfo &TII) {
   const Type *LLVMTy = getTypeForSPIRVType(SpvType);
   assert(LLVMTy->isArrayTy());
   const ArrayType *LLVMArrTy = cast<ArrayType>(LLVMTy);
   Type *LLVMBaseTy = LLVMArrTy->getElementType();
-  auto *ConstInt = ConstantInt::get(LLVMBaseTy, Val);
-  auto *ConstArr =
-      ConstantArray::get(const_cast<ArrayType *>(LLVMArrTy), {ConstInt});
+  Constant *CI = ConstantInt::get(LLVMBaseTy, Val);
+  SmallVector<Constant *> NumCI(Num, CI);
+  Constant *ConstArr =
+      ConstantArray::get(const_cast<ArrayType *>(LLVMArrTy), NumCI);
   SPIRVType *SpvBaseTy = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
   unsigned BW = getScalarOrVectorBitWidth(SpvBaseTy);
-  return getOrCreateCompositeOrNull(ConstInt, I, SpvType, TII, ConstArr, BW,
+  return getOrCreateCompositeOrNull(CI, I, SpvType, TII, ConstArr, BW,
                                     LLVMArrTy->getNumElements());
 }
 
@@ -545,24 +545,6 @@ SPIRVGlobalRegistry::getOrCreateConsIntVector(uint64_t Val,
                                        SpvType->getOperand(2).getImm());
 }
 
-Register
-SPIRVGlobalRegistry::getOrCreateConsIntArray(uint64_t Val,
-                                             MachineIRBuilder &MIRBuilder,
-                                             SPIRVType *SpvType, bool EmitIR) {
-  const Type *LLVMTy = getTypeForSPIRVType(SpvType);
-  assert(LLVMTy->isArrayTy());
-  const ArrayType *LLVMArrTy = cast<ArrayType>(LLVMTy);
-  Type *LLVMBaseTy = LLVMArrTy->getElementType();
-  const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val);
-  auto ConstArr =
-      ConstantArray::get(const_cast<ArrayType *>(LLVMArrTy), {ConstInt});
-  SPIRVType *SpvBaseTy = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
-  unsigned BW = getScalarOrVectorBitWidth(SpvBaseTy);
-  return getOrCreateIntCompositeOrNull(Val, MIRBuilder, SpvType, EmitIR,
-                                       ConstArr, BW,
-                                       LLVMArrTy->getNumElements());
-}
-
 Register
 SPIRVGlobalRegistry::getOrCreateConstNullPtr(MachineIRBuilder &MIRBuilder,
                                              SPIRVType *SpvType) {
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index 990d3328f6a30..a45e1ccd0717f 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -457,13 +457,11 @@ class SPIRVGlobalRegistry {
   Register getOrCreateConstVector(APFloat Val, MachineInstr &I,
                                   SPIRVType *SpvType, const SPIRVInstrInfo &TII,
                                   bool ZeroAsNull = true);
-  Register getOrCreateConsIntArray(uint64_t Val, MachineInstr &I,
-                                   SPIRVType *SpvType,
-                                   const SPIRVInstrInfo &TII);
+  Register getOrCreateConstIntArray(uint64_t Val, size_t Num, MachineInstr &I,
+                                    SPIRVType *SpvType,
+                                    const SPIRVInstrInfo &TII);
   Register getOrCreateConsIntVector(uint64_t Val, MachineIRBuilder &MIRBuilder,
                                     SPIRVType *SpvType, bool EmitIR = true);
-  Register getOrCreateConsIntArray(uint64_t Val, MachineIRBuilder &MIRBuilder,
-                                   SPIRVType *SpvType, bool EmitIR = true);
   Register getOrCreateConstNullPtr(MachineIRBuilder &MIRBuilder,
                                    SPIRVType *SpvType);
   Register buildConstantSampler(Register Res, unsigned AddrMode, unsigned Param,
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index 41a0d2c5e2f35..f5b6bcd64f480 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -846,7 +846,7 @@ bool SPIRVInstructionSelector::selectMemOperation(Register ResVReg,
     unsigned Num = getIConstVal(I.getOperand(2).getReg(), MRI);
     SPIRVType *ValTy = GR.getOrCreateSPIRVIntegerType(8, I, TII);
     SPIRVType *ArrTy = GR.getOrCreateSPIRVArrayType(ValTy, Num, I, TII);
-    Register Const = GR.getOrCreateConsIntArray(Val, I, ArrTy, TII);
+    Register Const = GR.getOrCreateConstIntArray(Val, Num, I, ArrTy, TII);
     SPIRVType *VarTy = GR.getOrCreateSPIRVPointerType(
         ArrTy, I, TII, SPIRV::StorageClass::UniformConstant);
     // TODO: check if we have such GV, add init, use buildGlobalVariable.
diff --git a/llvm/test/CodeGen/SPIRV/var-uniform-const.ll b/llvm/test/CodeGen/SPIRV/var-uniform-const.ll
new file mode 100644
index 0000000000000..da02b444e1f4a
--- /dev/null
+++ b/llvm/test/CodeGen/SPIRV/var-uniform-const.ll
@@ -0,0 +1,55 @@
+; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; TODO: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
+; TODO: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
+
+; CHECK-SPIRV-DAG: %[[#Char:]] = OpTypeInt 8 0
+; CHECK-SPIRV-DAG: %[[#Long:]] = OpTypeInt 64 0
+; CHECK-SPIRV-DAG: %[[#Int:]] = OpTypeInt 32 0
+; CHECK-SPIRV-DAG: %[[#Size3:]] = OpConstant %[[#Int]] 3
+; CHECK-SPIRV-DAG: %[[#Arr3:]] = OpTypeArray %[[#Char]] %[[#Size3]]
+; CHECK-SPIRV-DAG: %[[#Size16:]] = OpConstant %[[#Int]] 16
+; CHECK-SPIRV-DAG: %[[#Arr16:]] = OpTypeArray %[[#Char]] %[[#Size16]]
+; CHECK-SPIRV-DAG: %[[#Const3:]] = OpConstant %[[#Long]] 3
+; CHECK-SPIRV-DAG: %[[#One:]] = OpConstant %[[#Char]] 1
+; CHECK-SPIRV-DAG: %[[#One3:]] = OpConstantComposite %[[#Arr3]] %[[#One]] %[[#One]] %[[#One]]
+; CHECK-SPIRV-DAG: %[[#Zero3:]] = OpConstantNull %[[#Arr3]]
+; CHECK-SPIRV-DAG: %[[#Const16:]] = OpConstant %[[#Long]] 16
+; CHECK-SPIRV-DAG: %[[#One16:]] = OpConstantComposite %[[#Arr16]] %[[#One]] %[[#One]] %[[#One]] %[[#One]] %[[#One]] %[[#One]] %[[#One]] %[[#One]] %[[#One]] %[[#One]] %[[#One]] %[[#One]] %[[#One]] %[[#One]] %[[#One]] %[[#One]]
+; CHECK-SPIRV-DAG: %[[#Zero16:]] = OpConstantNull %[[#Arr16]]
+; CHECK-SPIRV-DAG: %[[#PtrArr3:]] = OpTypePointer UniformConstant %[[#Arr3]]
+; CHECK-SPIRV-DAG: OpVariable %[[#PtrArr3]] UniformConstant %[[#One3]]
+; CHECK-SPIRV-DAG: OpVariable %[[#PtrArr3]] UniformConstant %[[#Zero3]]
+; CHECK-SPIRV-DAG: %[[#PtrArr16:]] = OpTypePointer UniformConstant %[[#Arr16]]
+; CHECK-SPIRV-DAG: OpVariable %[[#PtrArr16]] UniformConstant %[[#One16]]
+; CHECK-SPIRV-DAG: OpVariable %[[#PtrArr16]] UniformConstant %[[#Zero16]]
+; CHECK-SPIRV: OpFunction
+; CHECK-SPIRV: OpCopyMemorySized %[[#]] %[[#]] %[[#Const3]] Aligned 4
+; CHECK-SPIRV: OpCopyMemorySized %[[#]] %[[#]] %[[#Const3]] Aligned 4
+; CHECK-SPIRV: OpFunctionEnd
+; CHECK-SPIRV: OpFunction
+; CHECK-SPIRV: OpCopyMemorySized %[[#]] %[[#]] %[[#Const16]] Aligned 4
+; CHECK-SPIRV: OpCopyMemorySized %[[#]] %[[#]] %[[#Const16]] Aligned 4
+; CHECK-SPIRV: OpFunctionEnd
+
+%Vec3 = type { <3 x i8> }
+%Vec16 = type { <16 x i8> }
+
+define spir_kernel void @foo(ptr addrspace(1) noundef align 16 %arg) {
+  %a1 = getelementptr inbounds %Vec3, ptr addrspace(1) %arg, i64 1
+  call void @llvm.memset.p1.i64(ptr addrspace(1) align 4 %a1, i8 0, i64 3, i1 false)
+  %a2 = getelementptr inbounds %Vec3, ptr addrspace(1) %arg, i64 1
+  call void @llvm.memset.p1.i64(ptr addrspace(1) align 4 %a2, i8 1, i64 3, i1 false)
+  ret void
+}
+
+define spir_kernel void @bar(ptr addrspace(1) noundef align 16 %arg) {
+  %a1 = getelementptr inbounds %Vec16, ptr addrspace(1) %arg, i64 1
+  call void @llvm.memset.p1.i64(ptr addrspace(1) align 4 %a1, i8 0, i64 16, i1 false)
+  %a2 = getelementptr inbounds %Vec16, ptr addrspace(1) %arg, i64 1
+  call void @llvm.memset.p1.i64(ptr addrspace(1) align 4 %a2, i8 1, i64 16, i1 false)
+  ret void
+}
+
+declare void @llvm.memset.p1.i64(ptr addrspace(1) nocapture writeonly, i8, i64, i1 immarg)

Copy link
Member

@michalpaszkowski michalpaszkowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@michalpaszkowski
Copy link
Member

For unknown reason (to me), the previously (and still currently) passing OpenCL CTS test relationals/shuffle_array_cast has significantly shorter run time after this PR (~50% difference).

@VyacheslavLevytskyy VyacheslavLevytskyy merged commit f6aa508 into llvm:main Jun 25, 2024
8 checks passed
AlexisPerry pushed a commit to llvm-project-tlp/llvm-project that referenced this pull request Jul 9, 2024
…lvm#96514)

This PR fixes llvm#96513.

The way of creation of array type constant was incorrect: instead of
creating [1, 1, 1] or [1, 1, 1, 1, 1, ....] constants, the same [1]
constant was always created, substituting original composite constants.
This in its turn led to a situation when only one of constants might
exist in the code without emitting invalid code, the second constant
would be eventually rewritten to the first constant, because a key to
address both was an array of a single element (like [1]).

This PR fixes the issue and purges from the code unneeded copy/pasted
clone of the function that creates an array constant.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
3 participants