Skip to content

Commit 8011a23

Browse files
authored
[mlir][linalg] Support scalable vectorization of linalg.index operations (llvm#96778)
The vectorization of linalg.index operations doesn't support scalable vectors when computing the index vector. This patch fixes this with the vector.step operation. Depends on llvm#96776
1 parent 1e7d6d3 commit 8011a23

File tree

3 files changed

+62
-10
lines changed

3 files changed

+62
-10
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,10 @@ struct VectorizationState {
195195
/// Returns the canonical vector shape used to vectorize the iteration space.
196196
ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
197197

198+
/// Returns the vector dimensions that are scalable in the canonical vector
199+
/// shape.
200+
ArrayRef<bool> getScalableVecDims() const { return scalableVecDims; }
201+
198202
/// Returns a vector type of the provided `elementType` with the canonical
199203
/// vector shape and the corresponding fixed/scalable dimensions bit. If
200204
/// `dimPermutation` is provided, the canonical vector dimensions are permuted
@@ -694,23 +698,24 @@ static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
694698
return VectorizationResult{VectorizationStatus::Failure, nullptr};
695699
auto loc = indexOp.getLoc();
696700
// Compute the static loop sizes of the index op.
697-
auto targetShape = state.getCanonicalVecShape();
701+
ArrayRef<int64_t> targetShape = state.getCanonicalVecShape();
702+
auto dim = indexOp.getDim();
698703
// Compute a one-dimensional index vector for the index op dimension.
699-
auto constantSeq =
700-
llvm::to_vector(llvm::seq<int64_t>(0, targetShape[indexOp.getDim()]));
701-
auto indexSteps = rewriter.create<arith::ConstantOp>(
702-
loc, rewriter.getIndexVectorAttr(constantSeq));
704+
auto indexVectorType =
705+
VectorType::get({targetShape[dim]}, rewriter.getIndexType(),
706+
state.getScalableVecDims()[dim]);
707+
auto indexSteps = rewriter.create<vector::StepOp>(loc, indexVectorType);
703708
// Return the one-dimensional index vector if it lives in the trailing
704709
// dimension of the iteration space since the vectorization algorithm in this
705710
// case can handle the broadcast.
706-
if (indexOp.getDim() == targetShape.size() - 1)
711+
if (dim == targetShape.size() - 1)
707712
return VectorizationResult{VectorizationStatus::NewOp, indexSteps};
708713
// Otherwise permute the targetShape to move the index dimension last,
709714
// broadcast the one-dimensional index vector to the permuted shape, and
710715
// finally transpose the broadcasted index vector to undo the permutation.
711716
auto permPattern =
712717
llvm::to_vector(llvm::seq<unsigned>(0, targetShape.size()));
713-
std::swap(permPattern[indexOp.getDim()], permPattern.back());
718+
std::swap(permPattern[dim], permPattern.back());
714719
auto permMap =
715720
AffineMap::getPermutationMap(permPattern, linalgOp.getContext());
716721

@@ -719,7 +724,7 @@ static VectorizationResult vectorizeLinalgIndex(RewriterBase &rewriter,
719724
indexSteps);
720725
SmallVector<int64_t> transposition =
721726
llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops()));
722-
std::swap(transposition.back(), transposition[indexOp.getDim()]);
727+
std::swap(transposition.back(), transposition[dim]);
723728
auto transposeOp =
724729
rewriter.create<vector::TransposeOp>(loc, broadCastOp, transposition);
725730
return VectorizationResult{VectorizationStatus::NewOp, transposeOp};

mlir/test/Dialect/Linalg/vectorization-scalable.mlir

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,3 +142,50 @@ module attributes {transform.with_named_sequence} {
142142
}
143143
}
144144

