Skip to content

Commit 6732fee

Browse files
bwlodarczsys-ce-bb
authored andcommitted
Fix SPIRVRegularizeLLVMBase::regularize fix for shl i1 and lshr i1 (#2288)
The translator failed assertion with V->user_empty() during regularize function when shl i1 or lshr i1 result is used. E.g. %2 = shl i1 %0 %1 store %2, ptr addrspace(1) @G.1, align 1 Instruction shl i1 is converted to lshr i32 which arithmetic have the same behavior. Original commit: KhronosGroup/SPIRV-LLVM-Translator@239fbd4
1 parent cccbd9e commit 6732fee

File tree

2 files changed

+102
-36
lines changed

2 files changed

+102
-36
lines changed

llvm-spirv/lib/SPIRV/SPIRVRegularizeLLVM.cpp

Lines changed: 45 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -321,31 +321,6 @@ void SPIRVRegularizeLLVMBase::expandSYCLTypeUsing(Module *M) {
321321
expandVIDWithSYCLTypeByValComp(F);
322322
}
323323

324-
Value *SPIRVRegularizeLLVMBase::extendBitInstBoolArg(Instruction *II) {
325-
IRBuilder<> Builder(II);
326-
auto *ArgTy = II->getOperand(0)->getType();
327-
Type *NewArgType = nullptr;
328-
if (ArgTy->isIntegerTy()) {
329-
NewArgType = Builder.getInt32Ty();
330-
} else if (ArgTy->isVectorTy() &&
331-
cast<VectorType>(ArgTy)->getElementType()->isIntegerTy()) {
332-
unsigned NumElements = cast<FixedVectorType>(ArgTy)->getNumElements();
333-
NewArgType = VectorType::get(Builder.getInt32Ty(), NumElements, false);
334-
} else {
335-
llvm_unreachable("Unexpected type");
336-
}
337-
auto *NewBase = Builder.CreateZExt(II->getOperand(0), NewArgType);
338-
auto *NewShift = Builder.CreateZExt(II->getOperand(1), NewArgType);
339-
switch (II->getOpcode()) {
340-
case Instruction::LShr:
341-
return Builder.CreateLShr(NewBase, NewShift);
342-
case Instruction::Shl:
343-
return Builder.CreateShl(NewBase, NewShift);
344-
default:
345-
return II;
346-
}
347-
}
348-
349324
bool SPIRVRegularizeLLVMBase::runRegularizeLLVM(Module &Module) {
350325
M = &Module;
351326
Ctx = &M->getContext();
@@ -458,19 +433,53 @@ bool SPIRVRegularizeLLVMBase::regularize() {
458433
}
459434
}
460435

461-
// Translator treats i1 as boolean, but bit instructions take
462-
// a scalar/vector integers, so we have to extend such arguments
463-
if (II.isLogicalShift() &&
464-
II.getOperand(0)->getType()->isIntOrIntVectorTy(1)) {
465-
auto *NewInst = extendBitInstBoolArg(&II);
466-
for (auto *U : II.users()) {
467-
if (cast<Instruction>(U)->getOpcode() == Instruction::ZExt) {
468-
U->dropAllReferences();
469-
U->replaceAllUsesWith(NewInst);
470-
ToErase.push_back(cast<Instruction>(U));
436+
if (II.isLogicalShift()) {
437+
// Translator treats i1 as boolean, but bit instructions take
438+
// a scalar/vector integers, so we have to extend such arguments.
439+
// shl i1 %a %b and lshr i1 %a %b are now converted on:
440+
// %0 = select i1 %a, i32 1, i32 0
441+
// %1 = select i1 %b, i32 1, i32 0
442+
// %2 = lshr i32 %0, %1
443+
// if any other instruction other than zext was dependant:
444+
// %3 = icmp ne i32 %2, 0
445+
// which converts it back to i1 and replace original result with %3
446+
// to dependant instructions.
447+
if (II.getOperand(0)->getType()->isIntOrIntVectorTy(1)) {
448+
IRBuilder<> Builder(&II);
449+
Value *CmpNEInst = nullptr;
450+
Constant *ConstZero = ConstantInt::get(Builder.getInt32Ty(), 0);
451+
Constant *ConstOne = ConstantInt::get(Builder.getInt32Ty(), 1);
452+
if (auto *VecTy =
453+
dyn_cast<FixedVectorType>(II.getOperand(0)->getType())) {
454+
const unsigned NumElements = VecTy->getNumElements();
455+
ConstZero = ConstantVector::getSplat(
456+
ElementCount::getFixed(NumElements), ConstZero);
457+
ConstOne = ConstantVector::getSplat(
458+
ElementCount::getFixed(NumElements), ConstOne);
459+
}
460+
Value *ExtendedBase =
461+
Builder.CreateSelect(II.getOperand(0), ConstOne, ConstZero);
462+
Value *ExtendedShift =
463+
Builder.CreateSelect(II.getOperand(1), ConstOne, ConstZero);
464+
Value *ExtendedShiftedVal =
465+
Builder.CreateLShr(ExtendedBase, ExtendedShift);
466+
SmallVector<User *, 8> Users(II.users());
467+
for (User *U : Users) {
468+
if (auto *UI = dyn_cast<Instruction>(U)) {
469+
if (UI->getOpcode() == Instruction::ZExt) {
470+
UI->dropAllReferences();
471+
UI->replaceAllUsesWith(ExtendedShiftedVal);
472+
ToErase.push_back(UI);
473+
continue;
474+
}
475+
}
476+
if (!CmpNEInst) {
477+
CmpNEInst = Builder.CreateICmpNE(ExtendedShiftedVal, ConstZero);
478+
}
479+
U->replaceUsesOfWith(&II, CmpNEInst);
471480
}
481+
ToErase.push_back(&II);
472482
}
473-
ToErase.push_back(&II);
474483
}
475484

476485
// Remove optimization info not supported by SPIRV
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
; RUN: llvm-as %s -o %t.bc
2+
; RUN: llvm-spirv -s %t.bc -o %t.reg.bc
3+
; RUN: llvm-dis %t.reg.bc -o - | FileCheck --check-prefix=CHECK-LLVM %s
4+
5+
target triple = "spir64-unknown-unknown"
6+
7+
@G.0 = addrspace(1) global i1 false
8+
@G.1 = addrspace(1) global i1 true
9+
@G.2 = addrspace(1) global <8 x i1> <i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false, i1 false>
10+
11+
define spir_func void @test_lshr_i1(i1 %a, i1 %b) {
12+
entry:
13+
%0 = lshr i1 %a, %b
14+
; CHECK-LLVM: [[AI32_0:%[0-9]+]] = select i1 %a, i32 1, i32 0
15+
; CHECK-LLVM: [[BI32_0:%[0-9]+]] = select i1 %b, i32 1, i32 0
16+
; CHECK-LLVM: [[LSHRI32_0:%[0-9]+]] = lshr i32 [[AI32_0]], [[BI32_0]]
17+
; CHECK-LLVM: [[TRUNC_0:%[0-9]+]] = icmp ne i32 [[LSHRI32_0]], 0
18+
%1 = zext i1 %0 to i32
19+
%2 = zext i1 %0 to i32
20+
; CHECK-LLVM-NOT zext
21+
; CHECK-LLVM-NOT select
22+
store i1 %0, ptr addrspace(1) @G.0, align 1
23+
; CHECK-LLVM: store i1 [[TRUNC_0]], ptr addrspace(1) @G.0, align 1
24+
ret void
25+
}
26+
27+
define spir_func void @test_shl_i1(i1 %a, i1 %b) {
28+
entry:
29+
%0 = shl i1 %a, %b
30+
; CHECK-LLVM: [[AI32_1:%[0-9]+]] = select i1 %a, i32 1, i32 0
31+
; CHECK-LLVM: [[BI32_1:%[0-9]+]] = select i1 %b, i32 1, i32 0
32+
; CHECK-LLVM: [[LSHR32_1:%[0-9]+]] = lshr i32 [[AI32_1]], [[BI32_1]]
33+
; CHECK-LLVM: [[TRUNC_1:%[0-9]+]] = icmp ne i32 [[LSHR32_1]], 0
34+
%1 = zext i1 %0 to i32
35+
%2 = zext i1 %0 to i32
36+
; CHECK-LLVM-NOT: zext
37+
; CHECK-LLVM-NOT: select
38+
store i1 %0, ptr addrspace(1) @G.1, align 1
39+
; CHECK-LLVM: store i1 [[TRUNC_1]], ptr addrspace(1) @G.1, align 1
40+
ret void
41+
}
42+
43+
define spir_func void @test_shl_vec_i1(<8 x i1> %a, <8 x i1> %b) {
44+
entry:
45+
%0 = shl <8 x i1> %a, %b
46+
; CHECK-LLVM: [[AI32_2:%[0-9]+]] = select <8 x i1> %a, <8 x i32> <i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1>, <8 x i32> zeroinitializer
47+
; CHECK-LLVM: [[BI32_2:%[0-9]+]] = select <8 x i1> %b, <8 x i32> <i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1, i32 1>, <8 x i32> zeroinitializer
48+
; CHECK-LLVM: [[LSHR32_2:%[0-9]+]] = lshr <8 x i32> [[AI32_2]], [[BI32_2]]
49+
; CHECK-LLVM: [[TRUNC_2:%[0-9]+]] = icmp ne <8 x i32> [[LSHR32_2]], zeroinitializer
50+
%1 = zext <8 x i1> %0 to <8 x i32>
51+
%2 = zext <8 x i1> %0 to <8 x i32>
52+
; CHECK-LLVM-NOT: zext
53+
; CHECK-LLVM-NOT: select
54+
store <8 x i1> %0, ptr addrspace(1) @G.2, align 1
55+
; CHECK-LLVM: store <8 x i1> [[TRUNC_2]], ptr addrspace(1) @G.2, align 1
56+
ret void
57+
}

0 commit comments

Comments
 (0)