@@ -711,18 +711,16 @@ struct VectorizationState {
711
711
BlockArgument replacement);
712
712
713
713
// / Registers the scalar replacement of a scalar value. 'replacement' must be
714
- // / scalar. Both values must be block arguments. Operation results should be
715
- // / replaced using the 'registerOp*' utilitites.
714
+ // / scalar.
716
715
// /
717
716
// / This utility is used to register the replacement of block arguments
718
- // / that are within the loop to be vectorized and will continue being scalar
719
- // / within the vector loop.
717
+ // / or affine.apply results that are within the loop be vectorized and will
718
+ // / continue being scalar within the vector loop.
720
719
// /
721
720
// / Example:
722
721
// / * 'replaced': induction variable of a loop to be vectorized.
723
722
// / * 'replacement': new induction variable in the new vector loop.
724
- void registerValueScalarReplacement (BlockArgument replaced,
725
- BlockArgument replacement);
723
+ void registerValueScalarReplacement (Value replaced, Value replacement);
726
724
727
725
// / Registers the scalar replacement of a scalar result returned from a
728
726
// / reduction loop. 'replacement' must be scalar.
@@ -772,7 +770,6 @@ struct VectorizationState {
772
770
// / Internal implementation to map input scalar values to new vector or scalar
773
771
// / values.
774
772
void registerValueVectorReplacementImpl (Value replaced, Value replacement);
775
- void registerValueScalarReplacementImpl (Value replaced, Value replacement);
776
773
};
777
774
778
775
} // namespace
@@ -844,19 +841,22 @@ void VectorizationState::registerValueVectorReplacementImpl(Value replaced,
844
841
}
845
842
846
843
// / Registers the scalar replacement of a scalar value. 'replacement' must be
847
- // / scalar. Both values must be block arguments. Operation results should be
848
- // / replaced using the 'registerOp*' utilitites.
844
+ // / scalar.
849
845
// /
850
846
// / This utility is used to register the replacement of block arguments
851
- // / that are within the loop to be vectorized and will continue being scalar
852
- // / within the vector loop.
847
+ // / or affine.apply results that are within the loop be vectorized and will
848
+ // / continue being scalar within the vector loop.
853
849
// /
854
850
// / Example:
855
851
// / * 'replaced': induction variable of a loop to be vectorized.
856
852
// / * 'replacement': new induction variable in the new vector loop.
857
- void VectorizationState::registerValueScalarReplacement (
858
- BlockArgument replaced, BlockArgument replacement) {
859
- registerValueScalarReplacementImpl (replaced, replacement);
853
+ void VectorizationState::registerValueScalarReplacement (Value replaced,
854
+ Value replacement) {
855
+ assert (!valueScalarReplacement.contains (replaced) &&
856
+ " Scalar value replacement already registered" );
857
+ assert (!isa<VectorType>(replacement.getType ()) &&
858
+ " Expected scalar type in scalar replacement" );
859
+ valueScalarReplacement.map (replaced, replacement);
860
860
}
861
861
862
862
// / Registers the scalar replacement of a scalar result returned from a
@@ -879,15 +879,6 @@ void VectorizationState::registerLoopResultScalarReplacement(
879
879
loopResultScalarReplacement[replaced] = replacement;
880
880
}
881
881
882
- void VectorizationState::registerValueScalarReplacementImpl (Value replaced,
883
- Value replacement) {
884
- assert (!valueScalarReplacement.contains (replaced) &&
885
- " Scalar value replacement already registered" );
886
- assert (!isa<VectorType>(replacement.getType ()) &&
887
- " Expected scalar type in scalar replacement" );
888
- valueScalarReplacement.map (replaced, replacement);
889
- }
890
-
891
882
// / Returns in 'replacedVals' the scalar replacement for values in 'inputVals'.
892
883
void VectorizationState::getScalarValueReplacementsFor (
893
884
ValueRange inputVals, SmallVectorImpl<Value> &replacedVals) {
@@ -978,6 +969,33 @@ static arith::ConstantOp vectorizeConstant(arith::ConstantOp constOp,
978
969
return newConstOp;
979
970
}
980
971
972
+ // / We have no need to vectorize affine.apply. However, we still need to
973
+ // / generate it and replace the operands with values in valueScalarReplacement.
974
+ static Operation *vectorizeAffineApplyOp (AffineApplyOp applyOp,
975
+ VectorizationState &state) {
976
+ SmallVector<Value, 8 > updatedOperands;
977
+ for (Value operand : applyOp.getOperands ()) {
978
+ if (state.valueVectorReplacement .contains (operand)) {
979
+ LLVM_DEBUG (
980
+ dbgs () << " \n [early-vect]+++++ affine.apply on vector operand\n " );
981
+ return nullptr ;
982
+ } else {
983
+ Value updatedOperand = state.valueScalarReplacement .lookupOrNull (operand);
984
+ if (!updatedOperand)
985
+ updatedOperand = operand;
986
+ updatedOperands.push_back (updatedOperand);
987
+ }
988
+ }
989
+
990
+ auto newApplyOp = state.builder .create <AffineApplyOp>(
991
+ applyOp.getLoc (), applyOp.getAffineMap (), updatedOperands);
992
+
993
+ // Register the new affine.apply result.
994
+ state.registerValueScalarReplacement (applyOp.getResult (),
995
+ newApplyOp.getResult ());
996
+ return newApplyOp;
997
+ }
998
+
981
999
// / Creates a constant vector filled with the neutral elements of the given
982
1000
// / reduction. The scalar type of vector elements will be taken from
983
1001
// / `oldOperand`.
@@ -1493,6 +1511,8 @@ static Operation *vectorizeOneOperation(Operation *op,
1493
1511
return vectorizeAffineYieldOp (yieldOp, state);
1494
1512
if (auto constant = dyn_cast<arith::ConstantOp>(op))
1495
1513
return vectorizeConstant (constant, state);
1514
+ if (auto applyOp = dyn_cast<AffineApplyOp>(op))
1515
+ return vectorizeAffineApplyOp (applyOp, state);
1496
1516
1497
1517
// Other ops with regions are not supported.
1498
1518
if (op->getNumRegions () != 0 )
0 commit comments