Skip to content

Commit c3fe42b

Browse files
committed
review comments
1 parent b228d28 commit c3fe42b

File tree

2 files changed

+27
-24
lines changed

2 files changed

+27
-24
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -962,6 +962,8 @@ struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
962962
};
963963

964964
/// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
965+
///
966+
/// Example:
965967
/// ```
966968
/// %a = vector.broadcast %arg1 : index to vector<1x4xindex>
967969
/// %b = vector.broadcast %arg2 : index to vector<1x4xindex>
@@ -1047,6 +1049,8 @@ struct ReorderElementwiseOpsOnBroadcast final
10471049
/// This may result in cleaner code when extracting a single value
10481050
/// from multi-element vector and also to help canonicalize 1-element vectors to
10491051
/// scalars.
1052+
///
1053+
/// Example:
10501054
/// ```
10511055
/// %0 = arith.addf %arg0, %arg1 : vector<4xf32>
10521056
/// %1 = vector.extract %0[1] : f32 from vector<4xf32>
@@ -1104,6 +1108,8 @@ class ExtractOpFromElementwise final
11041108
};
11051109

11061110
/// Pattern to rewrite vector.extract(vector.load) -> vector/memref.load.
1111+
///
1112+
/// Example:
11071113
/// ```
11081114
/// vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
11091115
/// vector.extract %0[1] : f32 from vector<4xf32>
@@ -1122,13 +1128,14 @@ class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
11221128
PatternRewriter &rewriter) const override {
11231129
auto loadOp = op.getVector().getDefiningOp<vector::LoadOp>();
11241130
if (!loadOp)
1125-
return rewriter.notifyMatchFailure(op, "not a load op");
1131+
return rewriter.notifyMatchFailure(op, "expected a load op");
11261132

1133+
// Checking for single use so we won't duplicate load ops.
11271134
if (!loadOp->hasOneUse())
11281135
return rewriter.notifyMatchFailure(op, "expected single op use");
11291136

1130-
VectorType memVecType = loadOp.getVectorType();
1131-
if (memVecType.isScalable())
1137+
VectorType loadVecType = loadOp.getVectorType();
1138+
if (loadVecType.isScalable())
11321139
return rewriter.notifyMatchFailure(op,
11331140
"scalable vectors are not supported");
11341141

@@ -1137,7 +1144,7 @@ class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
11371144
return rewriter.notifyMatchFailure(
11381145
op, "memrefs of vectors are not supported");
11391146

1140-
int64_t rankOffset = memType.getRank() - memVecType.getRank();
1147+
int64_t rankOffset = memType.getRank() - loadVecType.getRank();
11411148
if (rankOffset < 0)
11421149
return rewriter.notifyMatchFailure(op, "unsupported ranks combination");
11431150

@@ -1149,6 +1156,9 @@ class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
11491156
SmallVector<Value> indices = loadOp.getIndices();
11501157
SmallVector<OpFoldResult> extractPos = op.getMixedPosition();
11511158

1159+
// There may be memory stores between the load and the extract op, so we
1160+
// need to make sure that the new load op is inserted at the same place as
1161+
// the original load op.
11521162
OpBuilder::InsertionGuard g(rewriter);
11531163
rewriter.setInsertionPoint(loadOp);
11541164
Location loc = loadOp.getLoc();
@@ -1170,12 +1180,15 @@ class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
11701180
} else {
11711181
rewriter.replaceOpWithNewOp<memref::LoadOp>(op, base, indices);
11721182
}
1183+
// We checked for single use so we can safely erase the load op.
11731184
rewriter.eraseOp(loadOp);
11741185
return success();
11751186
}
11761187
};
11771188

