Skip to content

Commit 489db65

Browse files
[SPIR-V] Emit Alignment decoration for alloca instructions and improve type inference (#118520)
This PR is to fix the following issues: * the SPIR-V Backend didn't generate Alignment decoration for alloca instructions, * we need to use types from demangled function declarations to specify types for opaque pointers.
1 parent b569ec6 commit 489db65

File tree

12 files changed

+271
-114
lines changed

12 files changed

+271
-114
lines changed

llvm/include/llvm/IR/IntrinsicsSPIRV.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ let TargetPrefix = "spv" in {
3636
def int_spv_selection_merge : Intrinsic<[], [llvm_vararg_ty]>;
3737
def int_spv_cmpxchg : Intrinsic<[llvm_i32_ty], [llvm_any_ty, llvm_vararg_ty]>;
3838
def int_spv_unreachable : Intrinsic<[], []>;
39-
def int_spv_alloca : Intrinsic<[llvm_any_ty], []>;
40-
def int_spv_alloca_array : Intrinsic<[llvm_any_ty], [llvm_anyint_ty]>;
39+
def int_spv_alloca : Intrinsic<[llvm_any_ty], [llvm_i8_ty], [ImmArg<ArgIndex<0>>]>;
40+
def int_spv_alloca_array : Intrinsic<[llvm_any_ty], [llvm_anyint_ty, llvm_i8_ty], [ImmArg<ArgIndex<1>>]>;
4141
def int_spv_undef : Intrinsic<[llvm_i32_ty], []>;
4242
def int_spv_inline_asm : Intrinsic<[], [llvm_metadata_ty, llvm_metadata_ty, llvm_vararg_ty]>;
4343

llvm/lib/Target/SPIRV/MCTargetDesc/SPIRVBaseInfo.cpp

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,15 @@ getSymbolicOperandMnemonic(SPIRV::OperandCategory::OperandCategory Category,
7676
const SPIRV::SymbolicOperand *EnumValueInCategory =
7777
SPIRV::lookupSymbolicOperandByCategory(Category);
7878

79+
auto TableEnd = ArrayRef(SPIRV::SymbolicOperands).end();
7980
while (EnumValueInCategory && EnumValueInCategory->Category == Category) {
8081
if ((EnumValueInCategory->Value != 0) &&
8182
(Value & EnumValueInCategory->Value)) {
8283
Name += Separator + EnumValueInCategory->Mnemonic.str();
8384
Separator = "|";
8485
}
85-
++EnumValueInCategory;
86+
if (++EnumValueInCategory == TableEnd)
87+
break;
8688
}
8789

8890
return Name;
@@ -115,15 +117,16 @@ getSymbolicOperandMaxVersion(SPIRV::OperandCategory::OperandCategory Category,
115117
CapabilityList
116118
getSymbolicOperandCapabilities(SPIRV::OperandCategory::OperandCategory Category,
117119
uint32_t Value) {
120+
CapabilityList Capabilities;
118121
const SPIRV::CapabilityEntry *Capability =
119122
SPIRV::lookupCapabilityByCategoryAndValue(Category, Value);
120-
121-
CapabilityList Capabilities;
123+
auto TableEnd = ArrayRef(SPIRV::CapabilityEntries).end();
122124
while (Capability && Capability->Category == Category &&
123125
Capability->Value == Value) {
124126
Capabilities.push_back(
125127
static_cast<SPIRV::Capability::Capability>(Capability->ReqCapability));
126-
++Capability;
128+
if (++Capability == TableEnd)
129+
break;
127130
}
128131

129132
return Capabilities;
@@ -136,16 +139,15 @@ getCapabilitiesEnabledByExtension(SPIRV::Extension::Extension Extension) {
136139
Extension, SPIRV::OperandCategory::CapabilityOperand);
137140

138141
CapabilityList Capabilities;
142+
auto TableEnd = ArrayRef(SPIRV::ExtensionEntries).end();
139143
while (Entry &&
140144
Entry->Category == SPIRV::OperandCategory::CapabilityOperand) {
141145
// Some capabilities' codes might go not in order.
142-
if (Entry->ReqExtension != Extension) {
143-
++Entry;
144-
continue;
145-
}
146-
Capabilities.push_back(
147-
static_cast<SPIRV::Capability::Capability>(Entry->Value));
148-
++Entry;
146+
if (Entry->ReqExtension == Extension)
147+
Capabilities.push_back(
148+
static_cast<SPIRV::Capability::Capability>(Entry->Value));
149+
if (++Entry == TableEnd)
150+
break;
149151
}
150152

151153
return Capabilities;
@@ -158,11 +160,13 @@ getSymbolicOperandExtensions(SPIRV::OperandCategory::OperandCategory Category,
158160
SPIRV::lookupExtensionByCategoryAndValue(Category, Value);
159161

160162
ExtensionList Extensions;
163+
auto TableEnd = ArrayRef(SPIRV::ExtensionEntries).end();
161164
while (Extension && Extension->Category == Category &&
162165
Extension->Value == Value) {
163166
Extensions.push_back(
164167
static_cast<SPIRV::Extension::Extension>(Extension->ReqExtension));
165-
++Extension;
168+
if (++Extension == TableEnd)
169+
break;
166170
}
167171

168172
return Extensions;

llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2664,16 +2664,7 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
26642664
return false;
26652665
}
26662666

2667-
Type *parseBuiltinCallArgumentBaseType(const StringRef DemangledCall,
2668-
unsigned ArgIdx, LLVMContext &Ctx) {
2669-
SmallVector<StringRef, 10> BuiltinArgsTypeStrs;
2670-
StringRef BuiltinArgs =
2671-
DemangledCall.slice(DemangledCall.find('(') + 1, DemangledCall.find(')'));
2672-
BuiltinArgs.split(BuiltinArgsTypeStrs, ',', -1, false);
2673-
if (ArgIdx >= BuiltinArgsTypeStrs.size())
2674-
return nullptr;
2675-
StringRef TypeStr = BuiltinArgsTypeStrs[ArgIdx].trim();
2676-
2667+
Type *parseBuiltinCallArgumentType(StringRef TypeStr, LLVMContext &Ctx) {
26772668
// Parse strings representing OpenCL builtin types.
26782669
if (hasBuiltinTypePrefix(TypeStr)) {
26792670
// OpenCL builtin types in demangled call strings have the following format:
@@ -2717,6 +2708,29 @@ Type *parseBuiltinCallArgumentBaseType(const StringRef DemangledCall,
27172708
return BaseType;
27182709
}
27192710

2711+
bool parseBuiltinTypeStr(SmallVector<StringRef, 10> &BuiltinArgsTypeStrs,
2712+
const StringRef DemangledCall, LLVMContext &Ctx) {
2713+
auto Pos1 = DemangledCall.find('(');
2714+
if (Pos1 == StringRef::npos)
2715+
return false;
2716+
auto Pos2 = DemangledCall.find(')');
2717+
if (Pos2 == StringRef::npos || Pos1 > Pos2)
2718+
return false;
2719+
DemangledCall.slice(Pos1 + 1, Pos2)
2720+
.split(BuiltinArgsTypeStrs, ',', -1, false);
2721+
return true;
2722+
}
2723+
2724+
Type *parseBuiltinCallArgumentBaseType(const StringRef DemangledCall,
2725+
unsigned ArgIdx, LLVMContext &Ctx) {
2726+
SmallVector<StringRef, 10> BuiltinArgsTypeStrs;
2727+
parseBuiltinTypeStr(BuiltinArgsTypeStrs, DemangledCall, Ctx);
2728+
if (ArgIdx >= BuiltinArgsTypeStrs.size())
2729+
return nullptr;
2730+
StringRef TypeStr = BuiltinArgsTypeStrs[ArgIdx].trim();
2731+
return parseBuiltinCallArgumentType(TypeStr, Ctx);
2732+
}
2733+
27202734
struct BuiltinType {
27212735
StringRef Name;
27222736
uint32_t Opcode;

llvm/lib/Target/SPIRV/SPIRVBuiltins.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ mapBuiltinToOpcode(const StringRef DemangledCall,
5656
/// \p ArgIdx is the index of the argument to parse.
5757
Type *parseBuiltinCallArgumentBaseType(const StringRef DemangledCall,
5858
unsigned ArgIdx, LLVMContext &Ctx);
59+
bool parseBuiltinTypeStr(SmallVector<StringRef, 10> &BuiltinArgsTypeStrs,
60+
const StringRef DemangledCall, LLVMContext &Ctx);
61+
Type *parseBuiltinCallArgumentType(StringRef TypeStr, LLVMContext &Ctx);
5962

6063
/// Translates a string representing a SPIR-V or OpenCL builtin type to a
6164
/// TargetExtType that can be further lowered with lowerBuiltinType().

llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
316316

317317
if (Arg.hasName())
318318
buildOpName(VRegs[i][0], Arg.getName(), MIRBuilder);
319-
if (isPointerTy(Arg.getType())) {
319+
if (isPointerTyOrWrapper(Arg.getType())) {
320320
auto DerefBytes = static_cast<unsigned>(Arg.getDereferenceableBytes());
321321
if (DerefBytes != 0)
322322
buildOpDecorate(VRegs[i][0], MIRBuilder,

0 commit comments

Comments
 (0)