145+
// -----
146+
147+
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
148+
func.func @vectorize_linalg_index(%arg0: tensor<3x3x?xf32>, %arg1: tensor<1x1x?xf32>) -> tensor<1x1x?xf32> {
149+
%0 = linalg.generic {
150+
indexing_maps = [#map],
151+
iterator_types = ["parallel", "parallel", "parallel"]
152+
} outs(%arg1 : tensor<1x1x?xf32>) {
153+
^bb0(%in: f32):
154+
%1 = linalg.index 0 : index
155+
%2 = linalg.index 1 : index
156+
%3 = linalg.index 2 : index
157+
%4 = tensor.extract %arg0[%1, %2, %3] : tensor<3x3x?xf32>
158+
linalg.yield %4 : f32
159+
} -> tensor<1x1x?xf32>
160+
return %0 : tensor<1x1x?xf32>
161+
}
162+
163+
// CHECK-LABEL: @vectorize_linalg_index
164+
// CHECK-SAME: %[[SRC:.*]]: tensor<3x3x?xf32>, %[[DST:.*]]: tensor<1x1x?xf32>
165+
// CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x[4]xf32>
166+
// CHECK-DAG: %[[MASK:.*]] = arith.constant dense<true> : vector<1x1x[4]xi1>
167+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
168+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
169+
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
170+
// CHECK: %[[DST_DIM2:.*]] = tensor.dim %[[DST]], %[[C2]] : tensor<1x1x?xf32>
171+
// CHECK: %[[DST_MASK:.*]] = vector.create_mask %[[C1]], %[[C1]], %[[DST_DIM2]] : vector<1x1x[4]xi1>
172+
// CHECK: %[[INDEX_VEC:.*]] = vector.step : vector<[4]xindex>
173+
// CHECK: %[[INDEX_VEC_BCAST:.*]] = vector.broadcast %[[INDEX_VEC]] : vector<[4]xindex> to vector<1x1x[4]xindex>
174+
// CHECK: %[[GATHER:.*]] = vector.mask %[[DST_MASK]] { vector.gather %[[SRC]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {{\[}}%[[INDEX_VEC_BCAST]]], %[[MASK]], %[[PASSTHRU]] : tensor<3x3x?xf32>, vector<1x1x[4]xindex>, vector<1x1x[4]xi1>, vector<1x1x[4]xf32> into vector<1x1x[4]xf32> } : vector<1x1x[4]xi1> -> vector<1x1x[4]xf32>
175+
// CHECK: %[[OUT:.*]] = vector.mask %[[DST_MASK]] { vector.transfer_write %[[GATHER]], %[[DST]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x[4]xf32>, tensor<1x1x?xf32> } : vector<1x1x[4]xi1> -> tensor<1x1x?xf32>
176+
// CHECK: return %[[OUT]] : tensor<1x1x?xf32>
177+
178+
module attributes {transform.with_named_sequence} {
179+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
180+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
181+
transform.structured.vectorize %0 vector_sizes [1, 1, [4]] {vectorize_nd_extract} : !transform.any_op
182+
183+
%func = transform.structured.match ops{["func.func"]} in %arg1
184+
: (!transform.any_op) -> !transform.any_op
185+
transform.apply_patterns to %func {
186+
transform.apply_patterns.canonicalization
187+
transform.apply_patterns.linalg.tiling_canonicalization
188+
} : !transform.any_op
189+
transform.yield
190+
}
191+
}

mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguo
6363
// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
6464
// CHECK: %[[VAL_10:.*]] = vector.create_mask %[[VAL_5]], %[[VAL_7]] : vector<1x4xi1>
6565
// CHECK: %[[VAL_11:.*]] = vector.mask %[[VAL_10]] { vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_8]], %[[VAL_8]]], %[[VAL_9]] {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
66-
// CHECK: %[[VAL_12:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
66+
// CHECK: %[[VAL_12:.*]] = vector.step : vector<4xindex>
6767
// CHECK: %[[VAL_13:.*]] = vector.broadcast %[[VAL_1]] : index to vector<4xindex>
6868
// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_13]] : vector<4xindex>
6969
// CHECK-DAG: %[[VAL_15:.*]] = arith.constant dense<true> : vector<1x4xi1>
@@ -160,7 +160,7 @@ func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_gather(%
160160
// CHECK: %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
161161
// CHECK: %[[VAL_10:.*]] = vector.create_mask %[[VAL_5]], %[[VAL_7]] : vector<1x4xi1>
162162
// CHECK: %[[VAL_11:.*]] = vector.mask %[[VAL_10]] { vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_8]], %[[VAL_8]]], %[[VAL_9]] {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
163-
// CHECK: %[[VAL_12:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
163+
// CHECK: %[[VAL_12:.*]] = vector.step : vector<4xindex>
164164
// CHECK: %[[VAL_13:.*]] = vector.broadcast %[[VAL_1]] : index to vector<4xindex>
165165
// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_13]] : vector<4xindex>
166166
// CHECK: %[[VAL_15:.*]] = arith.constant dense<true> : vector<1x4xi1>

0 commit comments

Comments
 (0)