Skip to content

[SPIR-V]: Fix creation of constants of array types in SPIRV Backend #96514

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1972,7 +1972,10 @@ static bool buildNDRange(const SPIRV::IncomingCall *Call,
.addDef(GlobalWorkSize)
.addUse(GR->getSPIRVTypeID(SpvFieldTy))
.addUse(GWSPtr);
Const = GR->getOrCreateConsIntArray(0, MIRBuilder, SpvFieldTy);
const SPIRVSubtarget &ST =
cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
Const = GR->getOrCreateConstIntArray(0, Size, *MIRBuilder.getInsertPt(),
SpvFieldTy, *ST.getInstrInfo());
} else {
Const = GR->buildConstantInt(0, MIRBuilder, SpvTy);
}
Expand Down
44 changes: 17 additions & 27 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ Register SPIRVGlobalRegistry::getOrCreateCompositeOrNull(
Constant *Val, MachineInstr &I, SPIRVType *SpvType,
const SPIRVInstrInfo &TII, Constant *CA, unsigned BitWidth,
unsigned ElemCnt, bool ZeroAsNull) {
// Find a constant vector in DT or build a new one.
// Find a constant vector or array in DT or build a new one.
Register Res = DT.find(CA, CurMF);
// If no values are attached, the composite is null constant.
bool IsNull = Val->isNullValue() && ZeroAsNull;
Expand Down Expand Up @@ -474,20 +474,28 @@ Register SPIRVGlobalRegistry::getOrCreateConstVector(APFloat Val,
ZeroAsNull);
}

Register
SPIRVGlobalRegistry::getOrCreateConsIntArray(uint64_t Val, MachineInstr &I,
SPIRVType *SpvType,
const SPIRVInstrInfo &TII) {
Register SPIRVGlobalRegistry::getOrCreateConstIntArray(
uint64_t Val, size_t Num, MachineInstr &I, SPIRVType *SpvType,
const SPIRVInstrInfo &TII) {
const Type *LLVMTy = getTypeForSPIRVType(SpvType);
assert(LLVMTy->isArrayTy());
const ArrayType *LLVMArrTy = cast<ArrayType>(LLVMTy);
Type *LLVMBaseTy = LLVMArrTy->getElementType();
auto *ConstInt = ConstantInt::get(LLVMBaseTy, Val);
auto *ConstArr =
ConstantArray::get(const_cast<ArrayType *>(LLVMArrTy), {ConstInt});
Constant *CI = ConstantInt::get(LLVMBaseTy, Val);
SPIRVType *SpvBaseTy = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
unsigned BW = getScalarOrVectorBitWidth(SpvBaseTy);
return getOrCreateCompositeOrNull(ConstInt, I, SpvType, TII, ConstArr, BW,
// The following is reasonably unique key that is better that [Val]. The naive
// alternative would be something along the lines of:
// SmallVector<Constant *> NumCI(Num, CI);
// Constant *UniqueKey =
// ConstantArray::get(const_cast<ArrayType*>(LLVMArrTy), NumCI);
// that would be a truly unique but dangerous key, because it could lead to
// the creation of constants of arbitrary length (that is, the parameter of
// memset) which were missing in the original module.
Constant *UniqueKey = ConstantStruct::getAnon(
{PoisonValue::get(const_cast<ArrayType *>(LLVMArrTy)),
ConstantInt::get(LLVMBaseTy, Val), ConstantInt::get(LLVMBaseTy, Num)});
return getOrCreateCompositeOrNull(CI, I, SpvType, TII, UniqueKey, BW,
LLVMArrTy->getNumElements());
}

Expand Down Expand Up @@ -545,24 +553,6 @@ SPIRVGlobalRegistry::getOrCreateConsIntVector(uint64_t Val,
SpvType->getOperand(2).getImm());
}

Register
SPIRVGlobalRegistry::getOrCreateConsIntArray(uint64_t Val,
MachineIRBuilder &MIRBuilder,
SPIRVType *SpvType, bool EmitIR) {
const Type *LLVMTy = getTypeForSPIRVType(SpvType);
assert(LLVMTy->isArrayTy());
const ArrayType *LLVMArrTy = cast<ArrayType>(LLVMTy);
Type *LLVMBaseTy = LLVMArrTy->getElementType();
const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val);
auto ConstArr =
ConstantArray::get(const_cast<ArrayType *>(LLVMArrTy), {ConstInt});
SPIRVType *SpvBaseTy = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
unsigned BW = getScalarOrVectorBitWidth(SpvBaseTy);
return getOrCreateIntCompositeOrNull(Val, MIRBuilder, SpvType, EmitIR,
ConstArr, BW,
LLVMArrTy->getNumElements());
}

