Skip to content

Commit 557e327

Browse files
improve how lowering of formal arguments interprets a value of 'kernel_arg_type'
1 parent 2e2b6b5 commit 557e327

File tree

3 files changed

+26
-20
lines changed

3 files changed

+26
-20
lines changed

llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -157,22 +157,23 @@ static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
157157
isSpecialOpaqueType(OriginalArgType))
158158
return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);
159159

160-
MDString *MDKernelArgType = getOCLKernelArgType(F, ArgIdx);
161-
if (!MDKernelArgType || (!MDKernelArgType->getString().ends_with("*") &&
162-
!MDKernelArgType->getString().ends_with("_t")))
163-
return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);
164-
165-
if (MDKernelArgType->getString().ends_with("*"))
166-
return GR->getOrCreateSPIRVTypeByName(
167-
MDKernelArgType->getString(), MIRBuilder,
168-
addressSpaceToStorageClass(OriginalArgType->getPointerAddressSpace()));
169-
170-
if (MDKernelArgType->getString().ends_with("_t"))
171-
return GR->getOrCreateSPIRVTypeByName(
172-
"opencl." + MDKernelArgType->getString().str(), MIRBuilder,
173-
SPIRV::StorageClass::Function, ArgAccessQual);
174-
175-
llvm_unreachable("Unable to recognize argument type name.");
160+
SPIRVType *ResArgType = nullptr;
161+
if (MDString *MDKernelArgType = getOCLKernelArgType(F, ArgIdx)) {
162+
StringRef MDTypeStr = MDKernelArgType->getString();
163+
if (MDTypeStr.ends_with("*")) {
164+
ResArgType = GR->getOrCreateSPIRVTypeByName(
165+
MDTypeStr, MIRBuilder,
166+
addressSpaceToStorageClass(
167+
OriginalArgType->getPointerAddressSpace()));
168+
} else if (MDTypeStr.ends_with("_t")) {
169+
ResArgType = GR->getOrCreateSPIRVTypeByName(
170+
"opencl." + MDTypeStr.str(), MIRBuilder,
171+
SPIRV::StorageClass::Function, ArgAccessQual);
172+
}
173+
}
174+
return ResArgType ? ResArgType
175+
: GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder,
176+
ArgAccessQual);
176177
}
177178

178179
static bool isEntryPoint(const Function &F) {

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -443,8 +443,9 @@ Register SPIRVGlobalRegistry::buildConstantSampler(
443443
SPIRVType *SampTy;
444444
if (SpvType)
445445
SampTy = getOrCreateSPIRVType(getTypeForSPIRVType(SpvType), MIRBuilder);
446-
else
447-
SampTy = getOrCreateSPIRVTypeByName("opencl.sampler_t", MIRBuilder);
446+
else if ((SampTy = getOrCreateSPIRVTypeByName("opencl.sampler_t",
447+
MIRBuilder)) == nullptr)
448+
report_fatal_error("Unable to recognize SPIRV type name: opencl.sampler_t");
448449

449450
auto Sampler =
450451
ResReg.isValid()
@@ -941,6 +942,7 @@ SPIRVGlobalRegistry::checkSpecialInstr(const SPIRV::SpecialTypeDescriptor &TD,
941942
return nullptr;
942943
}
943944

945+
// Returns nullptr if unable to recognize SPIRV type name
944946
// TODO: maybe use tablegen to implement this.
945947
SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName(
946948
StringRef TypeStr, MachineIRBuilder &MIRBuilder,
@@ -992,8 +994,10 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName(
992994
} else if (TypeStr.starts_with("double")) {
993995
Ty = Type::getDoubleTy(Ctx);
994996
TypeStr = TypeStr.substr(strlen("double"));
995-
} else
996-
llvm_unreachable("Unable to recognize SPIRV type name.");
997+
} else {
998+
// Unable to recognize SPIRV type name
999+
return nullptr;
1000+
}
9971001

9981002
auto SpirvTy = getOrCreateSPIRVType(Ty, MIRBuilder, AQ);
9991003

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ class SPIRVGlobalRegistry {
138138

139139
// Either generate a new OpTypeXXX instruction or return an existing one
140140
// corresponding to the given string containing the name of the builtin type.
141+
// Return nullptr if unable to recognize SPIRV type name from `TypeStr`.
141142
SPIRVType *getOrCreateSPIRVTypeByName(
142143
StringRef TypeStr, MachineIRBuilder &MIRBuilder,
143144
SPIRV::StorageClass::StorageClass SC = SPIRV::StorageClass::Function,

0 commit comments

Comments
 (0)