@@ -321,31 +321,6 @@ void SPIRVRegularizeLLVMBase::expandSYCLTypeUsing(Module *M) {
321
321
expandVIDWithSYCLTypeByValComp (F);
322
322
}
323
323
324
- Value *SPIRVRegularizeLLVMBase::extendBitInstBoolArg (Instruction *II) {
325
- IRBuilder<> Builder (II);
326
- auto *ArgTy = II->getOperand (0 )->getType ();
327
- Type *NewArgType = nullptr ;
328
- if (ArgTy->isIntegerTy ()) {
329
- NewArgType = Builder.getInt32Ty ();
330
- } else if (ArgTy->isVectorTy () &&
331
- cast<VectorType>(ArgTy)->getElementType ()->isIntegerTy ()) {
332
- unsigned NumElements = cast<FixedVectorType>(ArgTy)->getNumElements ();
333
- NewArgType = VectorType::get (Builder.getInt32Ty (), NumElements, false );
334
- } else {
335
- llvm_unreachable (" Unexpected type" );
336
- }
337
- auto *NewBase = Builder.CreateZExt (II->getOperand (0 ), NewArgType);
338
- auto *NewShift = Builder.CreateZExt (II->getOperand (1 ), NewArgType);
339
- switch (II->getOpcode ()) {
340
- case Instruction::LShr:
341
- return Builder.CreateLShr (NewBase, NewShift);
342
- case Instruction::Shl:
343
- return Builder.CreateShl (NewBase, NewShift);
344
- default :
345
- return II;
346
- }
347
- }
348
-
349
324
bool SPIRVRegularizeLLVMBase::runRegularizeLLVM (Module &Module) {
350
325
M = &Module;
351
326
Ctx = &M->getContext ();
@@ -458,19 +433,53 @@ bool SPIRVRegularizeLLVMBase::regularize() {
458
433
}
459
434
}
460
435
461
- // Translator treats i1 as boolean, but bit instructions take
462
- // a scalar/vector integers, so we have to extend such arguments
463
- if (II.isLogicalShift () &&
464
- II.getOperand (0 )->getType ()->isIntOrIntVectorTy (1 )) {
465
- auto *NewInst = extendBitInstBoolArg (&II);
466
- for (auto *U : II.users ()) {
467
- if (cast<Instruction>(U)->getOpcode () == Instruction::ZExt) {
468
- U->dropAllReferences ();
469
- U->replaceAllUsesWith (NewInst);
470
- ToErase.push_back (cast<Instruction>(U));
436
+ if (II.isLogicalShift ()) {
437
+ // Translator treats i1 as boolean, but bit instructions take
438
+ // a scalar/vector integers, so we have to extend such arguments.
439
+ // shl i1 %a %b and lshr i1 %a %b are now converted on:
440
+ // %0 = select i1 %a, i32 1, i32 0
441
+ // %1 = select i1 %b, i32 1, i32 0
442
+ // %2 = lshr i32 %0, %1
443
+ // if any other instruction other than zext was dependant:
444
+ // %3 = icmp ne i32 %2, 0
445
+ // which converts it back to i1 and replace original result with %3
446
+ // to dependant instructions.
447
+ if (II.getOperand (0 )->getType ()->isIntOrIntVectorTy (1 )) {
448
+ IRBuilder<> Builder (&II);
449
+ Value *CmpNEInst = nullptr ;
450
+ Constant *ConstZero = ConstantInt::get (Builder.getInt32Ty (), 0 );
451
+ Constant *ConstOne = ConstantInt::get (Builder.getInt32Ty (), 1 );
452
+ if (auto *VecTy =
453
+ dyn_cast<FixedVectorType>(II.getOperand (0 )->getType ())) {
454
+ const unsigned NumElements = VecTy->getNumElements ();
455
+ ConstZero = ConstantVector::getSplat (
456
+ ElementCount::getFixed (NumElements), ConstZero);
457
+ ConstOne = ConstantVector::getSplat (
458
+ ElementCount::getFixed (NumElements), ConstOne);
459
+ }
460
+ Value *ExtendedBase =
461
+ Builder.CreateSelect (II.getOperand (0 ), ConstOne, ConstZero);
462
+ Value *ExtendedShift =
463
+ Builder.CreateSelect (II.getOperand (1 ), ConstOne, ConstZero);
464
+ Value *ExtendedShiftedVal =
465
+ Builder.CreateLShr (ExtendedBase, ExtendedShift);
466
+ SmallVector<User *, 8 > Users (II.users ());
467
+ for (User *U : Users) {
468
+ if (auto *UI = dyn_cast<Instruction>(U)) {
469
+ if (UI->getOpcode () == Instruction::ZExt) {
470
+ UI->dropAllReferences ();
471
+ UI->replaceAllUsesWith (ExtendedShiftedVal);
472
+ ToErase.push_back (UI);
473
+ continue ;
474
+ }
475
+ }
476
+ if (!CmpNEInst) {
477
+ CmpNEInst = Builder.CreateICmpNE (ExtendedShiftedVal, ConstZero);
478
+ }
479
+ U->replaceUsesOfWith (&II, CmpNEInst);
471
480
}
481
+ ToErase.push_back (&II);
472
482
}
473
- ToErase.push_back (&II);
474
483
}
475
484
476
485
// Remove optimization info not supported by SPIRV
0 commit comments