40
40
// point types, 2/3/4/8/16-element vector of scalar types").
41
41
//
42
42
// ===----------------------------------------------------------------------===//
43
- #define DEBUG_TYPE " spv-lower-bitcast-to-nonstandard-type"
44
-
45
43
#include " SPIRVInternal.h"
46
44
47
45
#include " llvm/IR/IRBuilder.h"
46
+ #include " llvm/IR/NoFolder.h"
48
47
#include " llvm/IR/PassManager.h"
49
48
#include " llvm/Pass.h"
49
+ #include " llvm/Transforms/Utils/Local.h"
50
50
51
51
#include < utility>
52
52
53
+ #define DEBUG_TYPE " spv-lower-bitcast-to-nonstandard-type"
54
+
53
55
using namespace llvm ;
54
56
55
57
namespace SPIRV {
56
58
57
- static VectorType *getVectorType (Type *Ty) {
58
- assert (Ty != nullptr && " Expected non-null type" );
59
- if (auto *ElemTy = dyn_cast<PointerType>(Ty))
60
- Ty = ElemTy->getPointerElementType ();
61
- return dyn_cast<VectorType>(Ty);
62
- }
59
+ using NFIRBuilder = IRBuilder<NoFolder>;
63
60
64
- // / Since SPIR-V does not support non-standard vector types, instructions using
65
- // / these types should be replaced in a special way to avoid using of
66
- // / unsupported types.
67
- // / lowerBitCastToNonStdVec function is designed to avoid using of bitcast to
68
- // / unsupported vector types instructions and should be called if similar
69
- // / instructions have been encountered in input LLVM IR.
70
- bool lowerBitCastToNonStdVec (Instruction *OldInst, Value *NewInst,
71
- const VectorType *OldVecTy,
72
- std::vector<Instruction *> &InstsToErase,
73
- IRBuilder<> &Builder,
74
- unsigned RecursionDepth = 0 ) {
75
- static constexpr unsigned MaxRecursionDepth = 16 ;
76
- if (RecursionDepth++ > MaxRecursionDepth)
77
- report_fatal_error (
78
- llvm::Twine (
79
- " The depth of recursion exceeds the maximum possible depth" ),
80
- false );
81
-
82
- bool Changed = false ;
83
- VectorType *NewVecTy = getVectorType (NewInst->getType ());
84
- if (NewVecTy) {
85
- Builder.SetInsertPoint (OldInst);
86
- for (auto *U : OldInst->users ()) {
87
- // Handle addrspacecast instruction after bitcast if present
88
- if (auto *ASCastInst = dyn_cast<AddrSpaceCastInst>(U)) {
89
- unsigned DestAS = ASCastInst->getDestAddressSpace ();
90
- auto *NewVecPtrTy = NewVecTy->getPointerTo (DestAS);
91
- // AddrSpaceCast is created explicitly instead of using method
92
- // IRBuilder<>.CreateAddrSpaceCast because IRBuilder doesn't create
93
- // separate instruction for constant values. Whereas SPIR-V translator
94
- // doesn't like several nested instructions in one.
95
- Value *LocalValue = new AddrSpaceCastInst (NewInst, NewVecPtrTy);
96
- Builder.Insert (LocalValue);
97
- Changed |=
98
- lowerBitCastToNonStdVec (ASCastInst, LocalValue, OldVecTy,
99
- InstsToErase, Builder, RecursionDepth);
100
- }
101
- // Handle load instruction which is following the bitcast in the pattern
102
- else if (auto *LI = dyn_cast<LoadInst>(U)) {
103
- Value *LocalValue = Builder.CreateLoad (NewVecTy, NewInst);
104
- Changed |= lowerBitCastToNonStdVec (
105
- LI, LocalValue, OldVecTy, InstsToErase, Builder, RecursionDepth);
106
- }
107
- // Handle extractelement instruction which is following the load
108
- else if (auto *EEI = dyn_cast<ExtractElementInst>(U)) {
109
- uint64_t NumElemsInOldVec = OldVecTy->getElementCount ().getFixedValue ();
110
- uint64_t NumElemsInNewVec = NewVecTy->getElementCount ().getFixedValue ();
111
- uint64_t OldElemIdx =
112
- cast<ConstantInt>(EEI->getIndexOperand ())->getZExtValue ();
113
- uint64_t NewElemIdx =
114
- OldElemIdx / (NumElemsInOldVec / NumElemsInNewVec);
115
- Value *LocalValue = Builder.CreateExtractElement (NewInst, NewElemIdx);
116
- // The trunc instruction truncates the high order bits in value, so it
117
- // may be necessary to shift right high order bits, if required bits are
118
- // not at the end of extracted value
119
- unsigned OldVecElemBitWidth =
120
- cast<IntegerType>(OldVecTy->getElementType ())->getBitWidth ();
121
- unsigned NewVecElemBitWidth =
122
- cast<IntegerType>(NewVecTy->getElementType ())->getBitWidth ();
123
- unsigned BitWidthRatio = NewVecElemBitWidth / OldVecElemBitWidth;
124
- if (auto RequiredBitsIdx =
125
- OldElemIdx % BitWidthRatio != BitWidthRatio - 1 ) {
126
- uint64_t Shift =
127
- OldVecElemBitWidth * (BitWidthRatio - RequiredBitsIdx);
128
- LocalValue = Builder.CreateLShr (LocalValue, Shift);
129
- }
130
- LocalValue =
131
- Builder.CreateTrunc (LocalValue, OldVecTy->getElementType ());
132
- Changed |= lowerBitCastToNonStdVec (
133
- EEI, LocalValue, OldVecTy, InstsToErase, Builder, RecursionDepth);
61
+ static Value *removeBitCasts (Value *OldValue, Type *NewTy, NFIRBuilder &Builder,
62
+ std::vector<Instruction *> &InstsToErase) {
63
+ IRBuilderBase::InsertPointGuard Guard (Builder);
64
+ auto RauwBitcasts = [&](Instruction *OldValue, Value *NewValue) {
65
+ // If there's only one use, don't create a bitcast for any uses, since it
66
+ // will be immediately replaced anyways.
67
+ if (OldValue->hasOneUse ()) {
68
+ OldValue->replaceAllUsesWith (UndefValue::get (OldValue->getType ()));
69
+ } else {
70
+ OldValue->replaceAllUsesWith (
71
+ Builder.CreateBitCast (NewValue, OldValue->getType ()));
72
+ }
73
+ InstsToErase.push_back (OldValue);
74
+ return NewValue;
75
+ };
76
+
77
+ if (auto *LI = dyn_cast<LoadInst>(OldValue)) {
78
+ Builder.SetInsertPoint (LI);
79
+ Value *Pointer = LI->getPointerOperand ();
80
+ if (!Pointer->getType ()->isOpaquePointerTy ()) {
81
+ Type *NewPointerTy =
82
+ PointerType::get (NewTy, LI->getPointerAddressSpace ());
83
+ Pointer = removeBitCasts (Pointer, NewPointerTy, Builder, InstsToErase);
84
+ }
85
+ LoadInst *NewLI = Builder.CreateAlignedLoad (NewTy, Pointer, LI->getAlign (),
86
+ LI->isVolatile ());
87
+ NewLI->setOrdering (LI->getOrdering ());
88
+ NewLI->setSyncScopeID (LI->getSyncScopeID ());
89
+ return RauwBitcasts (LI, NewLI);
90
+ }
91
+
92
+ if (auto *ASCI = dyn_cast<AddrSpaceCastInst>(OldValue)) {
93
+ Builder.SetInsertPoint (ASCI);
94
+ Type *NewSrcTy = PointerType::getWithSamePointeeType (
95
+ cast<PointerType>(NewTy), ASCI->getSrcAddressSpace ());
96
+ Value *Pointer = removeBitCasts (ASCI->getPointerOperand (), NewSrcTy,
97
+ Builder, InstsToErase);
98
+ return RauwBitcasts (ASCI, Builder.CreateAddrSpaceCast (Pointer, NewTy));
99
+ }
100
+
101
+ if (auto *BC = dyn_cast<BitCastInst>(OldValue)) {
102
+ if (BC->getSrcTy () == NewTy) {
103
+ if (BC->hasOneUse ()) {
104
+ BC->replaceAllUsesWith (UndefValue::get (BC->getType ()));
105
+ InstsToErase.push_back (BC);
134
106
}
107
+ return BC->getOperand (0 );
135
108
}
109
+ Builder.SetInsertPoint (BC);
110
+ return RauwBitcasts (BC, Builder.CreateBitCast (BC->getOperand (0 ), NewTy));
136
111
}
137
- InstsToErase.push_back (OldInst);
138
- if (!Changed)
139
- OldInst->replaceAllUsesWith (NewInst);
140
- return true ;
112
+
113
+ report_fatal_error (" Cannot translate source of bitcast instruction." );
114
+ return nullptr ;
115
+ }
116
+
117
+ static bool isNonStdVecType (VectorType *VecTy) {
118
+ uint64_t NumElems = VecTy->getElementCount ().getFixedValue ();
119
+ return !isValidVectorSize (NumElems);
141
120
}
142
121
143
122
class SPIRVLowerBitCastToNonStandardTypePass
@@ -160,41 +139,82 @@ class SPIRVLowerBitCastToNonStandardTypePass
160
139
if (Opts.isAllowedToUseExtension (ExtensionID::SPV_INTEL_vector_compute))
161
140
return PreservedAnalyses::all ();
162
141
163
- std::vector<Instruction *> BCastsToNonStdVec;
164
- std::vector<Instruction *> InstsToErase;
142
+ // The basic pattern we're trying to fix is this InstCombine pattern:
143
+ // trunc (extractelement) -> extractelement (bitcast)
144
+ // (note that the bitcast itself can get propagated back to change the type
145
+ // of load instructions, and even through those to pointer casts, if typed
146
+ // pointers are enabled.
147
+ std::vector<ExtractElementInst *> NonStdVecInsts;
148
+ SmallVector<WeakTrackingVH, 4 > MaybeDeletedInsts;
165
149
for (auto &BB : F)
166
150
for (auto &I : BB) {
167
- auto *BC = dyn_cast<BitCastInst>(&I);
168
- if (!BC)
169
- continue ;
170
- VectorType *SrcVecTy = getVectorType (BC->getSrcTy ());
171
- if (SrcVecTy) {
172
- uint64_t NumElemsInSrcVec =
173
- SrcVecTy->getElementCount ().getFixedValue ();
174
- if (!isValidVectorSize (NumElemsInSrcVec))
175
- report_fatal_error (
176
- llvm::Twine (" Unsupported vector type with the size of: " +
177
- std::to_string (NumElemsInSrcVec)),
178
- false );
179
- }
180
- VectorType *DestVecTy = getVectorType (BC->getDestTy ());
181
- if (DestVecTy) {
182
- uint64_t NumElemsInDestVec =
183
- DestVecTy->getElementCount ().getFixedValue ();
184
- if (!isValidVectorSize (NumElemsInDestVec))
185
- BCastsToNonStdVec.push_back (&I);
151
+ if (auto *EI = dyn_cast<ExtractElementInst>(&I)) {
152
+ if (isNonStdVecType (EI->getVectorOperandType ()))
153
+ NonStdVecInsts.push_back (EI);
154
+ } else if (auto *VT = dyn_cast<VectorType>(I.getType ())) {
155
+ if (isNonStdVecType (VT)) {
156
+ MaybeDeletedInsts.push_back (&I);
157
+ }
186
158
}
187
159
}
188
- IRBuilder<> Builder (F.getContext ());
189
- for (auto &I : BCastsToNonStdVec) {
190
- Value *NewValue = I->getOperand (0 );
191
- VectorType *OldVecTy = getVectorType (I->getType ());
192
- Changed |=
193
- lowerBitCastToNonStdVec (I, NewValue, OldVecTy, InstsToErase, Builder);
160
+
161
+ std::vector<Instruction *> InstsToErase;
162
+ NFIRBuilder Builder (F.getContext ());
163
+ for (auto &I : NonStdVecInsts) {
164
+ VectorType *OldVecTy = I->getVectorOperandType ();
165
+ unsigned OldVecSize = OldVecTy->getElementCount ().getFixedValue ();
166
+
167
+ // Compute the adjustment factor for the new vector size.
168
+ unsigned VecFactor = 2 ;
169
+ while (OldVecSize % VecFactor == 0 &&
170
+ !isValidVectorSize (OldVecSize / VecFactor))
171
+ VecFactor *= 2 ;
172
+ if (OldVecSize % VecFactor != 0 ) {
173
+ report_fatal_error (Twine (" Invalid vector size for fixup: " ) +
174
+ Twine (OldVecSize));
175
+ return PreservedAnalyses::none ();
176
+ }
177
+ unsigned NewElemSize = OldVecTy->getScalarSizeInBits () * VecFactor;
178
+ VectorType *NewVecTy =
179
+ VectorType::get (Type::getIntNTy (F.getContext (), NewElemSize),
180
+ OldVecSize / VecFactor, false );
181
+
182
+ // Adjust the element index as appropriate.
183
+ uint64_t OldElemIdx =
184
+ cast<ConstantInt>(I->getIndexOperand ())->getZExtValue ();
185
+ uint64_t NewElemIdx = OldElemIdx / VecFactor;
186
+ uint64_t ShiftCount = OldElemIdx % VecFactor;
187
+ Builder.SetInsertPoint (I);
188
+ Value *NewVecOp = removeBitCasts (I->getVectorOperand (), NewVecTy, Builder,
189
+ InstsToErase);
190
+ Value *NewExtracted = Builder.CreateExtractElement (NewVecOp, NewElemIdx);
191
+
192
+ // If the extract does higher-order bits of the value, shift as necessary.
193
+ if (ShiftCount > 0 )
194
+ NewExtracted = Builder.CreateLShr (
195
+ NewExtracted, ShiftCount * OldVecTy->getScalarSizeInBits ());
196
+
197
+ Value *NewValue = Builder.CreateTrunc (NewExtracted, I->getType ());
198
+ I->replaceAllUsesWith (NewValue);
199
+ I->eraseFromParent ();
200
+ Changed = true ;
194
201
}
195
202
196
203
for (auto *I : InstsToErase)
197
- I->eraseFromParent ();
204
+ RecursivelyDeleteTriviallyDeadInstructions (I);
205
+
206
+ // Check if there are any residual unsupported vector types.
207
+ for (auto &VH : MaybeDeletedInsts) {
208
+ // Some vector-valued instructions were replaced with undef values, so if
209
+ // that's what we got, it's still a dead instruction.
210
+ if (VH.pointsToAliveValue () && !isa<UndefValue>(VH)) {
211
+ auto *VT = dyn_cast<VectorType>(VH->getType ());
212
+ report_fatal_error (Twine (" Unsupported vector type with " ) +
213
+ Twine (VT->getElementCount ().getFixedValue ()) +
214
+ Twine (" elements" ),
215
+ false );
216
+ }
217
+ }
198
218
199
219
return Changed ? PreservedAnalyses::none () : PreservedAnalyses::all ();
200
220
}
0 commit comments