@@ -195,6 +195,10 @@ struct VectorizationState {
195
195
// / Returns the canonical vector shape used to vectorize the iteration space.
196
196
ArrayRef<int64_t > getCanonicalVecShape () const { return canonicalVecShape; }
197
197
198
+ // / Returns the vector dimensions that are scalable in the canonical vector
199
+ // / shape.
200
+ ArrayRef<bool > getScalableVecDims () const { return scalableVecDims; }
201
+
198
202
// / Returns a vector type of the provided `elementType` with the canonical
199
203
// / vector shape and the corresponding fixed/scalable dimensions bit. If
200
204
// / `dimPermutation` is provided, the canonical vector dimensions are permuted
@@ -694,23 +698,24 @@ static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
694
698
return VectorizationResult{VectorizationStatus::Failure, nullptr };
695
699
auto loc = indexOp.getLoc ();
696
700
// Compute the static loop sizes of the index op.
697
- auto targetShape = state.getCanonicalVecShape ();
701
+ ArrayRef<int64_t > targetShape = state.getCanonicalVecShape ();
702
+ auto dim = indexOp.getDim ();
698
703
// Compute a one-dimensional index vector for the index op dimension.
699
- auto constantSeq =
700
- llvm::to_vector (llvm::seq< int64_t >( 0 , targetShape[indexOp. getDim ()]));
701
- auto indexSteps = rewriter. create <arith::ConstantOp>(
702
- loc, rewriter.getIndexVectorAttr (constantSeq) );
704
+ auto indexVectorType =
705
+ VectorType::get ({ targetShape[dim]}, rewriter. getIndexType (),
706
+ state. getScalableVecDims ()[dim]);
707
+ auto indexSteps = rewriter.create <vector::StepOp>(loc, indexVectorType );
703
708
// Return the one-dimensional index vector if it lives in the trailing
704
709
// dimension of the iteration space since the vectorization algorithm in this
705
710
// case can handle the broadcast.
706
- if (indexOp. getDim () == targetShape.size () - 1 )
711
+ if (dim == targetShape.size () - 1 )
707
712
return VectorizationResult{VectorizationStatus::NewOp, indexSteps};
708
713
// Otherwise permute the targetShape to move the index dimension last,
709
714
// broadcast the one-dimensional index vector to the permuted shape, and
710
715
// finally transpose the broadcasted index vector to undo the permutation.
711
716
auto permPattern =
712
717
llvm::to_vector (llvm::seq<unsigned >(0 , targetShape.size ()));
713
- std::swap (permPattern[indexOp. getDim () ], permPattern.back ());
718
+ std::swap (permPattern[dim ], permPattern.back ());
714
719
auto permMap =
715
720
AffineMap::getPermutationMap (permPattern, linalgOp.getContext ());
716
721
@@ -719,7 +724,7 @@ static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
719
724
indexSteps);
720
725
SmallVector<int64_t > transposition =
721
726
llvm::to_vector<16 >(llvm::seq<int64_t >(0 , linalgOp.getNumLoops ()));
722
- std::swap (transposition.back (), transposition[indexOp. getDim () ]);
727
+ std::swap (transposition.back (), transposition[dim ]);
723
728
auto transposeOp =
724
729
rewriter.create <vector::TransposeOp>(loc, broadCastOp, transposition);
725
730
return VectorizationResult{VectorizationStatus::NewOp, transposeOp};
0 commit comments