@@ -210,6 +210,8 @@ getScalarSpecConstMetadata(const Instruction *I) {
210
210
return std::make_pair (MDSym->getString (), ID);
211
211
}
212
212
213
+ // / Recursively iterates over a composite type in order to collect information
214
+ // / about its scalar elements.
213
215
void collectCompositeElementsInfoRecursive (
214
216
const Type *Ty, unsigned &Index, unsigned &Offset,
215
217
std::vector<CompositeSpecConstElementDescriptor> &Result) {
@@ -303,21 +305,52 @@ Instruction *emitSpecConstantComposite(Type *Ty,
303
305
return emitCall (Ty, SPIRV_GET_SPEC_CONST_COMPOSITE, Args, InsertBefore);
304
306
}
305
307
306
- Instruction *
307
- emitSpecConstantRecursive (unsigned &NextID, Type *Ty,
308
- SmallVectorImpl<unsigned > &GeneratedScalarIDs,
309
- Instruction *InsertBefore) {
310
- if (!Ty->isArrayTy () && !Ty->isStructTy () && !Ty->isVectorTy ()) {
311
- // assume that this is a scalar
312
- GeneratedScalarIDs.push_back (NextID);
313
- return emitSpecConstant (NextID, Ty, InsertBefore);
308
+ // / For specified specialization constant type emits LLVM IR which is required
309
+ // / in order to correctly handle it later during LLVM IR -> SPIR-V translation.
310
+ // /
311
+ // / @param Ty [in] Specialization constant type to handle.
312
+ // / @param InsertBefore [in] Location in the module where new instructions
313
+ // / should be inserted.
314
+ // / @param IDs [in,out] List of IDs which are assigned for scalar specialization
315
+ // / constants. If \c IsNewSpecConstant is true, this vector is expected to
316
+ // / contain a single element with ID of the first spec constant - the rest of
317
+ // / generated spec constants will have their IDs generated by incrementing that
318
+ // / first ID. If \c IsNewSpecConstant is false, this vector is expected to
319
+ // / contain enough elements to assign ID to each scalar element encountered in
320
+ // / the specified composite type.
321
+ // / @param IsNewSpecConstant [in] Flag to specify whether \c IDs vector should
322
+ // / be filled with new IDs or it should be used as-is to replicate an existing
323
+ // / spec constant
324
+ // / @param [in,out] IsFirstElement Flag indicating whether this function is
325
+ // / handling the first scalar element encountered in the specified composite
326
+ // / type \c Ty or not.
327
+ // /
328
+ // / @returns Instruction* representing specialization constant in LLVM IR, which
329
+ // / is in SPIR-V friendly LLVM IR form.
330
+ // / For scalar types it results in a single __spirv_SpecConstant call.
331
+ // / For composite types it results in a number of __spirv_SpecConstant calls
332
+ // / for each scalar member of the composite plus in a number of
333
+ // / __spirvSpecConstantComposite calls for each composite member of the
334
+ // / composite (plus for the top-level composite). Also enumerates all
335
+ // / encountered scalars and assigns them IDs (or re-uses existing ones).
336
+ Instruction *emitSpecConstantRecursiveImpl (Type *Ty, Instruction *InsertBefore,
337
+ SmallVectorImpl<unsigned > &IDs,
338
+ bool IsNewSpecConstant,
339
+ bool &IsFirstElement) {
340
+ if (!Ty->isArrayTy () && !Ty->isStructTy () && !Ty->isVectorTy ()) { // Scalar
341
+ if (IsNewSpecConstant && !IsFirstElement) {
342
+ // If it is a new specialization constant, we need to generate IDs for
343
+ // scalar elements, starting with the second one.
344
+ IDs.push_back (IDs.back () + 1 );
345
+ }
346
+ IsFirstElement = false ;
347
+ return emitSpecConstant (IDs.back (), Ty, InsertBefore);
314
348
}
315
349
316
350
SmallVector<Instruction *, 8 > Elements;
317
351
auto LoopIteration = [&](Type *Ty) {
318
- ++NextID; // The first NextID is reserved for SpecConstantComposite below
319
- Elements.push_back (emitSpecConstantRecursive (NextID, Ty, GeneratedScalarIDs,
320
- InsertBefore));
352
+ Elements.push_back (emitSpecConstantRecursiveImpl (
353
+ Ty, InsertBefore, IDs, IsNewSpecConstant, IsFirstElement));
321
354
};
322
355
323
356
if (auto *ArrTy = dyn_cast<ArrayType>(Ty)) {
@@ -339,12 +372,21 @@ emitSpecConstantRecursive(unsigned &NextID, Type *Ty,
339
372
return emitSpecConstantComposite (Ty, Elements, InsertBefore);
340
373
}
341
374
375
+ // / Wrapper intended to hide IsFirstElement argument from the caller
376
+ Instruction *emitSpecConstantRecursive (Type *Ty, Instruction *InsertBefore,
377
+ SmallVectorImpl<unsigned > &IDs,
378
+ bool IsNewSpecConstant) {
379
+ bool IsFirstElement = true ;
380
+ return emitSpecConstantRecursiveImpl (Ty, InsertBefore, IDs, IsNewSpecConstant,
381
+ IsFirstElement);
382
+ }
383
+
342
384
} // namespace
343
385
344
386
PreservedAnalyses SpecConstantsPass::run (Module &M,
345
387
ModuleAnalysisManager &MAM) {
346
388
unsigned NextID = 0 ;
347
- StringMap<unsigned > IDMap;
389
+ StringMap<SmallVector< unsigned , 1 > > IDMap;
348
390
349
391
// Iterate through all declarations of instances of function template
350
392
// template <typename T> T __sycl_getSpecConstantValue(const char *ID)
@@ -380,7 +422,7 @@ PreservedAnalyses SpecConstantsPass::run(Module &M,
380
422
DelInsts.push_back (CI);
381
423
Type *SCTy = CI->getType ();
382
424
unsigned NameArgNo = 0 ;
383
- if (IsComposite) { // structs are returned via sret arguments
425
+ if (IsComposite) { // structs are returned via sret arguments.
384
426
NameArgNo = 1 ;
385
427
auto *PtrTy = cast<PointerType>(CI->getArgOperand (0 )->getType ());
386
428
SCTy = PtrTy->getElementType ();
@@ -389,22 +431,28 @@ PreservedAnalyses SpecConstantsPass::run(Module &M,
389
431
390
432
if (SetValAtRT) {
391
433
// 2. Spec constant value will be set at run time - then add the literal
392
- // to a "spec const string literal ID" -> "integer ID" map, uniquing
393
- // the integer ID if this is new literal
394
- auto Ins = IDMap.insert (std::make_pair (SymID, 0 ));
395
- if (Ins.second )
396
- Ins.first ->second = NextID;
397
- unsigned ID = Ins.first ->second ;
434
+ // to a "spec const string literal ID" -> "integer ID" map or
435
+ // "composite spec const string literal ID" -> "vector of integer IDs"
436
+ // map, uniquing the integer IDs if this is new literal
437
+ auto Ins =
438
+ IDMap.insert (std::make_pair (SymID, SmallVector<unsigned , 1 >{}));
439
+ bool IsNewSpecConstant = Ins.second ;
440
+ auto &IDs = Ins.first ->second ;
441
+ if (IsNewSpecConstant) {
442
+ // For any spec constant type there will be always at least one ID
443
+ // generatedA.
444
+ IDs.push_back (NextID);
445
+ }
398
446
399
447
// 3. Transform to spirv intrinsic _Z*__spirv_SpecConstant* or
400
448
// _Z*__spirv_SpecConstantComposite
401
- SmallVector< unsigned , 4 > GeneratedIDs;
402
- auto *SPIRVCall = emitSpecConstantRecursive (ID, SCTy, GeneratedIDs, CI );
403
- if (Ins. second ) {
449
+ auto *SPIRVCall =
450
+ emitSpecConstantRecursive (SCTy, CI, IDs, IsNewSpecConstant );
451
+ if (IsNewSpecConstant ) {
404
452
// emitSpecConstantRecursive might emit more than one spec constant
405
453
// (because of composite types) and therefore, we need to ajudst
406
- // NextID according to the actual amount of emitted spec constants
407
- NextID += GeneratedIDs .size ();
454
+ // NextID according to the actual amount of emitted spec constants.
455
+ NextID += IDs .size ();
408
456
}
409
457
410
458
if (IsComposite) {
@@ -418,7 +466,7 @@ PreservedAnalyses SpecConstantsPass::run(Module &M,
418
466
419
467
// Mark the instruction with <symbolic_id, int_ids...> list for later
420
468
// recollection by collectSpecConstantMetadata method.
421
- setSpecConstSymIDMetadata (SPIRVCall, SymID, GeneratedIDs );
469
+ setSpecConstSymIDMetadata (SPIRVCall, SymID, IDs );
422
470
// Example of the emitted call when spec constant is integer:
423
471
// %6 = call i32 @_Z20__spirv_SpecConstantii(i32 0, i32 0), \
424
472
// !SYCL_SPEC_CONST_SYM_ID !22
0 commit comments