Skip to content

Commit 83b2dfd

Browse files
jcranmer-intelsvenvh
authored andcommitted
Make the LowerBitCast pass support opaque pointers.
This is a relatively large change, as the original pass relied on being able to track the initial bitcast <3 x i64>* to <6 x i32>* to know where to start rewriting. Instead, this patch starts at the final invalid extractelement call and works its way backwards as far as necessary to generate correct code.
1 parent 1555107 commit 83b2dfd

File tree

3 files changed

+183
-113
lines changed

3 files changed

+183
-113
lines changed

lib/SPIRV/SPIRVLowerBitCastToNonStandardType.cpp

Lines changed: 130 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -40,104 +40,83 @@
4040
// point types, 2/3/4/8/16-element vector of scalar types").
4141
//
4242
//===----------------------------------------------------------------------===//
43-
#define DEBUG_TYPE "spv-lower-bitcast-to-nonstandard-type"
44-
4543
#include "SPIRVInternal.h"
4644

4745
#include "llvm/IR/IRBuilder.h"
46+
#include "llvm/IR/NoFolder.h"
4847
#include "llvm/IR/PassManager.h"
4948
#include "llvm/Pass.h"
49+
#include "llvm/Transforms/Utils/Local.h"
5050

5151
#include <utility>
5252

53+
#define DEBUG_TYPE "spv-lower-bitcast-to-nonstandard-type"
54+
5355
using namespace llvm;
5456

