@@ -246,12 +246,12 @@ struct VPCallback {
246
246
// / VPTransformState holds information passed down when "executing" a VPlan,
247
247
// / needed for generating the output IR.
248
248
struct VPTransformState {
249
- VPTransformState (ElementCount VF, unsigned UF, LoopInfo *LI,
249
+ VPTransformState (ElementCount VF, unsigned UF, Loop *OrigLoop, LoopInfo *LI,
250
250
DominatorTree *DT, IRBuilder<> &Builder,
251
251
VectorizerValueMap &ValueMap, InnerLoopVectorizer *ILV,
252
252
VPCallback &Callback)
253
- : VF(VF), UF(UF), Instance(), LI(LI ), DT(DT ), Builder(Builder ),
254
- ValueMap (ValueMap), ILV(ILV), Callback(Callback) {}
253
+ : VF(VF), UF(UF), Instance(), OrigLoop(OrigLoop ), LI(LI ), DT(DT ),
254
+ Builder (Builder), ValueMap(ValueMap), ILV(ILV), Callback(Callback) {}
255
255
256
256
// / The chosen Vectorization and Unroll Factors of the loop being vectorized.
257
257
ElementCount VF;
@@ -269,6 +269,9 @@ struct VPTransformState {
269
269
typedef SmallVector<Value *, 2 > PerPartValuesTy;
270
270
271
271
DenseMap<VPValue *, PerPartValuesTy> PerPartOutput;
272
+
273
+ using ScalarsPerPartValuesTy = SmallVector<SmallVector<Value *, 4 >, 2 >;
274
+ DenseMap<VPValue *, ScalarsPerPartValuesTy> PerPartScalars;
272
275
} Data;
273
276
274
277
// / Get the generated Value for a given VPValue and a given Part. Note that
@@ -285,24 +288,21 @@ struct VPTransformState {
285
288
}
286
289
287
290
// / Get the generated Value for a given VPValue and given Part and Lane.
288
- Value *get (VPValue *Def, const VPIteration &Instance) {
289
- // If the Def is managed directly by VPTransformState, extract the lane from
290
- // the relevant part. Note that currently only VPInstructions and external
291
- // defs are managed by VPTransformState. Other Defs are still created by ILV
292
- // and managed in its ValueMap. For those this method currently just
293
- // delegates the call to ILV below.
294
- if (Data.PerPartOutput .count (Def)) {
295
- auto *VecPart = Data.PerPartOutput [Def][Instance.Part ];
296
- if (!VecPart->getType ()->isVectorTy ()) {
297
- assert (Instance.Lane == 0 && " cannot get lane > 0 for scalar" );
298
- return VecPart;
299
- }
300
- // TODO: Cache created scalar values.
301
- return Builder.CreateExtractElement (VecPart,
302
- Builder.getInt32 (Instance.Lane ));
303
- }
291
+ Value *get (VPValue *Def, const VPIteration &Instance);
304
292
305
- return Callback.getOrCreateScalarValue (VPValue2Value[Def], Instance);
293
+ bool hasVectorValue (VPValue *Def, unsigned Part) {
294
+ auto I = Data.PerPartOutput .find (Def);
295
+ return I != Data.PerPartOutput .end () && Part < I->second .size () &&
296
+ I->second [Part];
297
+ }
298
+
299
+ bool hasScalarValue (VPValue *Def, VPIteration Instance) {
300
+ auto I = Data.PerPartScalars .find (Def);
301
+ if (I == Data.PerPartScalars .end ())
302
+ return false ;
303
+ return Instance.Part < I->second .size () &&
304
+ Instance.Lane < I->second [Instance.Part ].size () &&
305
+ I->second [Instance.Part ][Instance.Lane ];
306
306
}
307
307
308
308
// / Set the generated Value for a given VPValue and a given Part.
@@ -315,6 +315,17 @@ struct VPTransformState {
315
315
}
316
316
void set (VPValue *Def, Value *IRDef, Value *V, unsigned Part);
317
317
318
+ void set (VPValue *Def, Value *V, const VPIteration &Instance) {
319
+ auto Iter = Data.PerPartScalars .insert ({Def, {}});
320
+ auto &PerPartVec = Iter.first ->second ;
321
+ while (PerPartVec.size () <= Instance.Part )
322
+ PerPartVec.emplace_back ();
323
+ auto &Scalars = PerPartVec[Instance.Part ];
324
+ while (Scalars.size () <= Instance.Lane )
325
+ Scalars.push_back (nullptr );
326
+ Scalars[Instance.Lane ] = V;
327
+ }
328
+
318
329
// / Hold state information used when constructing the CFG of the output IR,
319
330
// / traversing the VPBasicBlocks and generating corresponding IR BasicBlocks.
320
331
struct CFGState {
@@ -340,6 +351,9 @@ struct VPTransformState {
340
351
CFGState () = default ;
341
352
} CFG;
342
353
354
+ // / Hold a pointer to the original loop.
355
+ Loop *OrigLoop;
356
+
343
357
// / Hold a pointer to LoopInfo to register new basic blocks in the loop.
344
358
LoopInfo *LI;
345
359
0 commit comments