Skip to content

Commit 9a82d53

Browse files
committed
Fix handling of multiple usage of composite spec constant
Also updated enumeration algorithm, so IDs are not reserved for __spirv_SpecConstantComposite entries anymore.
1 parent 4c9ce32 commit 9a82d53

File tree

3 files changed

+176
-30
lines changed

3 files changed

+176
-30
lines changed

llvm/test/tools/sycl-post-link/composite-spec-constant.ll

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,21 +9,21 @@
99
; CHECK: %[[#NS1:]] = call float @_Z20__spirv_SpecConstantif(i32 [[#ID + 1]], float
1010
; CHECK: %[[#NA0:]] = call %struct._ZTS1A.A @_Z29__spirv_SpecConstantCompositeif(i32 %[[#NS0]], float %[[#NS1]])
1111
;
12-
; CHECK: %[[#NS2:]] = call i32 @_Z20__spirv_SpecConstantii(i32 [[#ID + 3]], i32
13-
; CHECK: %[[#NS3:]] = call float @_Z20__spirv_SpecConstantif(i32 [[#ID + 4]], float
12+
; CHECK: %[[#NS2:]] = call i32 @_Z20__spirv_SpecConstantii(i32 [[#ID + 2]], i32
13+
; CHECK: %[[#NS3:]] = call float @_Z20__spirv_SpecConstantif(i32 [[#ID + 3]], float
1414
; CHECK: %[[#NA1:]] = call %struct._ZTS1A.A @_Z29__spirv_SpecConstantCompositeif(i32 %[[#NS2]], float %[[#NS3]])
1515
;
1616
; CHECK: %[[#NA:]] = call [2 x %struct._ZTS1A.A] @_Z29__spirv_SpecConstantCompositestruct._ZTS1A.Astruct._ZTS1A.A(%struct._ZTS1A.A %[[#NA0]], %struct._ZTS1A.A %[[#NA1]])
1717
;
18-
; CHECK: %[[#B0:]] = call i32 @_Z20__spirv_SpecConstantii(i32 [[#ID + 7]], i32{{.*}})
19-
; CHECK: %[[#B1:]] = call i32 @_Z20__spirv_SpecConstantii(i32 [[#ID + 8]], i32{{.*}})
18+
; CHECK: %[[#B0:]] = call i32 @_Z20__spirv_SpecConstantii(i32 [[#ID + 4]], i32{{.*}})
19+
; CHECK: %[[#B1:]] = call i32 @_Z20__spirv_SpecConstantii(i32 [[#ID + 5]], i32{{.*}})
2020
; CHECK: %[[#BV:]] = call <2 x i32> @_Z29__spirv_SpecConstantCompositeii(i32 %[[#B0]], i32 %[[#B1]])
2121
; CHECK: %[[#B:]] = call %"class._ZTSN2cl4sycl3vecIiLi2EEE.cl::sycl::vec" @_Z29__spirv_SpecConstantCompositeDv2_i(<2 x i32> %[[#BV]])
2222
;
2323
; CHECK: %[[#POD:]] = call %struct._ZTS3POD.POD @"_Z29__spirv_SpecConstantCompositeAstruct._ZTS1A.Aclass._ZTSN2cl4sycl3vecIiLi2EEE.cl::sycl::vec"([2 x %struct._ZTS1A.A] %[[#NA]], %"class._ZTSN2cl4sycl3vecIiLi2EEE.cl::sycl::vec" %[[#B]]), !SYCL_SPEC_CONST_SYM_ID ![[#MD:]]
2424
; CHECK: store %struct._ZTS3POD.POD %[[#POD]]
2525
;
26-
; CHECK: ![[#MD]] = !{!"_ZTS3POD", i32 [[#ID]], i32 [[#ID + 1]], i32 [[#ID + 3]], i32 [[#ID + 4]], i32 [[#ID + 7]], i32 [[#ID + 8]]}
26+
; CHECK: ![[#MD]] = !{!"_ZTS3POD", i32 [[#ID]], i32 [[#ID + 1]], i32 [[#ID + 2]], i32 [[#ID + 3]], i32 [[#ID + 4]], i32 [[#ID + 5]]}
2727

2828
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
2929
target triple = "spir64-unknown-unknown-sycldevice"
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
; RUN: sycl-post-link -spec-const=rt --ir-output-only %s -S -o - \
2+
; RUN: | FileCheck %s --implicit-check-not __sycl_getCompositeSpecConstantValue
3+
;
4+
; This test is intended to check that sycl-post-link tool is capable of handling
5+
; situations when the same composite specialization constants is used more than
6+
; once
7+
;
8+
; CHECK-LABEL: @foo1
9+
; CHECK: call %struct._ZTS3POD.POD @"_Z29__spirv_SpecConstantCompositeAstruct._ZTS1A.Aclass._ZTSN2cl4sycl3vecIiLi2EEE.cl::sycl::vec"({{.*}}), !SYCL_SPEC_CONST_SYM_ID ![[#MD0:]]
10+
; CHECK-LABEL: @_ZTS4Test
11+
; CHECK: call %struct._ZTS3POD.POD @"_Z29__spirv_SpecConstantCompositeAstruct._ZTS1A.Aclass._ZTSN2cl4sycl3vecIiLi2EEE.cl::sycl::vec"({{.*}}), !SYCL_SPEC_CONST_SYM_ID ![[#MD1:]]
12+
; CHECK-LABEL: @foo2
13+
; CHECK: call %struct._ZTS3POD.POD @"_Z29__spirv_SpecConstantCompositeAstruct._ZTS1A.Aclass._ZTSN2cl4sycl3vecIiLi2EEE.cl::sycl::vec"({{.*}}), !SYCL_SPEC_CONST_SYM_ID ![[#MD0:]]
14+
;
15+
; CHECK-DAG: ![[#MD0]] = !{!"_ZTS3PO2", i32 [[#ID:]],
16+
; CHECK-SAME: i32 [[#ID + 1]], i32 [[#ID + 2]], i32 [[#ID + 3]], i32 [[#ID + 4]], i32 [[#ID + 5]]}
17+
; CHECK-DAG: ![[#MD1]] = !{!"_ZTS3POD", i32 [[#ID1:]],
18+
; CHECK-SAME: i32 [[#ID1 + 1]], i32 [[#ID1 + 2]], i32 [[#ID1 + 3]], i32 [[#ID1 + 4]], i32 [[#ID1 + 5]]}
19+
20+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
21+
target triple = "spir64-unknown-unknown-sycldevice"
22+
23+
%struct._ZTS3POD.POD = type { [2 x %struct._ZTS1A.A], %"class._ZTSN2cl4sycl3vecIiLi2EEE.cl::sycl::vec" }
24+
%struct._ZTS1A.A = type { i32, float }
25+
%"class._ZTSN2cl4sycl3vecIiLi2EEE.cl::sycl::vec" = type { <2 x i32> }
26+
%"class._ZTSN2cl4sycl5rangeILi1EEE.cl::sycl::range" = type { %"class._ZTSN2cl4sycl6detail5arrayILi1EEE.cl::sycl::detail::array" }
27+
%"class._ZTSN2cl4sycl6detail5arrayILi1EEE.cl::sycl::detail::array" = type { [1 x i64] }
28+
%"class._ZTSN2cl4sycl2idILi1EEE.cl::sycl::id" = type { %"class._ZTSN2cl4sycl6detail5arrayILi1EEE.cl::sycl::detail::array" }
29+
30+
$_ZTS4Test = comdat any
31+
32+
@__builtin_unique_stable_name._ZNK2cl4sycl6ONEAPI12experimental13spec_constantI3PODS4_E3getIS4_EENSt9enable_ifIXsr3std6is_podIT_EE5valueES8_E4typeEv = private unnamed_addr addrspace(1) constant [9 x i8] c"_ZTS3POD\00", align 1
33+
@__builtin_unique_stable_name.2 = private unnamed_addr addrspace(1) constant [9 x i8] c"_ZTS3PO2\00", align 1
34+
35+
define spir_func void @foo1() {
36+
%ref.tmp.i = alloca %struct._ZTS3POD.POD, align 8
37+
%1 = addrspacecast %struct._ZTS3POD.POD* %ref.tmp.i to %struct._ZTS3POD.POD addrspace(4)*
38+
call spir_func void @_Z36__sycl_getCompositeSpecConstantValueI3PODET_PKc(%struct._ZTS3POD.POD addrspace(4)* sret align 8 %1, i8 addrspace(4)* addrspacecast (i8 addrspace(1)* getelementptr inbounds ([9 x i8], [9 x i8] addrspace(1)* @__builtin_unique_stable_name.2, i64 0, i64 0) to i8 addrspace(4)*)) #4
39+
ret void
40+
}
41+
42+
; Function Attrs: convergent norecurse uwtable
43+
define weak_odr dso_local spir_kernel void @_ZTS4Test(%struct._ZTS3POD.POD addrspace(1)* %_arg_, %"class._ZTSN2cl4sycl5rangeILi1EEE.cl::sycl::range"* byval(%"class._ZTSN2cl4sycl5rangeILi1EEE.cl::sycl::range") align 8 %_arg_1, %"class._ZTSN2cl4sycl5rangeILi1EEE.cl::sycl::range"* byval(%"class._ZTSN2cl4sycl5rangeILi1EEE.cl::sycl::range") align 8 %_arg_2, %"class._ZTSN2cl4sycl2idILi1EEE.cl::sycl::id"* byval(%"class._ZTSN2cl4sycl2idILi1EEE.cl::sycl::id") align 8 %_arg_3) local_unnamed_addr #0 comdat !kernel_arg_buffer_location !4 {
44+
entry:
45+
%ref.tmp.i = alloca %struct._ZTS3POD.POD, align 8
46+
%0 = getelementptr inbounds %"class._ZTSN2cl4sycl2idILi1EEE.cl::sycl::id", %"class._ZTSN2cl4sycl2idILi1EEE.cl::sycl::id"* %_arg_3, i64 0, i32 0, i32 0, i64 0
47+
%1 = load i64, i64* %0, align 8
48+
%add.ptr.i = getelementptr inbounds %struct._ZTS3POD.POD, %struct._ZTS3POD.POD addrspace(1)* %_arg_, i64 %1
49+
%2 = bitcast %struct._ZTS3POD.POD* %ref.tmp.i to i8*
50+
call void @llvm.lifetime.start.p0i8(i64 24, i8* nonnull %2) #3
51+
%3 = addrspacecast %struct._ZTS3POD.POD* %ref.tmp.i to %struct._ZTS3POD.POD addrspace(4)*
52+
call spir_func void @_Z36__sycl_getCompositeSpecConstantValueI3PODET_PKc(%struct._ZTS3POD.POD addrspace(4)* sret align 8 %3, i8 addrspace(4)* addrspacecast (i8 addrspace(1)* getelementptr inbounds ([9 x i8], [9 x i8] addrspace(1)* @__builtin_unique_stable_name._ZNK2cl4sycl6ONEAPI12experimental13spec_constantI3PODS4_E3getIS4_EENSt9enable_ifIXsr3std6is_podIT_EE5valueES8_E4typeEv, i64 0, i64 0) to i8 addrspace(4)*)) #4
53+
%4 = bitcast %struct._ZTS3POD.POD addrspace(1)* %add.ptr.i to i8 addrspace(1)*
54+
%5 = addrspacecast i8 addrspace(1)* %4 to i8 addrspace(4)*
55+
call void @llvm.memcpy.p4i8.p0i8.i64(i8 addrspace(4)* align 8 dereferenceable(24) %5, i8* nonnull align 8 dereferenceable(24) %2, i64 24, i1 false), !tbaa.struct !5
56+
call void @llvm.lifetime.end.p0i8(i64 24, i8* nonnull %2) #3
57+
ret void
58+
}
59+
60+
define spir_func void @foo2() {
61+
%ref.tmp.i = alloca %struct._ZTS3POD.POD, align 8
62+
%1 = addrspacecast %struct._ZTS3POD.POD* %ref.tmp.i to %struct._ZTS3POD.POD addrspace(4)*
63+
call spir_func void @_Z36__sycl_getCompositeSpecConstantValueI3PODET_PKc(%struct._ZTS3POD.POD addrspace(4)* sret align 8 %1, i8 addrspace(4)* addrspacecast (i8 addrspace(1)* getelementptr inbounds ([9 x i8], [9 x i8] addrspace(1)* @__builtin_unique_stable_name.2, i64 0, i64 0) to i8 addrspace(4)*)) #4
64+
ret void
65+
}
66+
67+
; Function Attrs: argmemonly nounwind willreturn
68+
declare void @llvm.lifetime.start.p0i8(i64 immarg, i8* nocapture) #1
69+
70+
; Function Attrs: argmemonly nounwind willreturn
71+
declare void @llvm.lifetime.end.p0i8(i64 immarg, i8* nocapture) #1
72+
73+
; Function Attrs: argmemonly nounwind willreturn
74+
declare void @llvm.memcpy.p4i8.p0i8.i64(i8 addrspace(4)* noalias nocapture writeonly, i8* noalias nocapture readonly, i64, i1 immarg) #1
75+
76+
; Function Attrs: convergent
77+
declare dso_local spir_func void @_Z36__sycl_getCompositeSpecConstantValueI3PODET_PKc(%struct._ZTS3POD.POD addrspace(4)* sret align 8, i8 addrspace(4)*) local_unnamed_addr #2
78+
79+
attributes #0 = { convergent norecurse uwtable "disable-tail-calls"="false" "frame-pointer"="all" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "sycl-module-id"="../sycl/test/spec_const/composite.cpp" "tune-cpu"="generic" "uniform-work-group-size"="true" "unsafe-fp-math"="false" "use-soft-float"="false" }
80+
attributes #1 = { argmemonly nounwind willreturn }
81+
attributes #2 = { convergent "disable-tail-calls"="false" "frame-pointer"="all" "less-precise-fpmad"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" }
82+
attributes #3 = { nounwind }
83+
attributes #4 = { convergent }
84+
85+
!llvm.module.flags = !{!0}
86+
!opencl.spir.version = !{!1}
87+
!spirv.Source = !{!2}
88+
!llvm.ident = !{!3}
89+
90+
!0 = !{i32 1, !"wchar_size", i32 4}
91+
!1 = !{i32 1, i32 2}
92+
!2 = !{i32 4, i32 100000}
93+
!3 = !{!"clang version 12.0.0 "}
94+
!4 = !{i32 -1, i32 -1, i32 -1, i32 -1}
95+
!5 = !{i64 0, i64 16, !6, i64 16, i64 8, !6}
96+
!6 = !{!7, !7, i64 0}
97+
!7 = !{!"omnipotent char", !8, i64 0}
98+
!8 = !{!"Simple C++ TBAA"}

llvm/tools/sycl-post-link/SpecConstants.cpp

Lines changed: 73 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,8 @@ getScalarSpecConstMetadata(const Instruction *I) {
210210
return std::make_pair(MDSym->getString(), ID);
211211
}
212212

213+
/// Recursively iterates over a composite type in order to collect information
214+
/// about its scalar elements.
213215
void collectCompositeElementsInfoRecursive(
214216
const Type *Ty, unsigned &Index, unsigned &Offset,
215217
std::vector<CompositeSpecConstElementDescriptor> &Result) {
@@ -303,21 +305,52 @@ Instruction *emitSpecConstantComposite(Type *Ty,
303305
return emitCall(Ty, SPIRV_GET_SPEC_CONST_COMPOSITE, Args, InsertBefore);
304306
}
305307

306-
Instruction *
307-
emitSpecConstantRecursive(unsigned &NextID, Type *Ty,
308-
SmallVectorImpl<unsigned> &GeneratedScalarIDs,
309-
Instruction *InsertBefore) {
310-
if (!Ty->isArrayTy() && !Ty->isStructTy() && !Ty->isVectorTy()) {
311-
// assume that this is a scalar
312-
GeneratedScalarIDs.push_back(NextID);
313-
return emitSpecConstant(NextID, Ty, InsertBefore);
308+
/// For specified specialization constant type emits LLVM IR which is required
309+
/// in order to correctly handle it later during LLVM IR -> SPIR-V translation.
310+
///
311+
/// @param Ty [in] Specialization constant type to handle.
312+
/// @param InsertBefore [in] Location in the module where new instructions
313+
/// should be inserted.
314+
/// @param IDs [in,out] List of IDs which are assigned for scalar specialization
315+
/// constants. If \c IsNewSpecConstant is true, this vector is expected to
316+
/// contain a single element with ID of the first spec constant - the rest of
317+
/// generated spec constants will have their IDs generated by incrementing that
318+
/// first ID. If \c IsNewSpecConstant is false, this vector is expected to
319+
/// contain enough elements to assign ID to each scalar element encountered in
320+
/// the specified composite type.
321+
/// @param IsNewSpecConstant [in] Flag to specify whether \c IDs vector should
322+
/// be filled with new IDs or it should be used as-is to replicate an existing
323+
/// spec constant
324+
/// @param [in,out] IsFirstElement Flag indicating whether this function is
325+
/// handling the first scalar element encountered in the specified composite
326+
/// type \c Ty or not.
327+
///
328+
/// @returns Instruction* representing specialization constant in LLVM IR, which
329+
/// is in SPIR-V friendly LLVM IR form.
330+
/// For scalar types it results in a single __spirv_SpecConstant call.
331+
/// For composite types it results in a number of __spirv_SpecConstant calls
332+
/// for each scalar member of the composite plus in a number of
333+
/// __spirvSpecConstantComposite calls for each composite member of the
334+
/// composite (plus for the top-level composite). Also enumerates all
335+
/// encountered scalars and assigns them IDs (or re-uses existing ones).
336+
Instruction *emitSpecConstantRecursiveImpl(Type *Ty, Instruction *InsertBefore,
337+
SmallVectorImpl<unsigned> &IDs,
338+
bool IsNewSpecConstant,
339+
bool &IsFirstElement) {
340+
if (!Ty->isArrayTy() && !Ty->isStructTy() && !Ty->isVectorTy()) { // Scalar
341+
if (IsNewSpecConstant && !IsFirstElement) {
342+
// If it is a new specialization constant, we need to generate IDs for
343+
// scalar elements, starting with the second one.
344+
IDs.push_back(IDs.back() + 1);
345+
}
346+
IsFirstElement = false;
347+
return emitSpecConstant(IDs.back(), Ty, InsertBefore);
314348
}
315349

316350
SmallVector<Instruction *, 8> Elements;
317351
auto LoopIteration = [&](Type *Ty) {
318-
++NextID; // The first NextID is reserved for SpecConstantComposite below
319-
Elements.push_back(emitSpecConstantRecursive(NextID, Ty, GeneratedScalarIDs,
320-
InsertBefore));
352+
Elements.push_back(emitSpecConstantRecursiveImpl(
353+
Ty, InsertBefore, IDs, IsNewSpecConstant, IsFirstElement));
321354
};
322355

323356
if (auto *ArrTy = dyn_cast<ArrayType>(Ty)) {
@@ -339,12 +372,21 @@ emitSpecConstantRecursive(unsigned &NextID, Type *Ty,
339372
return emitSpecConstantComposite(Ty, Elements, InsertBefore);
340373
}
341374

375+
/// Wrapper intended to hide IsFirstElement argument from the caller
376+
Instruction *emitSpecConstantRecursive(Type *Ty, Instruction *InsertBefore,
377+
SmallVectorImpl<unsigned> &IDs,
378+
bool IsNewSpecConstant) {
379+
bool IsFirstElement = true;
380+
return emitSpecConstantRecursiveImpl(Ty, InsertBefore, IDs, IsNewSpecConstant,
381+
IsFirstElement);
382+
}
383+
342384
} // namespace
343385

344386
PreservedAnalyses SpecConstantsPass::run(Module &M,
345387
ModuleAnalysisManager &MAM) {
346388
unsigned NextID = 0;
347-
StringMap<unsigned> IDMap;
389+
StringMap<SmallVector<unsigned, 1>> IDMap;
348390

349391
// Iterate through all declarations of instances of function template
350392
// template <typename T> T __sycl_getSpecConstantValue(const char *ID)
@@ -380,7 +422,7 @@ PreservedAnalyses SpecConstantsPass::run(Module &M,
380422
DelInsts.push_back(CI);
381423
Type *SCTy = CI->getType();
382424
unsigned NameArgNo = 0;
383-
if (IsComposite) { // structs are returned via sret arguments
425+
if (IsComposite) { // structs are returned via sret arguments.
384426
NameArgNo = 1;
385427
auto *PtrTy = cast<PointerType>(CI->getArgOperand(0)->getType());
386428
SCTy = PtrTy->getElementType();
@@ -389,22 +431,28 @@ PreservedAnalyses SpecConstantsPass::run(Module &M,
389431

390432
if (SetValAtRT) {
391433
// 2. Spec constant value will be set at run time - then add the literal
392-
// to a "spec const string literal ID" -> "integer ID" map, uniquing
393-
// the integer ID if this is new literal
394-
auto Ins = IDMap.insert(std::make_pair(SymID, 0));
395-
if (Ins.second)
396-
Ins.first->second = NextID;
397-
unsigned ID = Ins.first->second;
434+
// to a "spec const string literal ID" -> "integer ID" map or
435+
// "composite spec const string literal ID" -> "vector of integer IDs"
436+
// map, uniquing the integer IDs if this is new literal
437+
auto Ins =
438+
IDMap.insert(std::make_pair(SymID, SmallVector<unsigned, 1>{}));
439+
bool IsNewSpecConstant = Ins.second;
440+
auto &IDs = Ins.first->second;
441+
if (IsNewSpecConstant) {
442+
// For any spec constant type there will be always at least one ID
443+
// generatedA.
444+
IDs.push_back(NextID);
445+
}
398446

399447
// 3. Transform to spirv intrinsic _Z*__spirv_SpecConstant* or
400448
// _Z*__spirv_SpecConstantComposite
401-
SmallVector<unsigned, 4> GeneratedIDs;
402-
auto *SPIRVCall = emitSpecConstantRecursive(ID, SCTy, GeneratedIDs, CI);
403-
if (Ins.second) {
449+
auto *SPIRVCall =
450+
emitSpecConstantRecursive(SCTy, CI, IDs, IsNewSpecConstant);
451+
if (IsNewSpecConstant) {
404452
// emitSpecConstantRecursive might emit more than one spec constant
405453
// (because of composite types) and therefore, we need to ajudst
406-
// NextID according to the actual amount of emitted spec constants
407-
NextID += GeneratedIDs.size();
454+
// NextID according to the actual amount of emitted spec constants.
455+
NextID += IDs.size();
408456
}
409457

410458
if (IsComposite) {
@@ -418,7 +466,7 @@ PreservedAnalyses SpecConstantsPass::run(Module &M,
418466

419467
// Mark the instruction with <symbolic_id, int_ids...> list for later
420468
// recollection by collectSpecConstantMetadata method.
421-
setSpecConstSymIDMetadata(SPIRVCall, SymID, GeneratedIDs);
469+
setSpecConstSymIDMetadata(SPIRVCall, SymID, IDs);
422470
// Example of the emitted call when spec constant is integer:
423471
// %6 = call i32 @_Z20__spirv_SpecConstantii(i32 0, i32 0), \
424472
// !SYCL_SPEC_CONST_SYM_ID !22

0 commit comments

Comments
 (0)