5557
namespace SPIRV {
5658

57-
static VectorType *getVectorType(Type *Ty) {
58-
assert(Ty != nullptr && "Expected non-null type");
59-
if (auto *ElemTy = dyn_cast<PointerType>(Ty))
60-
Ty = ElemTy->getPointerElementType();
61-
return dyn_cast<VectorType>(Ty);
62-
}
59+
using NFIRBuilder = IRBuilder<NoFolder>;
6360

64-
/// Since SPIR-V does not support non-standard vector types, instructions using
65-
/// these types should be replaced in a special way to avoid using of
66-
/// unsupported types.
67-
/// lowerBitCastToNonStdVec function is designed to avoid using of bitcast to
68-
/// unsupported vector types instructions and should be called if similar
69-
/// instructions have been encountered in input LLVM IR.
70-
bool lowerBitCastToNonStdVec(Instruction *OldInst, Value *NewInst,
71-
const VectorType *OldVecTy,
72-
std::vector<Instruction *> &InstsToErase,
73-
IRBuilder<> &Builder,
74-
unsigned RecursionDepth = 0) {
75-
static constexpr unsigned MaxRecursionDepth = 16;
76-
if (RecursionDepth++ > MaxRecursionDepth)
77-
report_fatal_error(
78-
llvm::Twine(
79-
"The depth of recursion exceeds the maximum possible depth"),
80-
false);
81-
82-
bool Changed = false;
83-
VectorType *NewVecTy = getVectorType(NewInst->getType());
84-
if (NewVecTy) {
85-
Builder.SetInsertPoint(OldInst);
86-
for (auto *U : OldInst->users()) {
87-
// Handle addrspacecast instruction after bitcast if present
88-
if (auto *ASCastInst = dyn_cast<AddrSpaceCastInst>(U)) {
89-
unsigned DestAS = ASCastInst->getDestAddressSpace();
90-
auto *NewVecPtrTy = NewVecTy->getPointerTo(DestAS);
91-
// AddrSpaceCast is created explicitly instead of using method
92-
// IRBuilder<>.CreateAddrSpaceCast because IRBuilder doesn't create
93-
// separate instruction for constant values. Whereas SPIR-V translator
94-
// doesn't like several nested instructions in one.
95-
Value *LocalValue = new AddrSpaceCastInst(NewInst, NewVecPtrTy);
96-
Builder.Insert(LocalValue);
97-
Changed |=
98-
lowerBitCastToNonStdVec(ASCastInst, LocalValue, OldVecTy,
99-
InstsToErase, Builder, RecursionDepth);
100-
}
101-
// Handle load instruction which is following the bitcast in the pattern
102-
else if (auto *LI = dyn_cast<LoadInst>(U)) {
103-
Value *LocalValue = Builder.CreateLoad(NewVecTy, NewInst);
104-
Changed |= lowerBitCastToNonStdVec(
105-
LI, LocalValue, OldVecTy, InstsToErase, Builder, RecursionDepth);
106-
}
107-
// Handle extractelement instruction which is following the load
108-
else if (auto *EEI = dyn_cast<ExtractElementInst>(U)) {
109-
uint64_t NumElemsInOldVec = OldVecTy->getElementCount().getFixedValue();
110-
uint64_t NumElemsInNewVec = NewVecTy->getElementCount().getFixedValue();
111-
uint64_t OldElemIdx =
112-
cast<ConstantInt>(EEI->getIndexOperand())->getZExtValue();
113-
uint64_t NewElemIdx =
114-
OldElemIdx / (NumElemsInOldVec / NumElemsInNewVec);
115-
Value *LocalValue = Builder.CreateExtractElement(NewInst, NewElemIdx);
116-
// The trunc instruction truncates the high order bits in value, so it
117-
// may be necessary to shift right high order bits, if required bits are
118-
// not at the end of extracted value
119-
unsigned OldVecElemBitWidth =
120-
cast<IntegerType>(OldVecTy->getElementType())->getBitWidth();
121-
unsigned NewVecElemBitWidth =
122-
cast<IntegerType>(NewVecTy->getElementType())->getBitWidth();
123-
unsigned BitWidthRatio = NewVecElemBitWidth / OldVecElemBitWidth;
124-
if (auto RequiredBitsIdx =
125-
OldElemIdx % BitWidthRatio != BitWidthRatio - 1) {
126-
uint64_t Shift =
127-
OldVecElemBitWidth * (BitWidthRatio - RequiredBitsIdx);
128-
LocalValue = Builder.CreateLShr(LocalValue, Shift);
129-
}
130-
LocalValue =
131-
Builder.CreateTrunc(LocalValue, OldVecTy->getElementType());
132-
Changed |= lowerBitCastToNonStdVec(
133-
EEI, LocalValue, OldVecTy, InstsToErase, Builder, RecursionDepth);
61+
static Value *removeBitCasts(Value *OldValue, Type *NewTy, NFIRBuilder &Builder,
62+
std::vector<Instruction *> &InstsToErase) {
63+
IRBuilderBase::InsertPointGuard Guard(Builder);
64+
auto RauwBitcasts = [&](Instruction *OldValue, Value *NewValue) {
65+
// If there's only one use, don't create a bitcast for any uses, since it
66+
// will be immediately replaced anyways.
67+
if (OldValue->hasOneUse()) {
68+
OldValue->replaceAllUsesWith(UndefValue::get(OldValue->getType()));
69+
} else {
70+
OldValue->replaceAllUsesWith(
71+
Builder.CreateBitCast(NewValue, OldValue->getType()));
72+
}
73+
InstsToErase.push_back(OldValue);
74+
return NewValue;
75+
};
76+
77+
if (auto *LI = dyn_cast<LoadInst>(OldValue)) {
78+
Builder.SetInsertPoint(LI);
79+
Value *Pointer = LI->getPointerOperand();
80+
if (!Pointer->getType()->isOpaquePointerTy()) {
81+
Type *NewPointerTy =
82+
PointerType::get(NewTy, LI->getPointerAddressSpace());
83+
Pointer = removeBitCasts(Pointer, NewPointerTy, Builder, InstsToErase);
84+
}
85+
LoadInst *NewLI = Builder.CreateAlignedLoad(NewTy, Pointer, LI->getAlign(),
86+
LI->isVolatile());
87+
NewLI->setOrdering(LI->getOrdering());
88+
NewLI->setSyncScopeID(LI->getSyncScopeID());
89+
return RauwBitcasts(LI, NewLI);
90+
}
91+
92+
if (auto *ASCI = dyn_cast<AddrSpaceCastInst>(OldValue)) {
93+
Builder.SetInsertPoint(ASCI);
94+
Type *NewSrcTy = PointerType::getWithSamePointeeType(
95+
cast<PointerType>(NewTy), ASCI->getSrcAddressSpace());
96+
Value *Pointer = removeBitCasts(ASCI->getPointerOperand(), NewSrcTy,
97+
Builder, InstsToErase);
98+
return RauwBitcasts(ASCI, Builder.CreateAddrSpaceCast(Pointer, NewTy));
99+
}
100+
101+
if (auto *BC = dyn_cast<BitCastInst>(OldValue)) {
102+
if (BC->getSrcTy() == NewTy) {
103+
if (BC->hasOneUse()) {
104+
BC->replaceAllUsesWith(UndefValue::get(BC->getType()));
105+
InstsToErase.push_back(BC);
134106
}
107+
return BC->getOperand(0);
135108
}
109+
Builder.SetInsertPoint(BC);
110+
return RauwBitcasts(BC, Builder.CreateBitCast(BC->getOperand(0), NewTy));
136111
}
137-
InstsToErase.push_back(OldInst);
138-
if (!Changed)
139-
OldInst->replaceAllUsesWith(NewInst);
140-
return true;
112+
113+
report_fatal_error("Cannot translate source of bitcast instruction.");
114+
return nullptr;
115+
}
116+
117+
static bool isNonStdVecType(VectorType *VecTy) {
118+
uint64_t NumElems = VecTy->getElementCount().getFixedValue();
119+
return !isValidVectorSize(NumElems);
141120
}
142121

143122
class SPIRVLowerBitCastToNonStandardTypePass
@@ -160,41 +139,82 @@ class SPIRVLowerBitCastToNonStandardTypePass
160139
if (Opts.isAllowedToUseExtension(ExtensionID::SPV_INTEL_vector_compute))
161140
return PreservedAnalyses::all();
162141

163-
std::vector<Instruction *> BCastsToNonStdVec;
164-
std::vector<Instruction *> InstsToErase;
142+
// The basic pattern we're trying to fix is this InstCombine pattern:
143+
// trunc (extractelement) -> extractelement (bitcast)
144+
// (note that the bitcast itself can get propagated back to change the type
145+
// of load instructions, and even through those to pointer casts, if typed
146+
// pointers are enabled.
147+
std::vector<ExtractElementInst *> NonStdVecInsts;
148+
SmallVector<WeakTrackingVH, 4> MaybeDeletedInsts;
165149
for (auto &BB : F)
166150
for (auto &I : BB) {
167-
auto *BC = dyn_cast<BitCastInst>(&I);
168-
if (!BC)
169-
continue;
170-
VectorType *SrcVecTy = getVectorType(BC->getSrcTy());
171-
if (SrcVecTy) {
172-
uint64_t NumElemsInSrcVec =
173-
SrcVecTy->getElementCount().getFixedValue();
174-
if (!isValidVectorSize(NumElemsInSrcVec))
175-
report_fatal_error(
176-
llvm::Twine("Unsupported vector type with the size of: " +
177-
std::to_string(NumElemsInSrcVec)),
178-
false);
179-
}
180-
VectorType *DestVecTy = getVectorType(BC->getDestTy());
181-
if (DestVecTy) {
182-
uint64_t NumElemsInDestVec =
183-
DestVecTy->getElementCount().getFixedValue();
184-
if (!isValidVectorSize(NumElemsInDestVec))
185-
BCastsToNonStdVec.push_back(&I);
151+
if (auto *EI = dyn_cast<ExtractElementInst>(&I)) {
152+
if (isNonStdVecType(EI->getVectorOperandType()))
153+
NonStdVecInsts.push_back(EI);
154+
} else if (auto *VT = dyn_cast<VectorType>(I.getType())) {
155+
if (isNonStdVecType(VT)) {
156+
MaybeDeletedInsts.push_back(&I);
157+
}
186158
}
187159
}
188-
IRBuilder<> Builder(F.getContext());
189-
for (auto &I : BCastsToNonStdVec) {
190-
Value *NewValue = I->getOperand(0);
191-
VectorType *OldVecTy = getVectorType(I->getType());
192-
Changed |=
193-
lowerBitCastToNonStdVec(I, NewValue, OldVecTy, InstsToErase, Builder);
160+
161+
std::vector<Instruction *> InstsToErase;
162+
NFIRBuilder Builder(F.getContext());
163+
for (auto &I : NonStdVecInsts) {
164+
VectorType *OldVecTy = I->getVectorOperandType();
165+
unsigned OldVecSize = OldVecTy->getElementCount().getFixedValue();
166+
167+
// Compute the adjustment factor for the new vector size.
168+
unsigned VecFactor = 2;
169+
while (OldVecSize % VecFactor == 0 &&
170+
!isValidVectorSize(OldVecSize / VecFactor))
171+
VecFactor *= 2;
172+
if (OldVecSize % VecFactor != 0) {
173+
report_fatal_error(Twine("Invalid vector size for fixup: ") +
174+
Twine(OldVecSize));
175+
return PreservedAnalyses::none();
176+
}
177+
unsigned NewElemSize = OldVecTy->getScalarSizeInBits() * VecFactor;
178+
VectorType *NewVecTy =
179+
VectorType::get(Type::getIntNTy(F.getContext(), NewElemSize),
180+
OldVecSize / VecFactor, false);
181+
182+
// Adjust the element index as appropriate.
183+
uint64_t OldElemIdx =
184+
cast<ConstantInt>(I->getIndexOperand())->getZExtValue();
185+
uint64_t NewElemIdx = OldElemIdx / VecFactor;
186+
uint64_t ShiftCount = OldElemIdx % VecFactor;
187+
Builder.SetInsertPoint(I);
188+
Value *NewVecOp = removeBitCasts(I->getVectorOperand(), NewVecTy, Builder,
189+
InstsToErase);
190+
Value *NewExtracted = Builder.CreateExtractElement(NewVecOp, NewElemIdx);
191+
192+
// If the extract does higher-order bits of the value, shift as necessary.
193+
if (ShiftCount > 0)
194+
NewExtracted = Builder.CreateLShr(
195+
NewExtracted, ShiftCount * OldVecTy->getScalarSizeInBits());
196+
197+
Value *NewValue = Builder.CreateTrunc(NewExtracted, I->getType());
198+
I->replaceAllUsesWith(NewValue);
199+
I->eraseFromParent();
200+
Changed = true;
194201
}
195202

196203
for (auto *I : InstsToErase)
197-
I->eraseFromParent();
204+
RecursivelyDeleteTriviallyDeadInstructions(I);
205+
206+
// Check if there are any residual unsupported vector types.
207+
for (auto &VH : MaybeDeletedInsts) {
208+
// Some vector-valued instructions were replaced with undef values, so if
209+
// that's what we got, it's still a dead instruction.
210+
if (VH.pointsToAliveValue() && !isa<UndefValue>(VH)) {
211+
auto *VT = dyn_cast<VectorType>(VH->getType());
212+
report_fatal_error(Twine("Unsupported vector type with ") +
213+
Twine(VT->getElementCount().getFixedValue()) +
214+
Twine(" elements"),
215+
false);
216+
}
217+
}
198218

199219
return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
200220
}
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
; RUN: llvm-as %s -o %t.bc
2+
; RUN: llvm-spirv -s %t.bc -o - | llvm-dis -o - | FileCheck %s --implicit-check-not="<6 x i32>"
3+
4+
; CHECK: [[ASCastInst:%.*]] = addrspacecast ptr addrspace(1) @Id to ptr addrspace(4)
5+
; CHECK: [[LoadInst1:%.*]] = load <3 x i64>, ptr addrspace(4) [[ASCastInst]], align 32
6+
; CHECK: [[LoadInst2:%.*]] = load <3 x i64>, ptr addrspace(4) [[ASCastInst]], align 32
7+
; CHECK: [[ExtrElInst1:%.*]] = extractelement <3 x i64> [[LoadInst1]], i64 0
8+
; CHECK: [[TruncInst1:%.*]] = trunc i64 [[ExtrElInst1]] to i32
9+
; CHECK: [[ExtrElInst2:%.*]] = extractelement <3 x i64> [[LoadInst2]], i64 2
10+
; CHECK: [[LShrInst:%.*]] = lshr i64 [[ExtrElInst2]], 32
11+
; CHECK: [[TruncInst2:%.*]] = trunc i64 [[LShrInst]] to i32
12+
; CHECK: %conv1 = sitofp i32 [[TruncInst1]] to float
13+
; CHECK: %conv2 = sitofp i32 [[TruncInst2]] to float
14+
15+
; ModuleID = 'lower-non-standard-types'
16+
source_filename = "lower-non-standard-types.cpp"
17+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
18+
target triple = "spir64-unknown-unknown"
19+
20+
@Id = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32
21+
22+
; Function Attrs: convergent norecurse
23+
define dso_local spir_func void @vmult2() local_unnamed_addr #0 !sycl_explicit_simd !4 !intel_reqd_sub_group_size !6 {
24+
entry:
25+
%0 = load <6 x i32>, ptr addrspace(4) addrspacecast (ptr addrspace(1) @Id to ptr addrspace(4)), align 32
26+
%1 = load <6 x i32>, ptr addrspace(4) addrspacecast (ptr addrspace(1) @Id to ptr addrspace(4)), align 32
27+
%2 = extractelement <6 x i32> %0, i32 0
28+
%3 = extractelement <6 x i32> %1, i32 5
29+
%conv1 = sitofp i32 %2 to float
30+
%conv2 = sitofp i32 %3 to float
31+
ret void
32+
}
33+
34+
attributes #0 = { convergent norecurse "frame-pointer"="all" "min-legal-vector-width"="256" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "sycl-module-id"="lower-external-funcs-with-z.cpp" }
35+
36+
!llvm.module.flags = !{!0, !1}
37+
!opencl.spir.version = !{!2}
38+
!spirv.Source = !{!3}
39+
!opencl.used.extensions = !{!4}
40+
!opencl.used.optional.core.features = !{!4}
41+
!opencl.compiler.options = !{!4}
42+
!llvm.ident = !{!5}
43+
44+
!0 = !{i32 1, !"wchar_size", i32 4}
45+
!1 = !{i32 7, !"frame-pointer", i32 2}
46+
!2 = !{i32 1, i32 2}
47+
!3 = !{i32 0, i32 100000}
48+
!4 = !{}
49+
!5 = !{!"Compiler"}
50+
!6 = !{i32 1}

test/lower-non-standard-types.ll

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33

44
; CHECK: [[ASCastInst:%.*]] = addrspacecast <3 x i64> addrspace(1)* @Id to <3 x i64> addrspace(4)*
55
; CHECK: [[LoadInst1:%.*]] = load <3 x i64>, <3 x i64> addrspace(4)* [[ASCastInst]], align 32
6+
; CHECK: [[LoadInst2:%.*]] = load <3 x i64>, <3 x i64> addrspace(4)* [[ASCastInst]], align 32
67
; CHECK: [[ExtrElInst1:%.*]] = extractelement <3 x i64> [[LoadInst1]], i64 0
78
; CHECK: [[TruncInst1:%.*]] = trunc i64 [[ExtrElInst1]] to i32
8-
; CHECK: [[LoadInst2:%.*]] = load <3 x i64>, <3 x i64> addrspace(4)* [[ASCastInst]], align 32
99
; CHECK: [[ExtrElInst2:%.*]] = extractelement <3 x i64> [[LoadInst2]], i64 2
1010
; CHECK: [[LShrInst:%.*]] = lshr i64 [[ExtrElInst2]], 32
1111
; CHECK: [[TruncInst2:%.*]] = trunc i64 [[LShrInst]] to i32
@@ -24,8 +24,8 @@ define dso_local spir_func void @vmult2() local_unnamed_addr #0 !sycl_explicit_s
2424
entry:
2525
%0 = load <6 x i32>, <6 x i32> addrspace(4)* addrspacecast (<6 x i32> addrspace(1)* bitcast (<3 x i64> addrspace(1)* @Id to <6 x i32> addrspace(1)*) to <6 x i32> addrspace(4)*), align 32
2626
%1 = load <6 x i32>, <6 x i32> addrspace(4)* addrspacecast (<6 x i32> addrspace(1)* bitcast (<3 x i64> addrspace(1)* @Id to <6 x i32> addrspace(1)*) to <6 x i32> addrspace(4)*), align 32
27-
%2 = extractelement <6 x i32> %0, i32 1
28-
%3 = extractelement <6 x i32> %1, i32 4
27+
%2 = extractelement <6 x i32> %0, i32 0
28+
%3 = extractelement <6 x i32> %1, i32 5
2929
%conv1 = sitofp i32 %2 to float
3030
%conv2 = sitofp i32 %3 to float
3131
ret void

0 commit comments

Comments
 (0)