@@ -2231,7 +2231,7 @@ void VPBlendRecipe::print(raw_ostream &O, const Twine &Indent,
2231
2231
void VPReductionRecipe::execute (VPTransformState &State) {
2232
2232
assert (!State.Lane && " Reduction being replicated." );
2233
2233
Value *PrevInChain = State.get (getChainOp (), /* IsScalar*/ true );
2234
- RecurKind Kind = RdxDesc. getRecurrenceKind ();
2234
+ RecurKind Kind = getRecurrenceKind ();
2235
2235
assert (!RecurrenceDescriptor::isAnyOfRecurrenceKind (Kind) &&
2236
2236
" In-loop AnyOf reductions aren't currently supported" );
2237
2237
// Propagate the fast-math flags carried by the underlying instruction.
@@ -2244,8 +2244,7 @@ void VPReductionRecipe::execute(VPTransformState &State) {
2244
2244
VectorType *VecTy = dyn_cast<VectorType>(NewVecOp->getType ());
2245
2245
Type *ElementTy = VecTy ? VecTy->getElementType () : NewVecOp->getType ();
2246
2246
2247
- Value *Start =
2248
- getRecurrenceIdentity (Kind, ElementTy, RdxDesc.getFastMathFlags ());
2247
+ Value *Start = getRecurrenceIdentity (Kind, ElementTy, getFastMathFlags ());
2249
2248
if (State.VF .isVector ())
2250
2249
Start = State.Builder .CreateVectorSplat (VecTy->getElementCount (), Start);
2251
2250
@@ -2260,18 +2259,19 @@ void VPReductionRecipe::execute(VPTransformState &State) {
2260
2259
createOrderedReduction (State.Builder , Kind, NewVecOp, PrevInChain);
2261
2260
else
2262
2261
NewRed = State.Builder .CreateBinOp (
2263
- (Instruction::BinaryOps)RdxDesc.getOpcode (), PrevInChain, NewVecOp);
2262
+ (Instruction::BinaryOps)RecurrenceDescriptor::getOpcode (Kind),
2263
+ PrevInChain, NewVecOp);
2264
2264
PrevInChain = NewRed;
2265
2265
NextInChain = NewRed;
2266
2266
} else {
2267
2267
PrevInChain = State.get (getChainOp (), /* IsScalar*/ true );
2268
2268
NewRed = createSimpleReduction (State.Builder , NewVecOp, Kind);
2269
2269
if (RecurrenceDescriptor::isMinMaxRecurrenceKind (Kind))
2270
- NextInChain = createMinMaxOp (State.Builder , RdxDesc.getRecurrenceKind (),
2271
- NewRed, PrevInChain);
2270
+ NextInChain = createMinMaxOp (State.Builder , Kind, NewRed, PrevInChain);
2272
2271
else
2273
2272
NextInChain = State.Builder .CreateBinOp (
2274
- (Instruction::BinaryOps)RdxDesc.getOpcode (), NewRed, PrevInChain);
2273
+ (Instruction::BinaryOps)RecurrenceDescriptor::getOpcode (Kind), NewRed,
2274
+ PrevInChain);
2275
2275
}
2276
2276
State.set (this , NextInChain, /* IsScalar*/ true );
2277
2277
}
@@ -2282,10 +2282,9 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
2282
2282
auto &Builder = State.Builder ;
2283
2283
// Propagate the fast-math flags carried by the underlying instruction.
2284
2284
IRBuilderBase::FastMathFlagGuard FMFGuard (Builder);
2285
- const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor ();
2286
2285
Builder.setFastMathFlags (getFastMathFlags ());
2287
2286
2288
- RecurKind Kind = RdxDesc. getRecurrenceKind ();
2287
+ RecurKind Kind = getRecurrenceKind ();
2289
2288
Value *Prev = State.get (getChainOp (), /* IsScalar*/ true );
2290
2289
Value *VecOp = State.get (getVecOp ());
2291
2290
Value *EVL = State.get (getEVL (), VPLane (0 ));
@@ -2308,18 +2307,19 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) {
2308
2307
if (RecurrenceDescriptor::isMinMaxRecurrenceKind (Kind))
2309
2308
NewRed = createMinMaxOp (Builder, Kind, NewRed, Prev);
2310
2309
else
2311
- NewRed = Builder.CreateBinOp ((Instruction::BinaryOps)RdxDesc.getOpcode (),
2312
- NewRed, Prev);
2310
+ NewRed = Builder.CreateBinOp (
2311
+ (Instruction::BinaryOps)RecurrenceDescriptor::getOpcode (Kind), NewRed,
2312
+ Prev);
2313
2313
}
2314
2314
State.set (this , NewRed, /* IsScalar*/ true );
2315
2315
}
2316
2316
2317
2317
InstructionCost VPReductionRecipe::computeCost (ElementCount VF,
2318
2318
VPCostContext &Ctx) const {
2319
- RecurKind RdxKind = RdxDesc. getRecurrenceKind ();
2319
+ RecurKind RdxKind = getRecurrenceKind ();
2320
2320
Type *ElementTy = Ctx.Types .inferScalarType (this );
2321
2321
auto *VectorTy = cast<VectorType>(toVectorTy (ElementTy, VF));
2322
- unsigned Opcode = RdxDesc. getOpcode ();
2322
+ unsigned Opcode = RecurrenceDescriptor:: getOpcode (RdxKind );
2323
2323
FastMathFlags FMFs = getFastMathFlags ();
2324
2324
2325
2325
// TODO: Support any-of and in-loop reductions.
@@ -2332,9 +2332,6 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
2332
2332
ForceTargetInstructionCost.getNumOccurrences () > 0 ) &&
2333
2333
" In-loop reduction not implemented in VPlan-based cost model currently." );
2334
2334
2335
- assert (ElementTy->getTypeID () == RdxDesc.getRecurrenceType ()->getTypeID () &&
2336
- " Inferred type and recurrence type mismatch." );
2337
-
2338
2335
// Cost = Reduction cost + BinOp cost
2339
2336
InstructionCost Cost =
2340
2337
Ctx.TTI .getArithmeticInstrCost (Opcode, ElementTy, Ctx.CostKind );
@@ -2357,28 +2354,30 @@ void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent,
2357
2354
getChainOp ()->printAsOperand (O, SlotTracker);
2358
2355
O << " +" ;
2359
2356
printFlags (O);
2360
- O << " reduce." << Instruction::getOpcodeName (RdxDesc.getOpcode ()) << " (" ;
2357
+ O << " reduce."
2358
+ << Instruction::getOpcodeName (
2359
+ RecurrenceDescriptor::getOpcode (getRecurrenceKind ()))
2360
+ << " (" ;
2361
2361
getVecOp ()->printAsOperand (O, SlotTracker);
2362
2362
if (isConditional ()) {
2363
2363
O << " , " ;
2364
2364
getCondOp ()->printAsOperand (O, SlotTracker);
2365
2365
}
2366
2366
O << " )" ;
2367
- if (RdxDesc.IntermediateStore )
2368
- O << " (with final reduction value stored in invariant address sank "
2369
- " outside of loop)" ;
2370
2367
}
2371
2368
2372
2369
void VPReductionEVLRecipe::print (raw_ostream &O, const Twine &Indent,
2373
2370
VPSlotTracker &SlotTracker) const {
2374
- const RecurrenceDescriptor &RdxDesc = getRecurrenceDescriptor ();
2375
2371
O << Indent << " REDUCE " ;
2376
2372
printAsOperand (O, SlotTracker);
2377
2373
O << " = " ;
2378
2374
getChainOp ()->printAsOperand (O, SlotTracker);
2379
2375
O << " +" ;
2380
2376
printFlags (O);
2381
- O << " vp.reduce." << Instruction::getOpcodeName (RdxDesc.getOpcode ()) << " (" ;
2377
+ O << " vp.reduce."
2378
+ << Instruction::getOpcodeName (
2379
+ RecurrenceDescriptor::getOpcode (getRecurrenceKind ()))
2380
+ << " (" ;
2382
2381
getVecOp ()->printAsOperand (O, SlotTracker);
2383
2382
O << " , " ;
2384
2383
getEVL ()->printAsOperand (O, SlotTracker);
@@ -2387,9 +2386,6 @@ void VPReductionEVLRecipe::print(raw_ostream &O, const Twine &Indent,
2387
2386
getCondOp ()->printAsOperand (O, SlotTracker);
2388
2387
}
2389
2388
O << " )" ;
2390
- if (RdxDesc.IntermediateStore )
2391
- O << " (with final reduction value stored in invariant address sank "
2392
- " outside of loop)" ;
2393
2389
}
2394
2390
#endif
2395
2391
0 commit comments