@@ -290,6 +290,20 @@ Value *SPIRVToLLVM::mapFunction(SPIRVFunction *BF, Function *F) {
290
290
// %5 = insertelement <3 x i64> %3, i64 %4, i32 2
291
291
// %c = extractelement <3 x i64> %5, i32 idx
292
292
// %d = extractelement <3 x i64> %5, i32 idx
293
+ //
294
+ // Replace the following pattern:
295
+ // %0 = addrspacecast <3 x i64> addrspace(1)* @__spirv_BuiltInWorkgroupSize to
296
+ // <3 x i64> addrspace(4)*
297
+ // %1 = getelementptr <3 x i64>, <3 x i64> addrspace(4)* %0, i64 0, i64 0
298
+ // %2 = load i64, i64 addrspace(4)* %1, align 32
299
+ // With:
300
+ // %0 = call spir_func i64 @_Z13get_global_idj(i32 0) #1
301
+ // %1 = insertelement <3 x i64> undef, i64 %0, i32 0
302
+ // %2 = call spir_func i64 @_Z13get_global_idj(i32 1) #1
303
+ // %3 = insertelement <3 x i64> %1, i64 %2, i32 1
304
+ // %4 = call spir_func i64 @_Z13get_global_idj(i32 2) #1
305
+ // %5 = insertelement <3 x i64> %3, i64 %4, i32 2
306
+ // %6 = extractelement <3 x i64> %5, i32 0
293
307
bool SPIRVToLLVM::transOCLBuiltinFromVariable (GlobalVariable *GV,
294
308
SPIRVBuiltinVariableKind Kind) {
295
309
std::string FuncName;
@@ -300,7 +314,8 @@ bool SPIRVToLLVM::transOCLBuiltinFromVariable(GlobalVariable *GV,
300
314
} else {
301
315
FuncName = std::string (GV->getName ());
302
316
}
303
- Type *ReturnTy = GV->getType ()->getPointerElementType ();
317
+ Type *GVTy = GV->getType ()->getPointerElementType ();
318
+ Type *ReturnTy = GVTy;
304
319
// Some SPIR-V builtin variables are translated to a function with an index
305
320
// argument.
306
321
bool HasIndexArg =
@@ -324,9 +339,9 @@ bool SPIRVToLLVM::transOCLBuiltinFromVariable(GlobalVariable *GV,
324
339
}
325
340
326
341
// Collect instructions in these containers to remove them later.
327
- std::vector<Instruction *> Extracts;
328
342
std::vector<Instruction *> Loads;
329
343
std::vector<Instruction *> Casts;
344
+ std::vector<Instruction *> GEPs;
330
345
331
346
auto Replace = [&](std::vector<Value *> Arg, Instruction *I) {
332
347
auto Call = CallInst::Create (Func, Arg, " " , I);
@@ -342,12 +357,15 @@ bool SPIRVToLLVM::transOCLBuiltinFromVariable(GlobalVariable *GV,
342
357
// with this vector.
343
358
// If HasIndexArg is false, the result of the Load instruction is the value
344
359
// which should be replaced with the Func.
345
- auto FindAndReplace = [&](LoadInst *LD) {
360
+ // Returns true if Load was replaced, false otherwise.
361
+ auto ReplaceIfLoad = [&](User *I) {
362
+ auto *LD = dyn_cast<LoadInst>(I);
363
+ if (!LD)
364
+ return false ;
346
365
std::vector<Value *> Vectors;
347
366
Loads.push_back (LD);
348
367
if (HasIndexArg) {
349
- auto *VecTy = cast<FixedVectorType>(
350
- LD->getPointerOperandType ()->getPointerElementType ());
368
+ auto *VecTy = cast<FixedVectorType>(GVTy);
351
369
Value *EmptyVec = UndefValue::get (VecTy);
352
370
Vectors.push_back (EmptyVec);
353
371
const DebugLoc &DLoc = LD->getDebugLoc ();
@@ -363,10 +381,27 @@ bool SPIRVToLLVM::transOCLBuiltinFromVariable(GlobalVariable *GV,
363
381
Insert->insertAfter (Call);
364
382
Vectors.push_back (Insert);
365
383
}
366
- LD->replaceAllUsesWith (Vectors.back ());
384
+
385
+ Value *Ptr = LD->getPointerOperand ();
386
+
387
+ if (isa<FixedVectorType>(Ptr->getType ()->getPointerElementType ())) {
388
+ LD->replaceAllUsesWith (Vectors.back ());
389
+ } else {
390
+ auto *GEP = dyn_cast<GetElementPtrInst>(Ptr);
391
+ assert (GEP && " Unexpected pattern!" );
392
+ assert (GEP->getNumIndices () == 2 && " Unexpected pattern!" );
393
+ Value *Idx = GEP->getOperand (2 );
394
+ Value *Vec = Vectors.back ();
395
+ auto *NewExtract = ExtractElementInst::Create (Vec, Idx);
396
+ NewExtract->insertAfter (cast<Instruction>(Vec));
397
+ LD->replaceAllUsesWith (NewExtract);
398
+ }
399
+
367
400
} else {
368
401
Replace ({}, LD);
369
402
}
403
+
404
+ return true ;
370
405
};
371
406
372
407
// Go over the GV users, find Load and ExtractElement instructions and
@@ -376,13 +411,19 @@ bool SPIRVToLLVM::transOCLBuiltinFromVariable(GlobalVariable *GV,
376
411
if (auto *ASCast = dyn_cast<AddrSpaceCastInst>(UI)) {
377
412
Casts.push_back (ASCast);
378
413
for (auto *CastUser : ASCast->users ()) {
379
- if (auto *LD = dyn_cast<LoadInst>(CastUser)) {
380
- FindAndReplace (LD);
414
+ if (ReplaceIfLoad (CastUser))
415
+ continue ;
416
+ if (auto *GEP = dyn_cast<GetElementPtrInst>(CastUser)) {
417
+ GEPs.push_back (GEP);
418
+ for (auto *GEPUser : GEP->users ()) {
419
+ if (!ReplaceIfLoad (GEPUser))
420
+ llvm_unreachable (" Unexpected pattern!" );
421
+ }
422
+ } else {
423
+ llvm_unreachable (" Unexpected pattern!" );
381
424
}
382
425
}
383
- } else if (auto *LD = dyn_cast<LoadInst>(UI)) {
384
- FindAndReplace (LD);
385
- } else {
426
+ } else if (!ReplaceIfLoad (UI)) {
386
427
llvm_unreachable (" Unexpected pattern!" );
387
428
}
388
429
}
@@ -394,8 +435,8 @@ bool SPIRVToLLVM::transOCLBuiltinFromVariable(GlobalVariable *GV,
394
435
}
395
436
};
396
437
// Order of erasing is important.
397
- Erase (Extracts);
398
438
Erase (Loads);
439
+ Erase (GEPs);
399
440
Erase (Casts);
400
441
401
442
return true ;
0 commit comments