Skip to content

Commit 80c6dbe

Browse files
committed
[mlir][Vector] Remove more special case uses for extractelement/insertelement
1 parent 4fa5ab3 commit 80c6dbe

File tree

9 files changed

+32
-59
lines changed

9 files changed

+32
-59
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,7 @@ def Vector_ExtractOp :
718718
let results = (outs AnyType:$result);
719719

720720
let builders = [
721+
OpBuilder<(ins "Value":$source)>,
721722
OpBuilder<(ins "Value":$source, "int64_t":$position)>,
722723
OpBuilder<(ins "Value":$source, "OpFoldResult":$position)>,
723724
OpBuilder<(ins "Value":$source, "ArrayRef<int64_t>":$position)>,
@@ -913,6 +914,7 @@ def Vector_InsertOp :
913914
let results = (outs AnyVectorOfAnyRank:$result);
914915

915916
let builders = [
917+
OpBuilder<(ins "Value":$source, "Value":$dest)>,
916918
OpBuilder<(ins "Value":$source, "Value":$dest, "int64_t":$position)>,
917919
OpBuilder<(ins "Value":$source, "Value":$dest, "OpFoldResult":$position)>,
918920
OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayRef<int64_t>":$position)>,

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -560,11 +560,9 @@ struct ElideUnitDimsInMultiDimReduction
560560
} else {
561561
// This means we are reducing all the dimensions, and all reduction
562562
// dimensions are of size 1. So a simple extraction would do.
563-
SmallVector<int64_t> zeroIdx(shape.size(), 0);
564563
if (mask)
565-
mask = rewriter.create<vector::ExtractOp>(loc, mask, zeroIdx);
566-
cast = rewriter.create<vector::ExtractOp>(loc, reductionOp.getSource(),
567-
zeroIdx);
564+
mask = rewriter.create<vector::ExtractOp>(loc, mask);
565+
cast = rewriter.create<vector::ExtractOp>(loc, reductionOp.getSource());
568566
}
569567

570568
Value result =
@@ -698,16 +696,9 @@ struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
698696
return failure();
699697

700698
Location loc = reductionOp.getLoc();
701-
Value result;
702-
if (vectorType.getRank() == 0) {
703-
if (mask)
704-
mask = rewriter.create<ExtractElementOp>(loc, mask);
705-
result = rewriter.create<ExtractElementOp>(loc, reductionOp.getVector());
706-
} else {
707-
if (mask)
708-
mask = rewriter.create<ExtractOp>(loc, mask, 0);
709-
result = rewriter.create<ExtractOp>(loc, reductionOp.getVector(), 0);
710-
}
699+
if (mask)
700+
mask = rewriter.create<ExtractOp>(loc, mask);
701+
Value result = rewriter.create<ExtractOp>(loc, reductionOp.getVector());
711702

712703
if (Value acc = reductionOp.getAcc())
713704
result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
@@ -1294,6 +1285,12 @@ void ExtractOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
12941285
setResultRanges(getResult(), argRanges.front());
12951286
}
12961287

1288+
void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1289+
Value source) {
1290+
auto vectorTy = cast<VectorType>(source.getType());
1291+
build(builder, result, source, SmallVector<int64_t>(vectorTy.getRank(), 0));
1292+
}
1293+
12971294
void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
12981295
Value source, int64_t position) {
12991296
build(builder, result, source, ArrayRef<int64_t>{position});
@@ -2916,6 +2913,13 @@ void vector::InsertOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
29162913
setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
29172914
}
29182915

