Skip to content

Commit a6e72f9

Browse files
authored
[MLIR][Vector] Add Lowering for vector.step (#113655)
Currently, the lowering for vector.step lives under a folder. This is not ideal if we want to do transformation on it and defer the materizaliztion of the constants much later. This commits adds a rewrite pattern that could be used by using `transform.structured.vectorize_children_and_apply_patterns` transform dialect operation. Moreover, the rewriter of vector.step is also now used in -convert-vector-to-llvm pass where it handles scalable and non-scalable types as LLVM expects it. As a consequence of removing the vector.step lowering as its folder, linalg vectorization will keep vector.step intact.
1 parent 10a1ea9 commit a6e72f9

File tree

11 files changed

+97
-48
lines changed

11 files changed

+97
-48
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2946,7 +2946,6 @@ def Vector_StepOp : Vector_Op<"step", [Pure]> {
29462946
%1 = vector.step : vector<[4]xindex> ; [0, 1, .., <vscale * 4 - 1>]
29472947
```
29482948
}];
2949-
let hasFolder = 1;
29502949
let results = (outs VectorOfRankAndType<[1], [Index]>:$result);
29512950
let assemblyFormat = "attr-dict `:` type($result)";
29522951
}

mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,13 @@ void populateVectorTransferPermutationMapLoweringPatterns(
235235
void populateVectorScanLoweringPatterns(RewritePatternSet &patterns,
236236
PatternBenefit benefit = 1);
237237

238+
/// Populate the pattern set with the following patterns:
239+
///
240+
/// [StepToArithConstantOp]
241+
/// Convert vector.step op into arith ops if not using scalable vectors
242+
void populateVectorStepLoweringPatterns(RewritePatternSet &patterns,
243+
PatternBenefit benefit = 1);
244+
238245
/// Populate the pattern set with the following patterns:
239246
///
240247
/// [FlattenGather]

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1865,12 +1865,17 @@ struct VectorFromElementsLowering
18651865
};
18661866

18671867
/// Conversion pattern for vector.step.
1868-
struct VectorStepOpLowering : public ConvertOpToLLVMPattern<vector::StepOp> {
1868+
struct VectorScalableStepOpLowering
1869+
: public ConvertOpToLLVMPattern<vector::StepOp> {
18691870
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
18701871

18711872
LogicalResult
18721873
matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
18731874
ConversionPatternRewriter &rewriter) const override {
1875+
auto resultType = cast<VectorType>(stepOp.getType());
1876+
if (!resultType.isScalable()) {
1877+
return failure();
1878+
}
18741879
Type llvmType = typeConverter->convertType(stepOp.getType());
18751880
rewriter.replaceOpWithNewOp<LLVM::StepVectorOp>(stepOp, llvmType);
18761881
return success();
@@ -1886,6 +1891,7 @@ void mlir::populateVectorToLLVMConversionPatterns(
18861891
MLIRContext *ctx = converter.getDialect()->getContext();
18871892
patterns.add<VectorFMAOpNDRewritePattern>(ctx);
18881893
populateVectorInsertExtractStridedSliceTransforms(patterns);
1894+
populateVectorStepLoweringPatterns(patterns);
18891895
patterns.add<VectorReductionOpConversion>(converter, reassociateFPReductions);
18901896
patterns.add<VectorCreateMaskOpRewritePattern>(ctx, force32BitVectorIndices);
18911897
patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
@@ -1903,7 +1909,7 @@ void mlir::populateVectorToLLVMConversionPatterns(
19031909
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
19041910
MaskedReductionOpConversion, VectorInterleaveOpLowering,
19051911
VectorDeinterleaveOpLowering, VectorFromElementsLowering,
1906-
VectorStepOpLowering>(converter);
1912+
VectorScalableStepOpLowering>(converter);
19071913
// Transfer ops with rank > 1 are handled by VectorToSCF.
19081914
populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1);
19091915
}

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3488,6 +3488,7 @@ transform::VectorizeChildrenAndApplyPatternsOp::applyToOne(
34883488

34893489
if (getVectorizePadding())
34903490
linalg::populatePadOpVectorizationPatterns(patterns);
3491+
vector::populateVectorStepLoweringPatterns(patterns);
34913492

34923493
TrackingListener listener(state, *this);
34933494
GreedyRewriteConfig config;

mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "mlir/Dialect/SCF/IR/SCF.h"
2828
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
2929
#include "mlir/Dialect/Vector/IR/VectorOps.h"
30+
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
3031
#include "mlir/IR/Matchers.h"
3132

3233
using namespace mlir;
@@ -664,6 +665,7 @@ void mlir::populateSparseVectorizationPatterns(RewritePatternSet &patterns,
664665
bool enableVLAVectorization,
665666
bool enableSIMDIndex32) {
666667
assert(vectorLength > 0);
668+
vector::populateVectorStepLoweringPatterns(patterns);
667669
patterns.add<ForOpRewriter>(patterns.getContext(), vectorLength,
668670
enableVLAVectorization, enableSIMDIndex32);
669671
patterns.add<ReducChainRewriter<vector::InsertElementOp>,

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6423,20 +6423,6 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
64236423
return SplatElementsAttr::get(getType(), {constOperand});
64246424
}
64256425

6426-
//===----------------------------------------------------------------------===//
6427-
// StepOp
6428-
//===----------------------------------------------------------------------===//
6429-
6430-
OpFoldResult StepOp::fold(FoldAdaptor adaptor) {
6431-
auto resultType = cast<VectorType>(getType());
6432-
if (resultType.isScalable())
6433-
return nullptr;
6434-
SmallVector<APInt> indices;
6435-
for (unsigned i = 0; i < resultType.getNumElements(); i++)
6436-
indices.push_back(APInt(/*width=*/64, i));
6437-
return DenseElementsAttr::get(resultType, indices);
6438-
}
6439-
64406426
//===----------------------------------------------------------------------===//
64416427
// WarpExecuteOnLane0Op
64426428
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
99
LowerVectorMultiReduction.cpp
1010
LowerVectorScan.cpp
1111
LowerVectorShapeCast.cpp
12+
LowerVectorStep.cpp
1213
LowerVectorTransfer.cpp
1314
LowerVectorTranspose.cpp
1415
SubsetOpInterfaceImpl.cpp
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
//===- LowerVectorStep.cpp - Lower 'vector.step' operation ----------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements target-independent rewrites and utilities to lower the
10+
// 'vector.step' operation.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "mlir/Dialect/Arith/IR/Arith.h"
15+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
16+
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
17+
#include "mlir/IR/PatternMatch.h"
18+
19+
#define DEBUG_TYPE "vector-step-lowering"
20+
21+
using namespace mlir;
22+
using namespace mlir::vector;
23+
24+
namespace {
25+
26+
struct StepToArithConstantOpRewrite final : OpRewritePattern<vector::StepOp> {
27+
using OpRewritePattern::OpRewritePattern;
28+
29+
LogicalResult matchAndRewrite(vector::StepOp stepOp,
30+
PatternRewriter &rewriter) const override {
31+
auto resultType = cast<VectorType>(stepOp.getType());
32+
if (resultType.isScalable()) {
33+
return failure();
34+
}
35+
int64_t elementCount = resultType.getNumElements();
36+
SmallVector<APInt> indices =
37+
llvm::map_to_vector(llvm::seq(elementCount),
38+
[](int64_t i) { return APInt(/*width=*/64, i); });
39+
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
40+
stepOp, DenseElementsAttr::get(resultType, indices));
41+
return success();
42+
}
43+
};
44+
} // namespace
45+
46+
void mlir::vector::populateVectorStepLoweringPatterns(
47+
RewritePatternSet &patterns, PatternBenefit benefit) {
48+
patterns.add<StepToArithConstantOpRewrite>(patterns.getContext(), benefit);
49+
}

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3448,3 +3448,13 @@ func.func @vector_step_scalable() -> vector<[4]xindex> {
34483448
%0 = vector.step : vector<[4]xindex>
34493449
return %0 : vector<[4]xindex>
34503450
}
3451+
3452+
// -----
3453+
3454+
// CHECK-LABEL: @vector_step
3455+
// CHECK: %[[CST:.+]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
3456+
// CHECK: return %[[CST]] : vector<4xindex>
3457+
func.func @vector_step() -> vector<4xindex> {
3458+
%0 = vector.step : vector<4xindex>
3459+
return %0 : vector<4xindex>
3460+
}

mlir/test/Dialect/Linalg/vectorization-scalable.mlir

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -144,43 +144,40 @@ module attributes {transform.with_named_sequence} {
144144

145145
// -----
146146

147-
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
148-
func.func @vectorize_linalg_index(%arg0: tensor<3x3x?xf32>, %arg1: tensor<1x1x?xf32>) -> tensor<1x1x?xf32> {
147+
#map = affine_map<(d0) -> (d0)>
148+
func.func @vectorize_linalg_index(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
149149
%0 = linalg.generic {
150150
indexing_maps = [#map],
151-
iterator_types = ["parallel", "parallel", "parallel"]
152-
} outs(%arg1 : tensor<1x1x?xf32>) {
151+
iterator_types = ["parallel"]
152+
} outs(%arg1 : tensor<?xf32>) {
153153
^bb0(%in: f32):
154154
%1 = linalg.index 0 : index
155-
%2 = linalg.index 1 : index
156-
%3 = linalg.index 2 : index
157-
%4 = tensor.extract %arg0[%1, %2, %3] : tensor<3x3x?xf32>
158-
linalg.yield %4 : f32
159-
} -> tensor<1x1x?xf32>
160-
return %0 : tensor<1x1x?xf32>
155+
%2 = tensor.extract %arg0[%1] : tensor<?xf32>
156+
linalg.yield %2 : f32
157+
} -> tensor<?xf32>
158+
return %0 : tensor<?xf32>
161159
}
162160

163161
// CHECK-LABEL: @vectorize_linalg_index
164-
// CHECK-SAME: %[[SRC:.*]]: tensor<3x3x?xf32>, %[[DST:.*]]: tensor<1x1x?xf32>
165-
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
166-
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
167-
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
168-
// CHECK: %[[DST_DIM2:.*]] = tensor.dim %[[DST]], %[[C2]] : tensor<1x1x?xf32>
169-
// CHECK: %[[MASK:.*]] = vector.create_mask %[[C1]], %[[C1]], %[[DST_DIM2]] : vector<1x1x[4]xi1>
170-
// CHECK: %[[INDEX_VEC:.*]] = vector.step : vector<[4]xindex>
171-
// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]][%c0, %c0, %2], %cst {in_bounds = [true, true, true]} : tensor<3x3x?xf32>, vector<1x1x[4]xf32> } : vector<1x1x[4]xi1> -> vector<1x1x[4]xf32>
172-
// CHECK: %[[OUT:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[DST]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x[4]xf32>, tensor<1x1x?xf32> } : vector<1x1x[4]xi1> -> tensor<1x1x?xf32>
173-
// CHECK: return %[[OUT]] : tensor<1x1x?xf32>
162+
// CHECK-SAME: %[[SRC:.*]]: tensor<?xf32>, %[[DST:.*]]: tensor<?xf32>
163+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
164+
// CHECK: %[[DST_DIM0:.*]] = tensor.dim %[[DST]], %[[C0]] : tensor<?xf32>
165+
// CHECK: %[[MASK:.*]] = vector.create_mask %[[DST_DIM0]] : vector<[4]xi1>
166+
// CHECK-DAG: %[[STEP:.+]] = vector.step : vector<[4]xindex>
167+
// CHECK-DAG: %[[STEP_ELEMENT:.+]] = vector.extractelement %[[STEP]][%c0_i32 : i32] : vector<[4]xindex>
168+
169+
// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]][%[[STEP_ELEMENT]]], %cst {in_bounds = [true]} : tensor<?xf32>, vector<[4]xf32> } : vector<[4]xi1> -> vector<[4]xf32>
170+
// CHECK: %[[OUT:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[DST]]{{\[}}%[[C0]]] {in_bounds = [true]} : vector<[4]xf32>, tensor<?xf32> } : vector<[4]xi1> -> tensor<?xf32>
171+
// CHECK: return %[[OUT]] : tensor<?xf32>
174172

175173
module attributes {transform.with_named_sequence} {
176174
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
177175
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
178-
transform.structured.vectorize %0 vector_sizes [1, 1, [4]] {vectorize_nd_extract} : !transform.any_op
176+
transform.structured.vectorize %0 vector_sizes [[4]] {vectorize_nd_extract} : !transform.any_op
179177

180178
%func = transform.structured.match ops{["func.func"]} in %arg1
181179
: (!transform.any_op) -> !transform.any_op
182180
transform.apply_patterns to %func {
183-
transform.apply_patterns.canonicalization
184181
transform.apply_patterns.linalg.tiling_canonicalization
185182
} : !transform.any_op
186183
transform.yield

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2748,15 +2748,6 @@ func.func @from_elements_to_splat(%a: f32, %b: f32) -> (vector<2x3xf32>, vector<
27482748
return %0, %1, %2 : vector<2x3xf32>, vector<2x3xf32>, vector<f32>
27492749
}
27502750

2751-
// -----
2752-
2753-
// CHECK-LABEL: @fold_vector_step_to_constant
2754-
// CHECK: %[[CONSTANT:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
2755-
// CHECK: return %[[CONSTANT]] : vector<4xindex>
2756-
func.func @fold_vector_step_to_constant() -> vector<4xindex> {
2757-
%0 = vector.step : vector<4xindex>
2758-
return %0 : vector<4xindex>
2759-
}
27602751

27612752
// -----
27622753

0 commit comments

Comments
 (0)