Skip to content

Commit f6aa508

Browse files
[SPIR-V]: Fix creation of constants of array types in SPIRV Backend (#96514)
This PR fixes #96513. The way of creation of array type constant was incorrect: instead of creating [1, 1, 1] or [1, 1, 1, 1, 1, ....] constants, the same [1] constant was always created, substituting original composite constants. This in its turn led to a situation when only one of constants might exist in the code without emitting invalid code, the second constant would be eventually rewritten to the first constant, because a key to address both was an array of a single element (like [1]). This PR fixes the issue and purges from the code unneeded copy/pasted clone of the function that creates an array constant.
1 parent 8395f9c commit f6aa508

File tree

6 files changed

+117
-35
lines changed

6 files changed

+117
-35
lines changed

llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1972,7 +1972,10 @@ static bool buildNDRange(const SPIRV::IncomingCall *Call,
19721972
.addDef(GlobalWorkSize)
19731973
.addUse(GR->getSPIRVTypeID(SpvFieldTy))
19741974
.addUse(GWSPtr);
1975-
Const = GR->getOrCreateConsIntArray(0, MIRBuilder, SpvFieldTy);
1975+
const SPIRVSubtarget &ST =
1976+
cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
1977+
Const = GR->getOrCreateConstIntArray(0, Size, *MIRBuilder.getInsertPt(),
1978+
SpvFieldTy, *ST.getInstrInfo());
19761979
} else {
19771980
Const = GR->buildConstantInt(0, MIRBuilder, SpvTy);
19781981
}

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp

Lines changed: 17 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,7 @@ Register SPIRVGlobalRegistry::getOrCreateCompositeOrNull(
394394
Constant *Val, MachineInstr &I, SPIRVType *SpvType,
395395
const SPIRVInstrInfo &TII, Constant *CA, unsigned BitWidth,
396396
unsigned ElemCnt, bool ZeroAsNull) {
397-
// Find a constant vector in DT or build a new one.
397+
// Find a constant vector or array in DT or build a new one.
398398
Register Res = DT.find(CA, CurMF);
399399
// If no values are attached, the composite is null constant.
400400
bool IsNull = Val->isNullValue() && ZeroAsNull;
@@ -474,20 +474,28 @@ Register SPIRVGlobalRegistry::getOrCreateConstVector(APFloat Val,
474474
ZeroAsNull);
475475
}
476476

477-
Register
478-
SPIRVGlobalRegistry::getOrCreateConsIntArray(uint64_t Val, MachineInstr &I,
479-
SPIRVType *SpvType,
480-
const SPIRVInstrInfo &TII) {
477+
Register SPIRVGlobalRegistry::getOrCreateConstIntArray(
478+
uint64_t Val, size_t Num, MachineInstr &I, SPIRVType *SpvType,
479+
const SPIRVInstrInfo &TII) {
481480
const Type *LLVMTy = getTypeForSPIRVType(SpvType);
482481
assert(LLVMTy->isArrayTy());
483482
const ArrayType *LLVMArrTy = cast<ArrayType>(LLVMTy);
484483
Type *LLVMBaseTy = LLVMArrTy->getElementType();
485-
auto *ConstInt = ConstantInt::get(LLVMBaseTy, Val);
486-
auto *ConstArr =
487-
ConstantArray::get(const_cast<ArrayType *>(LLVMArrTy), {ConstInt});
484+
Constant *CI = ConstantInt::get(LLVMBaseTy, Val);
488485
SPIRVType *SpvBaseTy = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
489486
unsigned BW = getScalarOrVectorBitWidth(SpvBaseTy);
490-
return getOrCreateCompositeOrNull(ConstInt, I, SpvType, TII, ConstArr, BW,
487+
// The following is reasonably unique key that is better that [Val]. The naive
488+
// alternative would be something along the lines of:
489+
// SmallVector<Constant *> NumCI(Num, CI);
490+
// Constant *UniqueKey =
491+
// ConstantArray::get(const_cast<ArrayType*>(LLVMArrTy), NumCI);
492+
// that would be a truly unique but dangerous key, because it could lead to
493+
// the creation of constants of arbitrary length (that is, the parameter of
494+
// memset) which were missing in the original module.
495+
Constant *UniqueKey = ConstantStruct::getAnon(
496+
{PoisonValue::get(const_cast<ArrayType *>(LLVMArrTy)),
497+
ConstantInt::get(LLVMBaseTy, Val), ConstantInt::get(LLVMBaseTy, Num)});
498+
return getOrCreateCompositeOrNull(CI, I, SpvType, TII, UniqueKey, BW,
491499
LLVMArrTy->getNumElements());
492500
}
493501

@@ -545,24 +553,6 @@ SPIRVGlobalRegistry::getOrCreateConsIntVector(uint64_t Val,
545553
SpvType->getOperand(2).getImm());
546554
}
547555

548-
Register
549-
SPIRVGlobalRegistry::getOrCreateConsIntArray(uint64_t Val,
550-
MachineIRBuilder &MIRBuilder,
551-
SPIRVType *SpvType, bool EmitIR) {
552-
const Type *LLVMTy = getTypeForSPIRVType(SpvType);
553-
assert(LLVMTy->isArrayTy());
554-
const ArrayType *LLVMArrTy = cast<ArrayType>(LLVMTy);
555-
Type *LLVMBaseTy = LLVMArrTy->getElementType();
556-
const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val);
557-
auto ConstArr =
558-
ConstantArray::get(const_cast<ArrayType *>(LLVMArrTy), {ConstInt});
559-
SPIRVType *SpvBaseTy = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
560-
unsigned BW = getScalarOrVectorBitWidth(SpvBaseTy);
561-
return getOrCreateIntCompositeOrNull(Val, MIRBuilder, SpvType, EmitIR,
562-
ConstArr, BW,
563-
LLVMArrTy->getNumElements());
564-
}
565-
566556
Register
567557
SPIRVGlobalRegistry::getOrCreateConstNullPtr(MachineIRBuilder &MIRBuilder,
568558
SPIRVType *SpvType) {

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -457,13 +457,11 @@ class SPIRVGlobalRegistry {
457457
Register getOrCreateConstVector(APFloat Val, MachineInstr &I,
458458
SPIRVType *SpvType, const SPIRVInstrInfo &TII,
459459
bool ZeroAsNull = true);
460-
Register getOrCreateConsIntArray(uint64_t Val, MachineInstr &I,
461-
SPIRVType *SpvType,
462-
const SPIRVInstrInfo &TII);
460+
Register getOrCreateConstIntArray(uint64_t Val, size_t Num, MachineInstr &I,
461+
SPIRVType *SpvType,
462+
const SPIRVInstrInfo &TII);
463463
Register getOrCreateConsIntVector(uint64_t Val, MachineIRBuilder &MIRBuilder,
464464
SPIRVType *SpvType, bool EmitIR = true);
465-
Register getOrCreateConsIntArray(uint64_t Val, MachineIRBuilder &MIRBuilder,
466-
SPIRVType *SpvType, bool EmitIR = true);
467465
Register getOrCreateConstNullPtr(MachineIRBuilder &MIRBuilder,
468466
SPIRVType *SpvType);
469467
Register buildConstantSampler(Register Res, unsigned AddrMode, unsigned Param,

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -846,7 +846,7 @@ bool SPIRVInstructionSelector::selectMemOperation(Register ResVReg,
846846
unsigned Num = getIConstVal(I.getOperand(2).getReg(), MRI);
847847
SPIRVType *ValTy = GR.getOrCreateSPIRVIntegerType(8, I, TII);
848848
SPIRVType *ArrTy = GR.getOrCreateSPIRVArrayType(ValTy, Num, I, TII);
849-
Register Const = GR.getOrCreateConsIntArray(Val, I, ArrTy, TII);
849+
Register Const = GR.getOrCreateConstIntArray(Val, Num, I, ArrTy, TII);
850850
SPIRVType *VarTy = GR.getOrCreateSPIRVPointerType(
851851
ArrTy, I, TII, SPIRV::StorageClass::UniformConstant);
852852
// TODO: check if we have such GV, add init, use buildGlobalVariable.

llvm/lib/Target/SPIRV/SPIRVUtils.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,11 @@ SPIRV::MemorySemantics::MemorySemantics getMemSemantics(AtomicOrdering Ord) {
253253

254254
MachineInstr *getDefInstrMaybeConstant(Register &ConstReg,
255255
const MachineRegisterInfo *MRI) {
256-
MachineInstr *ConstInstr = MRI->getVRegDef(ConstReg);
256+
MachineInstr *MI = MRI->getVRegDef(ConstReg);
257+
MachineInstr *ConstInstr =
258+
MI->getOpcode() == SPIRV::G_TRUNC || MI->getOpcode() == SPIRV::G_ZEXT
259+
? MRI->getVRegDef(MI->getOperand(1).getReg())
260+
: MI;
257261
if (auto *GI = dyn_cast<GIntrinsic>(ConstInstr)) {
258262
if (GI->is(Intrinsic::spv_track_constant)) {
259263
ConstReg = ConstInstr->getOperand(2).getReg();
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
3+
4+
; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
5+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}
6+
7+
; CHECK-SPIRV-DAG: %[[#Char:]] = OpTypeInt 8 0
8+
; CHECK-SPIRV-DAG: %[[#Long:]] = OpTypeInt 64 0
9+
; CHECK-SPIRV-DAG: %[[#Int:]] = OpTypeInt 32 0
10+
; CHECK-SPIRV-DAG: %[[#Size3:]] = OpConstant %[[#Int]] 3
11+
; CHECK-SPIRV-DAG: %[[#Arr3:]] = OpTypeArray %[[#Char]] %[[#Size3]]
12+
; CHECK-SPIRV-DAG: %[[#Size16:]] = OpConstant %[[#Int]] 16
13+
; CHECK-SPIRV-DAG: %[[#Arr16:]] = OpTypeArray %[[#Char]] %[[#Size16]]
14+
; CHECK-SPIRV-DAG: %[[#Const3:]] = OpConstant %[[#Long]] 3
15+
; CHECK-SPIRV-DAG: %[[#One:]] = OpConstant %[[#Char]] 1
16+
; CHECK-SPIRV-DAG: %[[#One3:]] = OpConstantComposite %[[#Arr3]] %[[#One]] %[[#One]] %[[#One]]
17+
; CHECK-SPIRV-DAG: %[[#Zero3:]] = OpConstantNull %[[#Arr3]]
18+
; CHECK-SPIRV-DAG: %[[#Const16:]] = OpConstant %[[#Long]] 16
19+
; CHECK-SPIRV-DAG: %[[#One16:]] = OpConstantComposite %[[#Arr16]] %[[#One]] %[[#One]] %[[#One]] %[[#One]] %[[#One]] %[[#One]] %[[#One]] %[[#One]] %[[#One]] %[[#One]] %[[#One]] %[[#One]] %[[#One]] %[[#One]] %[[#One]] %[[#One]]
20+
; CHECK-SPIRV-DAG: %[[#Zero16:]] = OpConstantNull %[[#Arr16]]
21+
22+
; The first set of functions.
23+
; CHECK-SPIRV-DAG: %[[#PtrArr3:]] = OpTypePointer UniformConstant %[[#Arr3]]
24+
; CHECK-SPIRV-DAG: OpVariable %[[#PtrArr3]] UniformConstant %[[#One3]]
25+
; CHECK-SPIRV-DAG: OpVariable %[[#PtrArr3]] UniformConstant %[[#Zero3]]
26+
; CHECK-SPIRV-DAG: %[[#PtrArr16:]] = OpTypePointer UniformConstant %[[#Arr16]]
27+
; CHECK-SPIRV-DAG: OpVariable %[[#PtrArr16]] UniformConstant %[[#One16]]
28+
; CHECK-SPIRV-DAG: OpVariable %[[#PtrArr16]] UniformConstant %[[#Zero16]]
29+
30+
; The second set of functions.
31+
; CHECK-SPIRV-DAG: OpVariable %[[#PtrArr3]] UniformConstant %[[#One3]]
32+
; CHECK-SPIRV-DAG: OpVariable %[[#PtrArr3]] UniformConstant %[[#Zero3]]
33+
; CHECK-SPIRV-DAG: OpVariable %[[#PtrArr16]] UniformConstant %[[#One16]]
34+
; CHECK-SPIRV-DAG: OpVariable %[[#PtrArr16]] UniformConstant %[[#Zero16]]
35+
36+
%Vec3 = type { <3 x i8> }
37+
%Vec16 = type { <16 x i8> }
38+
39+
; CHECK-SPIRV: OpFunction
40+
; CHECK-SPIRV: OpCopyMemorySized %[[#]] %[[#]] %[[#Const3]] Aligned 4
41+
; CHECK-SPIRV: OpCopyMemorySized %[[#]] %[[#]] %[[#Const3]] Aligned 4
42+
; CHECK-SPIRV: OpFunctionEnd
43+
define spir_kernel void @foo(ptr addrspace(1) noundef align 16 %arg) {
44+
%a1 = getelementptr inbounds %Vec3, ptr addrspace(1) %arg, i64 1
45+
call void @llvm.memset.p1.i64(ptr addrspace(1) align 4 %a1, i8 0, i64 3, i1 false)
46+
%a2 = getelementptr inbounds %Vec3, ptr addrspace(1) %arg, i64 1
47+
call void @llvm.memset.p1.i64(ptr addrspace(1) align 4 %a2, i8 1, i64 3, i1 false)
48+
ret void
49+
}
50+
51+
; CHECK-SPIRV: OpFunction
52+
; CHECK-SPIRV: OpCopyMemorySized %[[#]] %[[#]] %[[#Const16]] Aligned 4
53+
; CHECK-SPIRV: OpCopyMemorySized %[[#]] %[[#]] %[[#Const16]] Aligned 4
54+
; CHECK-SPIRV: OpFunctionEnd
55+
define spir_kernel void @bar(ptr addrspace(1) noundef align 16 %arg) {
56+
%a1 = getelementptr inbounds %Vec16, ptr addrspace(1) %arg, i64 1
57+
call void @llvm.memset.p1.i64(ptr addrspace(1) align 4 %a1, i8 0, i64 16, i1 false)
58+
%a2 = getelementptr inbounds %Vec16, ptr addrspace(1) %arg, i64 1
59+
call void @llvm.memset.p1.i64(ptr addrspace(1) align 4 %a2, i8 1, i64 16, i1 false)
60+
ret void
61+
}
62+
63+
; CHECK-SPIRV: OpFunction
64+
; CHECK-SPIRV: OpCopyMemorySized %[[#]] %[[#]] %[[#Const3]] Aligned 4
65+
; CHECK-SPIRV: OpCopyMemorySized %[[#]] %[[#]] %[[#Const3]] Aligned 4
66+
; CHECK-SPIRV: OpFunctionEnd
67+
define spir_kernel void @foo_2(ptr addrspace(1) noundef align 16 %arg) {
68+
%a1 = getelementptr inbounds %Vec3, ptr addrspace(1) %arg, i64 1
69+
call void @llvm.memset.p1.i64(ptr addrspace(1) align 4 %a1, i8 0, i64 3, i1 false)
70+
%a2 = getelementptr inbounds %Vec3, ptr addrspace(1) %arg, i64 1
71+
call void @llvm.memset.p1.i64(ptr addrspace(1) align 4 %a2, i8 1, i64 3, i1 false)
72+
ret void
73+
}
74+
75+
; CHECK-SPIRV: OpFunction
76+
; CHECK-SPIRV: OpCopyMemorySized %[[#]] %[[#]] %[[#Const16]] Aligned 4
77+
; CHECK-SPIRV: OpCopyMemorySized %[[#]] %[[#]] %[[#Const16]] Aligned 4
78+
; CHECK-SPIRV: OpFunctionEnd
79+
define spir_kernel void @bar_2(ptr addrspace(1) noundef align 16 %arg) {
80+
%a1 = getelementptr inbounds %Vec16, ptr addrspace(1) %arg, i64 1
81+
call void @llvm.memset.p1.i64(ptr addrspace(1) align 4 %a1, i8 0, i64 16, i1 false)
82+
%a2 = getelementptr inbounds %Vec16, ptr addrspace(1) %arg, i64 1
83+
call void @llvm.memset.p1.i64(ptr addrspace(1) align 4 %a2, i8 1, i64 16, i1 false)
84+
ret void
85+
}
86+
87+
declare void @llvm.memset.p1.i64(ptr addrspace(1) nocapture writeonly, i8, i64, i1 immarg)

0 commit comments

Comments
 (0)