@@ -322,31 +322,6 @@ void SPIRVRegularizeLLVMBase::expandSYCLTypeUsing(Module *M) {
322
322
expandVIDWithSYCLTypeByValComp (F);
323
323
}
324
324
325
- Value *SPIRVRegularizeLLVMBase::extendBitInstBoolArg (Instruction *II) {
326
- IRBuilder<> Builder (II);
327
- auto *ArgTy = II->getOperand (0 )->getType ();
328
- Type *NewArgType = nullptr ;
329
- if (ArgTy->isIntegerTy ()) {
330
- NewArgType = Builder.getInt32Ty ();
331
- } else if (ArgTy->isVectorTy () &&
332
- cast<VectorType>(ArgTy)->getElementType ()->isIntegerTy ()) {
333
- unsigned NumElements = cast<FixedVectorType>(ArgTy)->getNumElements ();
334
- NewArgType = VectorType::get (Builder.getInt32Ty (), NumElements, false );
335
- } else {
336
- llvm_unreachable (" Unexpected type" );
337
- }
338
- auto *NewBase = Builder.CreateZExt (II->getOperand (0 ), NewArgType);
339
- auto *NewShift = Builder.CreateZExt (II->getOperand (1 ), NewArgType);
340
- switch (II->getOpcode ()) {
341
- case Instruction::LShr:
342
- return Builder.CreateLShr (NewBase, NewShift);
343
- case Instruction::Shl:
344
- return Builder.CreateShl (NewBase, NewShift);
345
- default :
346
- return II;
347
- }
348
- }
349
-
350
325
bool SPIRVRegularizeLLVMBase::runRegularizeLLVM (Module &Module) {
351
326
M = &Module;
352
327
Ctx = &M->getContext ();
@@ -393,19 +368,53 @@ bool SPIRVRegularizeLLVMBase::regularize() {
393
368
}
394
369
}
395
370
396
- // Translator treats i1 as boolean, but bit instructions take
397
- // a scalar/vector integers, so we have to extend such arguments
398
- if (II.isLogicalShift () &&
399
- II.getOperand (0 )->getType ()->isIntOrIntVectorTy (1 )) {
400
- auto *NewInst = extendBitInstBoolArg (&II);
401
- for (auto *U : II.users ()) {
402
- if (cast<Instruction>(U)->getOpcode () == Instruction::ZExt) {
403
- U->dropAllReferences ();
404
- U->replaceAllUsesWith (NewInst);
405
- ToErase.push_back (cast<Instruction>(U));
371
+ if (II.isLogicalShift ()) {
372
+ // Translator treats i1 as boolean, but bit instructions take
373
+ // a scalar/vector integers, so we have to extend such arguments.
374
+ // shl i1 %a %b and lshr i1 %a %b are now converted on:
375
+ // %0 = select i1 %a, i32 1, i32 0
376
+ // %1 = select i1 %b, i32 1, i32 0
377
+ // %2 = lshr i32 %0, %1
378
+ // if any other instruction other than zext was dependant:
379
+ // %3 = icmp ne i32 %2, 0
380
+ // which converts it back to i1 and replace original result with %3
381
+ // to dependant instructions.
382
+ if (II.getOperand (0 )->getType ()->isIntOrIntVectorTy (1 )) {
383
+ IRBuilder<> Builder (&II);
384
+ Value *CmpNEInst = nullptr ;
385
+ Constant *ConstZero = ConstantInt::get (Builder.getInt32Ty (), 0 );
386
+ Constant *ConstOne = ConstantInt::get (Builder.getInt32Ty (), 1 );
387
+ if (auto *VecTy =
388
+ dyn_cast<FixedVectorType>(II.getOperand (0 )->getType ())) {
389
+ const unsigned NumElements = VecTy->getNumElements ();
390
+ ConstZero = ConstantVector::getSplat (
391
+ ElementCount::getFixed (NumElements), ConstZero);
392
+ ConstOne = ConstantVector::getSplat (
393
+ ElementCount::getFixed (NumElements), ConstOne);
394
+ }
395
+ Value *ExtendedBase =
396
+ Builder.CreateSelect (II.getOperand (0 ), ConstOne, ConstZero);
397
+ Value *ExtendedShift =
398
+ Builder.CreateSelect (II.getOperand (1 ), ConstOne, ConstZero);
399
+ Value *ExtendedShiftedVal =
400
+ Builder.CreateLShr (ExtendedBase, ExtendedShift);
401
+ SmallVector<User *, 8 > Users (II.users ());
402
+ for (User *U : Users) {
403
+ if (auto *UI = dyn_cast<Instruction>(U)) {
404
+ if (UI->getOpcode () == Instruction::ZExt) {
405
+ UI->dropAllReferences ();
406
+ UI->replaceAllUsesWith (ExtendedShiftedVal);
407
+ ToErase.push_back (UI);
408
+ continue ;
409
+ }
410
+ }
411
+ if (!CmpNEInst) {
412
+ CmpNEInst = Builder.CreateICmpNE (ExtendedShiftedVal, ConstZero);
413
+ }
414
+ U->replaceUsesOfWith (&II, CmpNEInst);
406
415
}
416
+ ToErase.push_back (&II);
407
417
}
408
- ToErase.push_back (&II);
409
418
}
410
419
411
420
// Remove optimization info not supported by SPIRV
0 commit comments