11781189
/// Pattern to rewrite vector.store(vector.splat) -> vector/memref.store.
1190+
///
1191+
/// Example:
11791192
/// ```
11801193
/// %0 = vector.splat %arg2 : vector<1xf32>
11811194
/// vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
@@ -1205,8 +1218,9 @@ class StoreFromSplat final : public OpRewritePattern<vector::StoreOp> {
12051218

12061219
Operation *splat = op.getValueToStore().getDefiningOp();
12071220
if (!isa_and_present<vector::BroadcastOp, vector::SplatOp>(splat))
1208-
return rewriter.notifyMatchFailure(op, "not a splat");
1221+
return rewriter.notifyMatchFailure(op, "neither a splat nor a broadcast");
12091222

1223+
// Checking for single use so we can remove splat.
12101224
if (!splat->hasOneUse())
12111225
return rewriter.notifyMatchFailure(op, "expected single op use");
12121226

mlir/test/Dialect/Vector/vector-sink.mlir

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -587,9 +587,9 @@ func.func @extract_load_vec_high_rank(%arg0: memref<?x?x?xf32>, %arg1: index, %a
587587
return %1 : vector<4xf32>
588588
}
589589

590-
// CHECK-LABEL: @negative_load_scalar_from_vec_memref
590+
// CHECK-LABEL: @negative_extract_load_scalar_from_vec_memref
591591
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xvector<4xf32>>, %[[ARG1:.*]]: index)
592-
func.func @negative_load_scalar_from_vec_memref(%arg0: memref<?xvector<4xf32>>, %arg1: index) -> f32 {
592+
func.func @negative_extract_load_scalar_from_vec_memref(%arg0: memref<?xvector<4xf32>>, %arg1: index) -> f32 {
593593
// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xvector<4xf32>>, vector<4xf32>
594594
// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<4xf32>
595595
// CHECK: return %[[EXT]] : f32
@@ -609,9 +609,9 @@ func.func @negative_extract_load_no_single_use(%arg0: memref<?xf32>, %arg1: inde
609609
return %1, %0 : f32, vector<4xf32>
610610
}
611611

612-
// CHECK-LABEL: @negative_load_scalable
612+
// CHECK-LABEL: @negative_extract_load_scalable
613613
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
614-
func.func @negative_load_scalable(%arg0: memref<?xf32>, %arg1: index) -> f32 {
614+
func.func @negative_extract_load_scalable(%arg0: memref<?xf32>, %arg1: index) -> f32 {
615615
// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<[1]xf32>
616616
// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<[1]xf32>
617617
// CHECK: return %[[EXT]] : f32
@@ -620,17 +620,6 @@ func.func @negative_load_scalable(%arg0: memref<?xf32>, %arg1: index) -> f32 {
620620
return %1 : f32
621621
}
622622

623-
// CHECK-LABEL: @negative_extract_load_unsupported_ranks
624-
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
625-
func.func @negative_extract_load_unsupported_ranks(%arg0: memref<?xf32>, %arg1: index) -> vector<4xf32> {
626-
// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<2x4xf32>
627-
// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][1] : vector<4xf32> from vector<2x4xf32>
628-
// CHECK: return %[[EXT]] : vector<4xf32>
629-
%0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<2x4xf32>
630-
%1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
631-
return %1 : vector<4xf32>
632-
}
633-
634623
//-----------------------------------------------------------------------------
635624
// [Pattern: StoreFromSplat]
636625
//-----------------------------------------------------------------------------
@@ -653,9 +642,9 @@ func.func @store_broadcast(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
653642
return
654643
}
655644

656-
// CHECK-LABEL: @store_broadcast_1d_2d
645+
// CHECK-LABEL: @store_broadcast_1d_to_2d
657646
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: vector<1xf32>)
658-
func.func @store_broadcast_1d_2d(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index, %arg3: vector<1xf32>) {
647+
func.func @store_broadcast_1d_to_2d(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index, %arg3: vector<1xf32>) {
659648
// CHECK: vector.store %[[ARG3]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<?x?xf32>, vector<1xf32>
660649
%0 = vector.broadcast %arg3 : vector<1xf32> to vector<1x1xf32>
661650
vector.store %0, %arg0[%arg1, %arg2] : memref<?x?xf32>, vector<1x1xf32>
@@ -682,9 +671,9 @@ func.func @negative_store_vec_memref(%arg0: memref<?xvector<1xf32>>, %arg1: inde
682671
return
683672
}
684673

685-
// CHECK-LABEL: @negative_store_non_1
674+
// CHECK-LABEL: @negative_store_more_than_one_element
686675
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
687-
func.func @negative_store_non_1(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
676+
func.func @negative_store_more_than_one_element(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
688677
// CHECK: %[[RES:.*]] = vector.splat %[[ARG2]] : vector<4xf32>
689678
// CHECK: vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<4xf32>
690679
%0 = vector.splat %arg2 : vector<4xf32>

0 commit comments

Comments
 (0)