Skip to content

Commit 1ac61c8

Browse files
authored
[mlir][Vector] Remove vector.extractelement/insertelement from sparse vectorizer (#143270)
This PR is part of the last step to remove `vector.extractelement` and `vector.insertelement` ops. RFC: https://discourse.llvm.org/t/rfc-psa-remove-vector-extractelement-and-vector-insertelement-ops-in-favor-of-vector-extract-and-vector-insert-ops It updates the Sparse Vectorizer to use `vector.extract` and `vector.insert` instead of `vector.extractelement` and `vector.insertelement`.
1 parent 902a991 commit 1ac61c8

File tree

5 files changed

+52
-34
lines changed

5 files changed

+52
-34
lines changed

mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp

Lines changed: 42 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -198,14 +198,14 @@ static Value genVectorReducInit(PatternRewriter &rewriter, Location loc,
198198
case vector::CombiningKind::ADD:
199199
case vector::CombiningKind::XOR:
200200
// Initialize reduction vector to: | 0 | .. | 0 | r |
201-
return rewriter.create<vector::InsertElementOp>(
202-
loc, r, constantZero(rewriter, loc, vtp),
203-
constantIndex(rewriter, loc, 0));
201+
return rewriter.create<vector::InsertOp>(loc, r,
202+
constantZero(rewriter, loc, vtp),
203+
constantIndex(rewriter, loc, 0));
204204
case vector::CombiningKind::MUL:
205205
// Initialize reduction vector to: | 1 | .. | 1 | r |
206-
return rewriter.create<vector::InsertElementOp>(
207-
loc, r, constantOne(rewriter, loc, vtp),
208-
constantIndex(rewriter, loc, 0));
206+
return rewriter.create<vector::InsertOp>(loc, r,
207+
constantOne(rewriter, loc, vtp),
208+
constantIndex(rewriter, loc, 0));
209209
case vector::CombiningKind::AND:
210210
case vector::CombiningKind::OR:
211211
// Initialize reduction vector to: | r | .. | r | r |
@@ -628,31 +628,49 @@ struct ForOpRewriter : public OpRewritePattern<scf::ForOp> {
628628
const VL vl;
629629
};
630630

631+
static LogicalResult cleanReducChain(PatternRewriter &rewriter, Operation *op,
632+
Value inp) {
633+
if (auto redOp = inp.getDefiningOp<vector::ReductionOp>()) {
634+
if (auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) {
635+
if (forOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) {
636+
rewriter.replaceOp(op, redOp.getVector());
637+
return success();
638+
}
639+
}
640+
}
641+
return failure();
642+
}
643+
631644
/// Reduction chain cleanup.
632645
/// v = for { }
633-
/// s = vsum(v) v = for { }
634-
/// u = expand(s) -> for (v) { }
646+
/// s = vsum(v) v = for { }
647+
/// u = broadcast(s) -> for (v) { }
635648
/// for (u) { }
636-
template <typename VectorOp>
637-
struct ReducChainRewriter : public OpRewritePattern<VectorOp> {
649+
struct ReducChainBroadcastRewriter
650+
: public OpRewritePattern<vector::BroadcastOp> {
638651
public:
639-
using OpRewritePattern<VectorOp>::OpRewritePattern;
652+
using OpRewritePattern<vector::BroadcastOp>::OpRewritePattern;
640653

641-
LogicalResult matchAndRewrite(VectorOp op,
654+
LogicalResult matchAndRewrite(vector::BroadcastOp op,
642655
PatternRewriter &rewriter) const override {
643-
Value inp = op.getSource();
644-
if (auto redOp = inp.getDefiningOp<vector::ReductionOp>()) {
645-
if (auto forOp = redOp.getVector().getDefiningOp<scf::ForOp>()) {
646-
if (forOp->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) {
647-
rewriter.replaceOp(op, redOp.getVector());
648-
return success();
649-
}
650-
}
651-
}
652-
return failure();
656+
return cleanReducChain(rewriter, op, op.getSource());
653657
}
654658
};
655659

660+
/// Reduction chain cleanup.
661+
/// v = for { }
662+
/// s = vsum(v) v = for { }
663+
/// u = insert(s) -> for (v) { }
664+
/// for (u) { }
665+
struct ReducChainInsertRewriter : public OpRewritePattern<vector::InsertOp> {
666+
public:
667+
using OpRewritePattern<vector::InsertOp>::OpRewritePattern;
668+
669+
LogicalResult matchAndRewrite(vector::InsertOp op,
670+
PatternRewriter &rewriter) const override {
671+
return cleanReducChain(rewriter, op, op.getValueToStore());
672+
}
673+
};
656674
} // namespace
657675

658676
//===----------------------------------------------------------------------===//
@@ -668,6 +686,6 @@ void mlir::populateSparseVectorizationPatterns(RewritePatternSet &patterns,
668686
vector::populateVectorStepLoweringPatterns(patterns);
669687
patterns.add<ForOpRewriter>(patterns.getContext(), vectorLength,
670688
enableVLAVectorization, enableSIMDIndex32);
671-
patterns.add<ReducChainRewriter<vector::InsertElementOp>,
672-
ReducChainRewriter<vector::BroadcastOp>>(patterns.getContext());
689+
patterns.add<ReducChainInsertRewriter, ReducChainBroadcastRewriter>(
690+
patterns.getContext());
673691
}

mlir/test/Dialect/SparseTensor/minipipeline_vector.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
// CHECK-NOVEC: }
2323
//
2424
// CHECK-VEC-LABEL: func.func @sum_reduction
25-
// CHECK-VEC: vector.insertelement
25+
// CHECK-VEC: vector.insert
2626
// CHECK-VEC: scf.for
2727
// CHECK-VEC: vector.create_mask
2828
// CHECK-VEC: vector.maskedload

mlir/test/Dialect/SparseTensor/sparse_vector.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ func.func @mul_s(%arga: tensor<1024xf32, #SparseVector>,
241241
// CHECK-VEC16-DAG: %[[c1024:.*]] = arith.constant 1024 : index
242242
// CHECK-VEC16-DAG: %[[v0:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
243243
// CHECK-VEC16: %[[l:.*]] = memref.load %{{.*}}[] : memref<f32>
244-
// CHECK-VEC16: %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<16xf32>
244+
// CHECK-VEC16: %[[r:.*]] = vector.insert %[[l]], %[[v0]] [0] : f32 into vector<16xf32>
245245
// CHECK-VEC16: %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<16xf32>) {
246246
// CHECK-VEC16: %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
247247
// CHECK-VEC16: %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
@@ -258,7 +258,7 @@ func.func @mul_s(%arga: tensor<1024xf32, #SparseVector>,
258258
// CHECK-VEC16-IDX32-DAG: %[[c1024:.*]] = arith.constant 1024 : index
259259
// CHECK-VEC16-IDX32-DAG: %[[v0:.*]] = arith.constant dense<0.000000e+00> : vector<16xf32>
260260
// CHECK-VEC16-IDX32: %[[l:.*]] = memref.load %{{.*}}[] : memref<f32>
261-
// CHECK-VEC16-IDX32: %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<16xf32>
261+
// CHECK-VEC16-IDX32: %[[r:.*]] = vector.insert %[[l]], %[[v0]] [0] : f32 into vector<16xf32>
262262
// CHECK-VEC16-IDX32: %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[c16]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<16xf32>) {
263263
// CHECK-VEC16-IDX32: %[[la:.*]] = vector.load %{{.*}}[%[[i]]] : memref<?xf32>, vector<16xf32>
264264
// CHECK-VEC16-IDX32: %[[lb:.*]] = vector.load %{{.*}}[%[[i]]] : memref<1024xf32>, vector<16xf32>
@@ -278,7 +278,7 @@ func.func @mul_s(%arga: tensor<1024xf32, #SparseVector>,
278278
// CHECK-VEC4-SVE: %[[l:.*]] = memref.load %{{.*}}[] : memref<f32>
279279
// CHECK-VEC4-SVE: %[[vscale:.*]] = vector.vscale
280280
// CHECK-VEC4-SVE: %[[step:.*]] = arith.muli %[[vscale]], %[[c4]] : index
281-
// CHECK-VEC4-SVE: %[[r:.*]] = vector.insertelement %[[l]], %[[v0]][%[[c0]] : index] : vector<[4]xf32>
281+
// CHECK-VEC4-SVE: %[[r:.*]] = vector.insert %[[l]], %[[v0]] [0] : f32 into vector<[4]xf32>
282282
// CHECK-VEC4-SVE: %[[red:.*]] = scf.for %[[i:.*]] = %[[c0]] to %[[c1024]] step %[[step]] iter_args(%[[red_in:.*]] = %[[r]]) -> (vector<[4]xf32>) {
283283
// CHECK-VEC4-SVE: %[[sub:.*]] = affine.min #[[$map]](%[[c1024]], %[[i]])[%[[step]]]
284284
// CHECK-VEC4-SVE: %[[mask:.*]] = vector.create_mask %[[sub]] : vector<[4]xi1>

mlir/test/Dialect/SparseTensor/sparse_vector_chain.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@
8282
// CHECK: %[[VAL_57:.*]] = arith.select %[[VAL_39]], %[[VAL_56]], %[[VAL_32]] : index
8383
// CHECK: scf.yield %[[VAL_55]], %[[VAL_57]], %[[VAL_58:.*]] : index, index, f64
8484
// CHECK: } attributes {"Emitted from" = "linalg.generic"}
85-
// CHECK: %[[VAL_59:.*]] = vector.insertelement %[[VAL_60:.*]]#2, %[[VAL_4]]{{\[}}%[[VAL_6]] : index] : vector<8xf64>
85+
// CHECK: %[[VAL_59:.*]] = vector.insert %[[VAL_60:.*]]#2, %[[VAL_4]] [0] : f64 into vector<8xf64>
8686
// CHECK: %[[VAL_61:.*]] = scf.for %[[VAL_62:.*]] = %[[VAL_60]]#0 to %[[VAL_21]] step %[[VAL_3]] iter_args(%[[VAL_63:.*]] = %[[VAL_59]]) -> (vector<8xf64>) {
8787
// CHECK: %[[VAL_64:.*]] = affine.min #map(%[[VAL_21]], %[[VAL_62]]){{\[}}%[[VAL_3]]]
8888
// CHECK: %[[VAL_65:.*]] = vector.create_mask %[[VAL_64]] : vector<8xi1>

mlir/test/Dialect/SparseTensor/vectorize_reduction.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ func.func @sparse_reduction_ori_accumulator_on_rhs(%argx: tensor<i13>,
172172
// CHECK-ON: %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<i32>
173173
// CHECK-ON: %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_3]]] : memref<?xindex>
174174
// CHECK-ON: %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
175-
// CHECK-ON: %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_4]]{{\[}}%[[VAL_3]] : index] : vector<8xi32>
175+
// CHECK-ON: %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_4]] [0] : i32 into vector<8xi32>
176176
// CHECK-ON: %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xi32>) {
177177
// CHECK-ON: %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
178178
// CHECK-ON: %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>
@@ -247,7 +247,7 @@ func.func @sparse_reduction_subi(%argx: tensor<i32>,
247247
// CHECK-ON: %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<i32>
248248
// CHECK-ON: %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
249249
// CHECK-ON: %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
250-
// CHECK-ON: %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_3]]{{\[}}%[[VAL_4]] : index] : vector<8xi32>
250+
// CHECK-ON: %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_3]] [0] : i32 into vector<8xi32>
251251
// CHECK-ON: %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xi32>) {
252252
// CHECK-ON: %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
253253
// CHECK-ON: %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>
@@ -323,7 +323,7 @@ func.func @sparse_reduction_xor(%argx: tensor<i32>,
323323
// CHECK-ON: %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<i32>
324324
// CHECK-ON: %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
325325
// CHECK-ON: %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
326-
// CHECK-ON: %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_3]]{{\[}}%[[VAL_4]] : index] : vector<8xi32>
326+
// CHECK-ON: %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_3]] [0] : i32 into vector<8xi32>
327327
// CHECK-ON: %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xi32>) {
328328
// CHECK-ON: %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
329329
// CHECK-ON: %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>
@@ -399,7 +399,7 @@ func.func @sparse_reduction_addi(%argx: tensor<i32>,
399399
// CHECK-ON: %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<f32>
400400
// CHECK-ON: %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
401401
// CHECK-ON: %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
402-
// CHECK-ON: %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_3]]{{\[}}%[[VAL_4]] : index] : vector<8xf32>
402+
// CHECK-ON: %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_3]] [0] : f32 into vector<8xf32>
403403
// CHECK-ON: %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xf32>) {
404404
// CHECK-ON: %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
405405
// CHECK-ON: %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>
@@ -475,7 +475,7 @@ func.func @sparse_reduction_subf(%argx: tensor<f32>,
475475
// CHECK-ON: %[[VAL_9:.*]] = memref.load %[[VAL_8]][] : memref<f32>
476476
// CHECK-ON: %[[VAL_10:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_4]]] : memref<?xindex>
477477
// CHECK-ON: %[[VAL_11:.*]] = memref.load %[[VAL_6]]{{\[}}%[[VAL_5]]] : memref<?xindex>
478-
// CHECK-ON: %[[VAL_12:.*]] = vector.insertelement %[[VAL_9]], %[[VAL_3]]{{\[}}%[[VAL_4]] : index] : vector<8xf32>
478+
// CHECK-ON: %[[VAL_12:.*]] = vector.insert %[[VAL_9]], %[[VAL_3]] [0] : f32 into vector<8xf32>
479479
// CHECK-ON: %[[VAL_13:.*]] = scf.for %[[VAL_14:.*]] = %[[VAL_10]] to %[[VAL_11]] step %[[VAL_2]] iter_args(%[[VAL_15:.*]] = %[[VAL_12]]) -> (vector<8xf32>) {
480480
// CHECK-ON: %[[VAL_16:.*]] = affine.min #map(%[[VAL_11]], %[[VAL_14]]){{\[}}%[[VAL_2]]]
481481
// CHECK-ON: %[[VAL_17:.*]] = vector.create_mask %[[VAL_16]] : vector<8xi1>

0 commit comments

Comments
 (0)