2916+
void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
2917+
Value source, Value dest) {
2918+
auto vectorTy = cast<VectorType>(dest.getType());
2919+
build(builder, result, source, dest,
2920+
SmallVector<int64_t>(vectorTy.getRank(), 0));
2921+
}
2922+
29192923
void vector::InsertOp::build(OpBuilder &builder, OperationState &result,
29202924
Value source, Value dest, int64_t position) {
29212925
build(builder, result, source, dest, ArrayRef<int64_t>{position});

mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,7 @@ class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
5252

5353
// Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
5454
if (srcRank <= 1 && dstRank == 1) {
55-
Value ext;
56-
if (srcRank == 0)
57-
ext = rewriter.create<vector::ExtractElementOp>(loc, op.getSource());
58-
else
59-
ext = rewriter.create<vector::ExtractOp>(loc, op.getSource(), 0);
55+
Value ext = rewriter.create<vector::ExtractOp>(loc, op.getSource());
6056
rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
6157
return success();
6258
}

mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -189,25 +189,9 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
189189
incIdx(resIdx, resultVectorType);
190190
}
191191

192-
Value extract;
193-
if (srcRank == 0) {
194-
// 0-D vector special case
195-
assert(srcIdx.empty() && "Unexpected indices for 0-D vector");
196-
extract = rewriter.create<vector::ExtractElementOp>(
197-
loc, op.getSourceVectorType().getElementType(), op.getSource());
198-
} else {
199-
extract =
200-
rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
201-
}
202-
203-
if (resRank == 0) {
204-
// 0-D vector special case
205-
assert(resIdx.empty() && "Unexpected indices for 0-D vector");
206-
result = rewriter.create<vector::InsertElementOp>(loc, extract, result);
207-
} else {
208-
result =
209-
rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
210-
}
192+
Value extract =
193+
rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
194+
result = rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
211195
}
212196
rewriter.replaceOp(op, result);
213197
return success();

mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -929,17 +929,8 @@ class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> {
929929
if (!xferOp.getPermutationMap().isMinorIdentity())
930930
return failure();
931931
// Only float and integer element types are supported.
932-
Value scalar;
933-
if (vecType.getRank() == 0) {
934-
// vector.extract does not support vector<f32> etc., so use
935-
// vector.extractelement instead.
936-
scalar = rewriter.create<vector::ExtractElementOp>(xferOp.getLoc(),
937-
xferOp.getVector());
938-
} else {
939-
SmallVector<int64_t> pos(vecType.getRank(), 0);
940-
scalar = rewriter.create<vector::ExtractOp>(xferOp.getLoc(),
941-
xferOp.getVector(), pos);
942-
}
932+
Value scalar =
933+
rewriter.create<vector::ExtractOp>(xferOp.getLoc(), xferOp.getVector());
943934
// Construct a scalar store.
944935
if (isa<MemRefType>(xferOp.getSource().getType())) {
945936
rewriter.replaceOpWithNewOp<memref::StoreOp>(

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ func.func @broadcast_vec2d_from_vec0d(%arg0: vector<f32>) -> vector<3x2xf32> {
187187
// CHECK: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<f32> to vector<1xf32>
188188
// CHECK: %[[T1:.*]] = ub.poison : vector<3x2xf32>
189189
// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>>
190-
// CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : index) : i64
190+
// CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : i64) : i64
191191
// CHECK: %[[T5:.*]] = llvm.extractelement %[[T0]][%[[T4]] : i64] : vector<1xf32>
192192
// CHECK: %[[T6Insert:.*]] = llvm.insertelement %[[T5]]
193193
// CHECK: %[[T6:.*]] = llvm.shufflevector %[[T6Insert]]

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2658,7 +2658,7 @@ func.func @fold_extractelement_of_broadcast(%f: f32) -> f32 {
26582658

26592659
// CHECK-LABEL: func.func @fold_0d_vector_reduction
26602660
func.func @fold_0d_vector_reduction(%arg0: vector<f32>) -> f32 {
2661-
// CHECK-NEXT: %[[RES:.*]] = vector.extractelement %arg{{.*}}[] : vector<f32>
2661+
// CHECK-NEXT: %[[RES:.*]] = vector.extract %arg{{.*}}[] : f32 from vector<f32>
26622662
// CHECK-NEXT: return %[[RES]] : f32
26632663
%0 = vector.reduction <add>, %arg0 : vector<f32> into f32
26642664
return %0 : f32

mlir/test/Dialect/Vector/scalar-vector-transfer-to-memref.mlir

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,7 @@ func.func @tensor_transfer_read_0d(%t: tensor<?x?x?xf32>, %idx: index) -> f32 {
4545

4646
// CHECK-LABEL: func @transfer_write_0d(
4747
// CHECK-SAME: %[[m:.*]]: memref<?x?x?xf32>, %[[idx:.*]]: index, %[[f:.*]]: f32
48-
// CHECK: %[[bc:.*]] = vector.broadcast %[[f]] : f32 to vector<f32>
49-
// CHECK: %[[extract:.*]] = vector.extractelement %[[bc]][] : vector<f32>
50-
// CHECK: memref.store %[[extract]], %[[m]][%[[idx]], %[[idx]], %[[idx]]]
48+
// CHECK: memref.store %[[f]], %[[m]][%[[idx]], %[[idx]], %[[idx]]]
5149
func.func @transfer_write_0d(%m: memref<?x?x?xf32>, %idx: index, %f: f32) {
5250
%0 = vector.broadcast %f : f32 to vector<f32>
5351
vector.transfer_write %0, %m[%idx, %idx, %idx] : vector<f32>, memref<?x?x?xf32>
@@ -69,9 +67,7 @@ func.func @transfer_write_1d(%m: memref<?x?x?xf32>, %idx: index, %f: f32) {
6967

7068
// CHECK-LABEL: func @tensor_transfer_write_0d(
7169
// CHECK-SAME: %[[t:.*]]: tensor<?x?x?xf32>, %[[idx:.*]]: index, %[[f:.*]]: f32
72-
// CHECK: %[[bc:.*]] = vector.broadcast %[[f]] : f32 to vector<f32>
73-
// CHECK: %[[extract:.*]] = vector.extractelement %[[bc]][] : vector<f32>
74-
// CHECK: %[[r:.*]] = tensor.insert %[[extract]] into %[[t]][%[[idx]], %[[idx]], %[[idx]]]
70+
// CHECK: %[[r:.*]] = tensor.insert %[[f]] into %[[t]][%[[idx]], %[[idx]], %[[idx]]]
7571
// CHECK: return %[[r]]
7672
func.func @tensor_transfer_write_0d(%t: tensor<?x?x?xf32>, %idx: index, %f: f32) -> tensor<?x?x?xf32> {
7773
%0 = vector.broadcast %f : f32 to vector<f32>

mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ func.func @shape_cast_1d3d(%arg0 : vector<6xf32>) -> vector<2x1x3xf32> {
117117
// CHECK-LABEL: func.func @shape_cast_0d1d(
118118
// CHECK-SAME: %[[ARG0:.*]]: vector<f32>) -> vector<1xf32> {
119119
// CHECK: %[[UB:.*]] = ub.poison : vector<1xf32>
120-
// CHECK: %[[EXTRACT0:.*]] = vector.extractelement %[[ARG0]][] : vector<f32>
120+
// CHECK: %[[EXTRACT0:.*]] = vector.extract %[[ARG0]][] : f32 from vector<f32>
121121
// CHECK: %[[RES:.*]] = vector.insert %[[EXTRACT0]], %[[UB]] [0] : f32 into vector<1xf32>
122122
// CHECK: return %[[RES]] : vector<1xf32>
123123
// CHECK: }
@@ -131,7 +131,7 @@ func.func @shape_cast_0d1d(%arg0 : vector<f32>) -> vector<1xf32> {
131131
// CHECK-SAME: %[[ARG0:.*]]: vector<1xf32>) -> vector<f32> {
132132
// CHECK: %[[UB:.*]] = ub.poison : vector<f32>
133133
// CHECK: %[[EXTRACT0:.*]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32>
134-
// CHECK: %[[RES:.*]] = vector.insertelement %[[EXTRACT0]], %[[UB]][] : vector<f32>
134+
// CHECK: %[[RES:.*]] = vector.insert %[[EXTRACT0]], %[[UB]] [] : f32 into vector<f32>
135135
// CHECK: return %[[RES]] : vector<f32>
136136
// CHECK: }
137137

0 commit comments

Comments
 (0)