Register
SPIRVGlobalRegistry::getOrCreateConstNullPtr(MachineIRBuilder &MIRBuilder,
SPIRVType *SpvType) {
Expand Down
8 changes: 3 additions & 5 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -457,13 +457,11 @@ class SPIRVGlobalRegistry {
Register getOrCreateConstVector(APFloat Val, MachineInstr &I,
SPIRVType *SpvType, const SPIRVInstrInfo &TII,
bool ZeroAsNull = true);
Register getOrCreateConsIntArray(uint64_t Val, MachineInstr &I,
SPIRVType *SpvType,
const SPIRVInstrInfo &TII);
Register getOrCreateConstIntArray(uint64_t Val, size_t Num, MachineInstr &I,
SPIRVType *SpvType,
const SPIRVInstrInfo &TII);
Register getOrCreateConsIntVector(uint64_t Val, MachineIRBuilder &MIRBuilder,
SPIRVType *SpvType, bool EmitIR = true);
Register getOrCreateConsIntArray(uint64_t Val, MachineIRBuilder &MIRBuilder,
SPIRVType *SpvType, bool EmitIR = true);
Register getOrCreateConstNullPtr(MachineIRBuilder &MIRBuilder,
SPIRVType *SpvType);
Register buildConstantSampler(Register Res, unsigned AddrMode, unsigned Param,
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,7 @@ bool SPIRVInstructionSelector::selectMemOperation(Register ResVReg,
unsigned Num = getIConstVal(I.getOperand(2).getReg(), MRI);
SPIRVType *ValTy = GR.getOrCreateSPIRVIntegerType(8, I, TII);
SPIRVType *ArrTy = GR.getOrCreateSPIRVArrayType(ValTy, Num, I, TII);
Register Const = GR.getOrCreateConsIntArray(Val, I, ArrTy, TII);
Register Const = GR.getOrCreateConstIntArray(Val, Num, I, ArrTy, TII);
SPIRVType *VarTy = GR.getOrCreateSPIRVPointerType(
ArrTy, I, TII, SPIRV::StorageClass::UniformConstant);
// TODO: check if we have such GV, add init, use buildGlobalVariable.
Expand Down
6 changes: 5 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,11 @@ SPIRV::MemorySemantics::MemorySemantics getMemSemantics(AtomicOrdering Ord) {

MachineInstr *getDefInstrMaybeConstant(Register &ConstReg,
const MachineRegisterInfo *MRI) {
MachineInstr *ConstInstr = MRI->getVRegDef(ConstReg);
MachineInstr *MI = MRI->getVRegDef(ConstReg);
MachineInstr *ConstInstr =
MI->getOpcode() == SPIRV::G_TRUNC || MI->getOpcode() == SPIRV::G_ZEXT
? MRI->getVRegDef(MI->getOperand(1).getReg())
: MI;
if (auto *GI = dyn_cast<GIntrinsic>(ConstInstr)) {
if (GI->is(Intrinsic::spv_track_constant)) {
ConstReg = ConstInstr->getOperand(2).getReg();
Expand Down
87 changes: 87 additions & 0 deletions llvm/test/CodeGen/SPIRV/var-uniform-const.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
; RUN: llc -O0 -mtriple=spirv64-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}

; RUN: llc -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s --check-prefix=CHECK-SPIRV
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv32-unknown-unknown %s -o - -filetype=obj | spirv-val %}

; CHECK-SPIRV-DAG: %[[#Char:]] = OpTypeInt 8 0
; CHECK-SPIRV-DAG: %[[#Long:]] = OpTypeInt 64 0
; CHECK-SPIRV-DAG: %[[#Int:]] = OpTypeInt 32 0
; CHECK-SPIRV-DAG: %[[#Size3:]] = OpConstant %[[#Int]] 3
; CHECK-SPIRV-DAG: %[[#Arr3:]] = OpTypeArray %[[#Char]] %[[#Size3]]
; CHECK-SPIRV-DAG: %[[#Size16:]] = OpConstant %[[#Int]] 16
; CHECK-SPIRV-DAG: %[[#Arr16:]] = OpTypeArray %[[#Char]] %[[#Size16]]
; CHECK-SPIRV-DAG: %[[#Const3:]] = OpConstant %[[#Long]] 3
; CHECK-SPIRV-DAG: %[[#One:]] = OpConstant %[[#Char]] 1
; CHECK-SPIRV-DAG: %[[#One3:]] = OpConstantComposite %[[#Arr3]] %[[#One]] %[[#One]] %[[#One]]
; CHECK-SPIRV-DAG: %[[#Zero3:]] = OpConstantNull %[[#Arr3]]
; CHECK-SPIRV-DAG: %[[#Const16:]] = OpConstant %[[#Long]] 16
; CHECK-SPIRV-DAG: %[[#One16:]] = OpConstantComposite %[[#Arr16]] %[[#One]] %[[#One]] %[[#One]] %[[#One]] %[[#One]] %[[#One]] %[[#One]] %[[#One]] %[[#One]] %[[#One]] %[[#One]] %[[#One]] %[[#One]] %[[#One]] %[[#One]] %[[#One]]
; CHECK-SPIRV-DAG: %[[#Zero16:]] = OpConstantNull %[[#Arr16]]

; The first set of functions.
; CHECK-SPIRV-DAG: %[[#PtrArr3:]] = OpTypePointer UniformConstant %[[#Arr3]]
; CHECK-SPIRV-DAG: OpVariable %[[#PtrArr3]] UniformConstant %[[#One3]]
; CHECK-SPIRV-DAG: OpVariable %[[#PtrArr3]] UniformConstant %[[#Zero3]]
; CHECK-SPIRV-DAG: %[[#PtrArr16:]] = OpTypePointer UniformConstant %[[#Arr16]]
; CHECK-SPIRV-DAG: OpVariable %[[#PtrArr16]] UniformConstant %[[#One16]]
; CHECK-SPIRV-DAG: OpVariable %[[#PtrArr16]] UniformConstant %[[#Zero16]]

; The second set of functions.
; CHECK-SPIRV-DAG: OpVariable %[[#PtrArr3]] UniformConstant %[[#One3]]
; CHECK-SPIRV-DAG: OpVariable %[[#PtrArr3]] UniformConstant %[[#Zero3]]
; CHECK-SPIRV-DAG: OpVariable %[[#PtrArr16]] UniformConstant %[[#One16]]
; CHECK-SPIRV-DAG: OpVariable %[[#PtrArr16]] UniformConstant %[[#Zero16]]

%Vec3 = type { <3 x i8> }
%Vec16 = type { <16 x i8> }

; CHECK-SPIRV: OpFunction
; CHECK-SPIRV: OpCopyMemorySized %[[#]] %[[#]] %[[#Const3]] Aligned 4
; CHECK-SPIRV: OpCopyMemorySized %[[#]] %[[#]] %[[#Const3]] Aligned 4
; CHECK-SPIRV: OpFunctionEnd
define spir_kernel void @foo(ptr addrspace(1) noundef align 16 %arg) {
%a1 = getelementptr inbounds %Vec3, ptr addrspace(1) %arg, i64 1
call void @llvm.memset.p1.i64(ptr addrspace(1) align 4 %a1, i8 0, i64 3, i1 false)
%a2 = getelementptr inbounds %Vec3, ptr addrspace(1) %arg, i64 1
call void @llvm.memset.p1.i64(ptr addrspace(1) align 4 %a2, i8 1, i64 3, i1 false)
ret void
}

; CHECK-SPIRV: OpFunction
; CHECK-SPIRV: OpCopyMemorySized %[[#]] %[[#]] %[[#Const16]] Aligned 4
; CHECK-SPIRV: OpCopyMemorySized %[[#]] %[[#]] %[[#Const16]] Aligned 4
; CHECK-SPIRV: OpFunctionEnd
define spir_kernel void @bar(ptr addrspace(1) noundef align 16 %arg) {
%a1 = getelementptr inbounds %Vec16, ptr addrspace(1) %arg, i64 1
call void @llvm.memset.p1.i64(ptr addrspace(1) align 4 %a1, i8 0, i64 16, i1 false)
%a2 = getelementptr inbounds %Vec16, ptr addrspace(1) %arg, i64 1
call void @llvm.memset.p1.i64(ptr addrspace(1) align 4 %a2, i8 1, i64 16, i1 false)
ret void
}

; CHECK-SPIRV: OpFunction
; CHECK-SPIRV: OpCopyMemorySized %[[#]] %[[#]] %[[#Const3]] Aligned 4
; CHECK-SPIRV: OpCopyMemorySized %[[#]] %[[#]] %[[#Const3]] Aligned 4
; CHECK-SPIRV: OpFunctionEnd
define spir_kernel void @foo_2(ptr addrspace(1) noundef align 16 %arg) {
%a1 = getelementptr inbounds %Vec3, ptr addrspace(1) %arg, i64 1
call void @llvm.memset.p1.i64(ptr addrspace(1) align 4 %a1, i8 0, i64 3, i1 false)
%a2 = getelementptr inbounds %Vec3, ptr addrspace(1) %arg, i64 1
call void @llvm.memset.p1.i64(ptr addrspace(1) align 4 %a2, i8 1, i64 3, i1 false)
ret void
}

; CHECK-SPIRV: OpFunction
; CHECK-SPIRV: OpCopyMemorySized %[[#]] %[[#]] %[[#Const16]] Aligned 4
; CHECK-SPIRV: OpCopyMemorySized %[[#]] %[[#]] %[[#Const16]] Aligned 4
; CHECK-SPIRV: OpFunctionEnd
define spir_kernel void @bar_2(ptr addrspace(1) noundef align 16 %arg) {
%a1 = getelementptr inbounds %Vec16, ptr addrspace(1) %arg, i64 1
call void @llvm.memset.p1.i64(ptr addrspace(1) align 4 %a1, i8 0, i64 16, i1 false)
%a2 = getelementptr inbounds %Vec16, ptr addrspace(1) %arg, i64 1
call void @llvm.memset.p1.i64(ptr addrspace(1) align 4 %a2, i8 1, i64 16, i1 false)
ret void
}

declare void @llvm.memset.p1.i64(ptr addrspace(1) nocapture writeonly, i8, i64, i1 immarg)
Loading