@@ -721,8 +721,7 @@ struct VectorizationState {
721
721
// / Example:
722
722
// / * 'replaced': induction variable of a loop to be vectorized.
723
723
// / * 'replacement': new induction variable in the new vector loop.
724
- void registerValueScalarReplacement (BlockArgument replaced,
725
- BlockArgument replacement);
724
+ void registerValueScalarReplacement (Value replaced, Value replacement);
726
725
727
726
// / Registers the scalar replacement of a scalar result returned from a
728
727
// / reduction loop. 'replacement' must be scalar.
@@ -854,8 +853,8 @@ void VectorizationState::registerValueVectorReplacementImpl(Value replaced,
854
853
// / Example:
855
854
// / * 'replaced': induction variable of a loop to be vectorized.
856
855
// / * 'replacement': new induction variable in the new vector loop.
857
- void VectorizationState::registerValueScalarReplacement (
858
- BlockArgument replaced, BlockArgument replacement) {
856
+ void VectorizationState::registerValueScalarReplacement (Value replaced,
857
+ Value replacement) {
859
858
registerValueScalarReplacementImpl (replaced, replacement);
860
859
}
861
860
@@ -978,6 +977,28 @@ static arith::ConstantOp vectorizeConstant(arith::ConstantOp constOp,
978
977
return newConstOp;
979
978
}
980
979
980
+ // / We have no need to vectorize affine.apply. However, we still need to
981
+ // / generate it and replace the operands with values in valueScalarReplacement.
982
+ static Operation *vectorizeAffineApplyOp (AffineApplyOp applyOp,
983
+ VectorizationState &state) {
984
+ SmallVector<Value, 8 > updatedOperands;
985
+ for (Value operand : applyOp.getOperands ()) {
986
+ Value updatedOperand = operand;
987
+ if (state.valueScalarReplacement .contains (operand)) {
988
+ updatedOperand = state.valueScalarReplacement .lookupOrDefault (operand);
989
+ }
990
+ updatedOperands.push_back (updatedOperand);
991
+ }
992
+
993
+ auto newApplyOp = state.builder .create <AffineApplyOp>(
994
+ applyOp.getLoc (), applyOp.getAffineMap (), updatedOperands);
995
+
996
+ // Register the new affine.apply result.
997
+ state.registerValueScalarReplacement (applyOp.getResult (),
998
+ newApplyOp.getResult ());
999
+ return newApplyOp;
1000
+ }
1001
+
981
1002
// / Creates a constant vector filled with the neutral elements of the given
982
1003
// / reduction. The scalar type of vector elements will be taken from
983
1004
// / `oldOperand`.
@@ -1493,6 +1514,8 @@ static Operation *vectorizeOneOperation(Operation *op,
1493
1514
return vectorizeAffineYieldOp (yieldOp, state);
1494
1515
if (auto constant = dyn_cast<arith::ConstantOp>(op))
1495
1516
return vectorizeConstant (constant, state);
1517
+ if (auto applyOp = dyn_cast<AffineApplyOp>(op))
1518
+ return vectorizeAffineApplyOp (applyOp, state);
1496
1519
1497
1520
// Other ops with regions are not supported.
1498
1521
if (op->getNumRegions () != 0 )
0 commit comments