Skip to content

Commit 239fbd4

Browse files
authored
Fix SPIRVRegularizeLLVMBase::regularize fix for shl i1 and lshr i1 (#2288)
The translator failed assertion with V->user_empty() during regularize function when shl i1 or lshr i1 result is used. E.g. %2 = shl i1 %0 %1 store %2, ptr addrspace(1) @G.1, align 1 Instruction shl i1 is converted to lshr i32 which arithmetic have the same behavior.
1 parent e8b2018 commit 239fbd4

File tree

2 files changed

+102
-36
lines changed

2 files changed

+102
-36
lines changed

lib/SPIRV/SPIRVRegularizeLLVM.cpp

Lines changed: 45 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -322,31 +322,6 @@ void SPIRVRegularizeLLVMBase::expandSYCLTypeUsing(Module *M) {
322322
expandVIDWithSYCLTypeByValComp(F);
323323
}
324324

325-
Value *SPIRVRegularizeLLVMBase::extendBitInstBoolArg(Instruction *II) {
326-
IRBuilder<> Builder(II);
327-
auto *ArgTy = II->getOperand(0)->getType();
328-
Type *NewArgType = nullptr;
329-
if (ArgTy->isIntegerTy()) {
330-
NewArgType = Builder.getInt32Ty();
331-
} else if (ArgTy->isVectorTy() &&
332-
cast<VectorType>(ArgTy)->getElementType()->isIntegerTy()) {
333-
unsigned NumElements = cast<FixedVectorType>(ArgTy)->getNumElements();
334-
NewArgType = VectorType::get(Builder.getInt32Ty(), NumElements, false);
335-
} else {
336-
llvm_unreachable("Unexpected type");
337-
}
338-
auto *NewBase = Builder.CreateZExt(II->getOperand(0), NewArgType);
339-
auto *NewShift = Builder.CreateZExt(II->getOperand(1), NewArgType);
340-
switch (II->getOpcode()) {
341-
case Instruction::LShr:
342-
return Builder.CreateLShr(NewBase, NewShift);
343-
case Instruction::Shl:
344-
return Builder.CreateShl(NewBase, NewShift);
345-
default:
346-
return II;
347-
}
348-
}
349-
350325
bool SPIRVRegularizeLLVMBase::runRegularizeLLVM(Module &Module) {
351326
M = &Module;
352327
Ctx = &M->getContext();
@@ -393,19 +368,53 @@ bool SPIRVRegularizeLLVMBase::regularize() {
393368
}
394369
}
395370

396-
// Translator treats i1 as boolean, but bit instructions take
397-
// a scalar/vector integers, so we have to extend such arguments
398-
if (II.isLogicalShift() &&
399-
II.getOperand(0)->getType()->isIntOrIntVectorTy(1)) {
400-
auto *NewInst = extendBitInstBoolArg(&II);
401-
for (auto *U : II.users()) {
402-
if (cast<Instruction>(U)->getOpcode() == Instruction::ZExt) {
403-
U->dropAllReferences();
404-
U->replaceAllUsesWith(NewInst);
405-
ToErase.push_back(cast<Instruction>(U));
371+
if (II.isLogicalShift()) {
372+
// Translator treats i1 as boolean, but bit instructions take
373+
// a scalar/vector integers, so we have to extend such arguments.
374+
// shl i1 %a %b and lshr i1 %a %b are now converted on:
375+
// %0 = select i1 %a, i32 1, i32 0
376+
// %1 = select i1 %b, i32 1, i32 0
377+
// %2 = lshr i32 %0, %1
378+
// if any other instruction other than zext was dependant:
379+
// %3 = icmp ne i32 %2, 0
380+
// which converts it back to i1 and replace original result with %3
381+
// to dependant instructions.
382+
if (II.getOperand(0)->getType()->isIntOrIntVectorTy(1)) {
383+
IRBuilder<> Builder(&II);
384+
Value *CmpNEInst = nullptr;
385+
Constant *ConstZero = ConstantInt::get(Builder.getInt32Ty(), 0);
386+
Constant *ConstOne = ConstantInt::get(Builder.getInt32Ty(), 1);
387+
if (auto *VecTy =
388+
dyn_cast<FixedVectorType>(II.getOperand(0)->getType())) {
389+
const unsigned NumElements = VecTy->getNumElements();
390+
ConstZero = ConstantVector::getSplat(
391+
ElementCount::getFixed(NumElements), ConstZero);
392+
ConstOne = ConstantVector::getSplat(
393+
ElementCount::getFixed(NumElements), ConstOne);
394+
}
395+
Value *ExtendedBase =
396+
Builder.CreateSelect(II.getOperand(0), ConstOne, ConstZero);
397+
Value *ExtendedShift =
398+
Builder.CreateSelect(II.getOperand(1), ConstOne, ConstZero);
399+
Value *ExtendedShiftedVal =
400+
Builder.CreateLShr(ExtendedBase, ExtendedShift);
401+
SmallVector<User *, 8> Users(II.users());
402+
for (User *U : Users) {
403+
if (auto *UI = dyn_cast<Instruction>(U)) {
404+
if (UI->getOpcode() == Instruction::ZExt) {
405+
UI->dropAllReferences();
406+
UI->replaceAllUsesWith(ExtendedShiftedVal);
407+
ToErase.push_back(UI);
408+
continue;
409+
}
410+
}
411+
if (!CmpNEInst) {
412+
CmpNEInst = Builder.CreateICmpNE(ExtendedShiftedVal, ConstZero);
413+
}
414+
U->replaceUsesOfWith(&II, CmpNEInst);
406415
}
416+
ToErase.push_back(&II);
407417
}
408-
ToErase.push_back(&II);
409418
}
410419

411420
// Remove optimization info not supported by SPIRV

test/lshr_shl_i1_regularize.ll

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
; RUN: llvm-as %s -o %t.bc
2+
; RUN: llvm-spirv -s %t.bc -o %t.reg.bc
3+
; RUN: llvm-dis %t.reg.bc -o - | FileCheck --check-prefix=CHECK-LLVM %s
4+
5+
target triple = "spir64-unknown-unknown"
6+
7+
@G.0 = addrspace(1) global i1 false
8+
@G.1 = addrspace(1) global i1 true
9+
@G.2 = addrspace(1) global <8 x i1> <i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false>
10+
11+
define spir_func void @test_lshr_i1(i1 %a, i1 %b) {
12+
entry:
13+
%0 = lshr i1 %a, %b
14+
; CHECK-LLVM: [[AI32_0:%[0-9]+]] = select i1 %a, i32 1, i32 0
15+
; CHECK-LLVM: [[BI32_0:%[0-9]+]] = select i1 %b, i32 1, i32 0
16+
; CHECK-LLVM: [[LSHRI32_0:%[0-9]+]] = lshr i32 [[AI32_0]], [[BI32_0]]
17+
; CHECK-LLVM: [[TRUNC_0:%[0-9]+]] = icmp ne i32 [[LSHRI32_0]], 0
18+
%1 = zext i1 %0 to i32
19+
%2 = zext i1 %0 to i32
20+
; CHECK-LLVM-NOT zext
21+
; CHECK-LLVM-NOT select
22+
store i1 %0, ptr addrspace(1) @G.0, align 1
23+
; CHECK-LLVM: store i1 [[TRUNC_0]], ptr addrspace(1) @G.0, align 1
24+
ret void
25+
}
26+
27+
define spir_func void @test_shl_i1(i1 %a, i1 %b) {
28+
entry:
29+
%0 = shl i1 %a, %b
30+
; CHECK-LLVM: [[AI32_1:%[0-9]+]] = select i1 %a, i32 1, i32 0
31+
; CHECK-LLVM: [[BI32_1:%[0-9]+]] = select i1 %b, i32 1, i32 0
32+
; CHECK-LLVM: [[LSHR32_1:%[0-9]+]] = lshr i32 [[AI32_1]], [[BI32_1]]
33+
; CHECK-LLVM: [[TRUNC_1:%[0-9]+]] = icmp ne i32 [[LSHR32_1]], 0
34+
%1 = zext i1 %0 to i32
35+
%2 = zext i1 %0 to i32
36+
; CHECK-LLVM-NOT: zext
37+
; CHECK-LLVM-NOT: select
38+
store i1 %0, ptr addrspace(1) @G.1, align 1
39+
; CHECK-LLVM: store i1 [[TRUNC_1]], ptr addrspace(1) @G.1, align 1
40+
ret void
41+
}
42+
43+
define spir_func void @test_shl_vec_i1(<8 x i1> %a, <8 x i1> %b) {
44+
entry:
45+
%0 = shl <8 x i1> %a, %b
46+
; CHECK-LLVM: [[AI32_2:%[0-9]+]] = select <8 x i1> %a, <8 x i32> <i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1>, <8 x i32> zeroinitializer
47+
; CHECK-LLVM: [[BI32_2:%[0-9]+]] = select <8 x i1> %b, <8 x i32> <i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1>, <8 x i32> zeroinitializer
48+
; CHECK-LLVM: [[LSHR32_2:%[0-9]+]] = lshr <8 x i32> [[AI32_2]], [[BI32_2]]
49+
; CHECK-LLVM: [[TRUNC_2:%[0-9]+]] = icmp ne <8 x i32> [[LSHR32_2]], zeroinitializer
50+
%1 = zext <8 x i1> %0 to <8 x i32>
51+
%2 = zext <8 x i1> %0 to <8 x i32>
52+
; CHECK-LLVM-NOT: zext
53+
; CHECK-LLVM-NOT: select
54+
store <8 x i1> %0, ptr addrspace(1) @G.2, align 1
55+
; CHECK-LLVM: store <8 x i1> [[TRUNC_2]], ptr addrspace(1) @G.2, align 1
56+
ret void
57+
}

0 commit comments

Comments
 (0)