Skip to content

Commit bdc86b6

Browse files
committed
[mlir][vector] Sink extract/splat into load/store ops
``` vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32> vector.extract %0[1] : f32 from vector<4xf32> ``` Gets converted to: ``` %c1 = arith.constant 1 : index %0 = arith.addi %arg1, %c1 overflow<nsw> : index %1 = memref.load %arg0[%0] : memref<?xf32> ``` ``` %0 = vector.splat %arg2 : vector<1xf32> vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32> ``` Gets converted to: ``` memref.store %arg2, %arg0[%arg1] : memref<?xf32> ```
1 parent 09680dc commit bdc86b6

File tree

7 files changed

+350
-2
lines changed

7 files changed

+350
-2
lines changed

mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -469,8 +469,28 @@ def ApplySinkVectorPatternsOp : Op<Transform_Dialect,
469469
%0 = arith.addf %a, %b : vector<4x2xf32>
470470
%r = vector.transpose %0, [1, 0] : vector<2x4xf32>
471471
```
472-
At the moment, these patterns are limited to vector.broadcast and
473-
vector.transpose.
472+
At the moment, these patterns are limited to vector.broadcast,
473+
vector.transpose and vector.extract.
474+
}];
475+
476+
let assemblyFormat = "attr-dict";
477+
}
478+
479+
def ApplySinkVectorMemPatternsOp : Op<Transform_Dialect,
480+
"apply_patterns.vector.sink_mem_ops",
481+
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
482+
let description = [{
483+
Patterns that remove redundant Vector Ops by merging them with load/store
484+
ops
485+
```
486+
vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
487+
vector.extract %0[1] : f32 from vector<4xf32>
488+
```
489+
Gets converted to:
490+
```
491+
%c1 = arith.constant 1 : index
492+
%0 = arith.addi %arg1, %c1 overflow<nsw> : index
493+
%1 = memref.load %arg0[%0] : memref<?xf32>
474494
}];
475495

476496
let assemblyFormat = "attr-dict";

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,11 @@ void populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
161161
void populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
162162
PatternBenefit benefit = 1);
163163

164+
/// Patterns that remove redundant Vector Ops by re-ordering them with
165+
/// memory Ops:
166+
void populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns,
167+
PatternBenefit benefit = 1);
168+
164169
/// Patterns that fold chained vector reductions. These patterns assume that
165170
/// elementwise operations (e.g., `arith.addf` with vector operands) are
166171
/// cheaper than vector reduction.

mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,11 @@ void transform::ApplySinkVectorPatternsOp::populatePatterns(
212212
vector::populateSinkVectorOpsPatterns(patterns);
213213
}
214214

215+
void transform::ApplySinkVectorMemPatternsOp::populatePatterns(
216+
RewritePatternSet &patterns) {
217+
vector::populateSinkVectorMemOpsPatterns(patterns);
218+
}
219+
215220
//===----------------------------------------------------------------------===//
216221
// Transform op registration
217222
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,6 +1043,127 @@ class ExtractOpFromElementwise final
10431043
}
10441044
};
10451045

1046+
/// Pattern to rewrite vector.extract(vector.load) -> vector/memref.load.
1047+
/// ```
1048+
/// vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
1049+
/// vector.extract %0[1] : f32 from vector<4xf32>
1050+
/// ```
1051+
/// Gets converted to:
1052+
/// ```
1053+
/// %c1 = arith.constant 1 : index
1054+
/// %0 = arith.addi %arg1, %c1 overflow<nsw> : index
1055+
/// %1 = memref.load %arg0[%0] : memref<?xf32>
1056+
/// ```
1057+
class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
1058+
public:
1059+
using OpRewritePattern::OpRewritePattern;
1060+
1061+
LogicalResult matchAndRewrite(vector::ExtractOp op,
1062+
PatternRewriter &rewriter) const override {
1063+
auto loadOp = op.getVector().getDefiningOp<vector::LoadOp>();
1064+
if (!loadOp)
1065+
return rewriter.notifyMatchFailure(op, "not a load op");
1066+
1067+
if (!loadOp->hasOneUse())
1068+
return rewriter.notifyMatchFailure(op, "expected single op use");
1069+
1070+
VectorType memVecType = loadOp.getVectorType();
1071+
if (memVecType.isScalable())
1072+
return rewriter.notifyMatchFailure(op,
1073+
"scalable vectors are not supported");
1074+
1075+
MemRefType memType = loadOp.getMemRefType();
1076+
if (isa<VectorType>(memType.getElementType()))
1077+
return rewriter.notifyMatchFailure(
1078+
op, "memrefs of vectors are not supported");
1079+
1080+
int64_t rankOffset = memType.getRank() - memVecType.getRank();
1081+
if (rankOffset < 0)
1082+
return rewriter.notifyMatchFailure(op, "unsupported ranks combination");
1083+
1084+
auto resVecType = dyn_cast<VectorType>(op.getResult().getType());
1085+
int64_t finalRank = 0;
1086+
if (resVecType)
1087+
finalRank = resVecType.getRank();
1088+
1089+
SmallVector<Value> indices = loadOp.getIndices();
1090+
SmallVector<OpFoldResult> extractPos = op.getMixedPosition();
1091+
1092+
OpBuilder::InsertionGuard g(rewriter);
1093+
rewriter.setInsertionPoint(loadOp);
1094+
Location loc = loadOp.getLoc();
1095+
for (auto i : llvm::seq<int64_t>(rankOffset, indices.size() - finalRank)) {
1096+
OpFoldResult pos = extractPos[i - rankOffset];
1097+
if (isConstantIntValue(pos, 0))
1098+
continue;
1099+
1100+
Value offset = getValueOrCreateConstantIndexOp(rewriter, loc, pos);
1101+
1102+
auto ovf = arith::IntegerOverflowFlags::nsw;
1103+
indices[i] = rewriter.create<arith::AddIOp>(loc, indices[i], offset, ovf);
1104+
}
1105+
1106+
Value base = loadOp.getBase();
1107+
if (resVecType) {
1108+
rewriter.replaceOpWithNewOp<vector::LoadOp>(op, resVecType, base,
1109+
indices);
1110+
} else {
1111+
rewriter.replaceOpWithNewOp<memref::LoadOp>(op, base, indices);
1112+
}
1113+
rewriter.eraseOp(loadOp);
1114+
return success();
1115+
}
1116+
};
1117+
1118+
/// Pattern to rewrite vector.store(vector.splat) -> vector/memref.store.
1119+
/// ```
1120+
/// %0 = vector.splat %arg2 : vector<1xf32>
1121+
/// vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
1122+
/// ```
1123+
/// Gets converted to:
1124+
/// ```
1125+
/// memref.store %arg2, %arg0[%arg1] : memref<?xf32>
1126+
/// ```
1127+
class StoreFromSplat final : public OpRewritePattern<vector::StoreOp> {
1128+
public:
1129+
using OpRewritePattern::OpRewritePattern;
1130+
1131+
LogicalResult matchAndRewrite(vector::StoreOp op,
1132+
PatternRewriter &rewriter) const override {
1133+
VectorType vecType = op.getVectorType();
1134+
if (vecType.isScalable())
1135+
return rewriter.notifyMatchFailure(op,
1136+
"scalable vectors are not supported");
1137+
1138+
if (isa<VectorType>(op.getMemRefType().getElementType()))
1139+
return rewriter.notifyMatchFailure(
1140+
op, "memrefs of vectors are not supported");
1141+
1142+
if (vecType.getNumElements() != 1)
1143+
return rewriter.notifyMatchFailure(
1144+
op, "only 1-element, vectors are supported");
1145+
1146+
Operation *splat = op.getValueToStore().getDefiningOp();
1147+
if (!isa_and_present<vector::BroadcastOp, vector::SplatOp>(splat))
1148+
return rewriter.notifyMatchFailure(op, "not a splat");
1149+
1150+
if (!splat->hasOneUse())
1151+
return rewriter.notifyMatchFailure(op, "expected single op use");
1152+
1153+
Value source = splat->getOperand(0);
1154+
Value base = op.getBase();
1155+
ValueRange indices = op.getIndices();
1156+
1157+
if (isa<VectorType>(source.getType())) {
1158+
rewriter.replaceOpWithNewOp<vector::StoreOp>(op, source, base, indices);
1159+
} else {
1160+
rewriter.replaceOpWithNewOp<memref::StoreOp>(op, source, base, indices);
1161+
}
1162+
rewriter.eraseOp(splat);
1163+
return success();
1164+
}
1165+
};
1166+
10461167
// Helper that returns a vector comparison that constructs a mask:
10471168
// mask = [0,1,..,n-1] + [o,o,..,o] < [b,b,..,b]
10481169
//
@@ -2109,6 +2230,12 @@ void mlir::vector::populateSinkVectorOpsPatterns(RewritePatternSet &patterns,
21092230
patterns.getContext(), benefit);
21102231
}
21112232

2233+
void mlir::vector::populateSinkVectorMemOpsPatterns(RewritePatternSet &patterns,
2234+
PatternBenefit benefit) {
2235+
patterns.add<ExtractOpFromLoad, StoreFromSplat>(patterns.getContext(),
2236+
benefit);
2237+
}
2238+
21122239
void mlir::vector::populateChainedVectorReductionFoldingPatterns(
21132240
RewritePatternSet &patterns, PatternBenefit benefit) {
21142241
patterns.add<ChainedReduction>(patterns.getContext(), benefit);

mlir/test/Dialect/Vector/vector-sink-transform.mlir

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ module attributes {transform.with_named_sequence} {
77
%func = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op
88
transform.apply_patterns to %func {
99
transform.apply_patterns.vector.sink_ops
10+
transform.apply_patterns.vector.sink_mem_ops
1011
} : !transform.any_op
1112
transform.yield
1213
}

mlir/test/Dialect/Vector/vector-sink.mlir

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,3 +513,192 @@ func.func @negative_extract_vec_fma(%arg0: vector<4xf32>, %arg1: vector<4xf32>,
513513
%1 = vector.extract %0[1] : f32 from vector<4xf32>
514514
return %1 : f32
515515
}
516+
517+
//-----------------------------------------------------------------------------
518+
// [Pattern: ExtractOpFromLoad]
519+
//-----------------------------------------------------------------------------
520+
521+
// CHECK-LABEL: @extract_load_scalar
522+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
523+
func.func @extract_load_scalar(%arg0: memref<?xf32>, %arg1: index) -> f32 {
524+
// CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>
525+
// CHECK: return %[[RES]] : f32
526+
%0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
527+
%1 = vector.extract %0[0] : f32 from vector<4xf32>
528+
return %1 : f32
529+
}
530+
531+
// CHECK-LABEL: @extract_load_scalar_non_zero_off
532+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
533+
func.func @extract_load_scalar_non_zero_off(%arg0: memref<?xf32>, %arg1: index) -> f32 {
534+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
535+
// CHECK: %[[OFF:.*]] = arith.addi %[[ARG1]], %[[C1]] overflow<nsw> : index
536+
// CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[OFF]]] : memref<?xf32>
537+
// CHECK: return %[[RES]] : f32
538+
%0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
539+
%1 = vector.extract %0[1] : f32 from vector<4xf32>
540+
return %1 : f32
541+
}
542+
543+
// CHECK-LABEL: @extract_load_scalar_dyn_off
544+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
545+
func.func @extract_load_scalar_dyn_off(%arg0: memref<?xf32>, %arg1: index, %arg2: index) -> f32 {
546+
// CHECK: %[[OFF:.*]] = arith.addi %[[ARG1]], %[[ARG2]] overflow<nsw> : index
547+
// CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[OFF]]] : memref<?xf32>
548+
// CHECK: return %[[RES]] : f32
549+
%0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
550+
%1 = vector.extract %0[%arg2] : f32 from vector<4xf32>
551+
return %1 : f32
552+
}
553+
554+
// CHECK-LABEL: @extract_load_vec
555+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
556+
func.func @extract_load_vec(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) -> vector<4xf32> {
557+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
558+
// CHECK: %[[OFF:.*]] = arith.addi %[[ARG1]], %[[C1]] overflow<nsw> : index
559+
// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[OFF]], %[[ARG2]]] : memref<?x?xf32>, vector<4xf32>
560+
// CHECK: return %[[RES]] : vector<4xf32>
561+
%0 = vector.load %arg0[%arg1, %arg2] : memref<?x?xf32>, vector<2x4xf32>
562+
%1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
563+
return %1 : vector<4xf32>
564+
}
565+
566+
// CHECK-LABEL: @extract_load_scalar_high_rank
567+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
568+
func.func @extract_load_scalar_high_rank(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index) -> f32 {
569+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
570+
// CHECK: %[[OFF:.*]] = arith.addi %[[ARG2]], %[[C1]] overflow<nsw> : index
571+
// CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[OFF]]] : memref<?x?xf32>
572+
// CHECK: return %[[RES]] : f32
573+
%0 = vector.load %arg0[%arg1, %arg2] : memref<?x?xf32>, vector<4xf32>
574+
%1 = vector.extract %0[1] : f32 from vector<4xf32>
575+
return %1 : f32
576+
}
577+
578+
// CHECK-LABEL: @extract_load_vec_high_rank
579+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
580+
func.func @extract_load_vec_high_rank(%arg0: memref<?x?x?xf32>, %arg1: index, %arg2: index, %arg3: index) -> vector<4xf32> {
581+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
582+
// CHECK: %[[OFF:.*]] = arith.addi %[[ARG2]], %[[C1]] overflow<nsw> : index
583+
// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]], %[[OFF]], %[[ARG3]]] : memref<?x?x?xf32>, vector<4xf32>
584+
// CHECK: return %[[RES]] : vector<4xf32>
585+
%0 = vector.load %arg0[%arg1, %arg2, %arg3] : memref<?x?x?xf32>, vector<2x4xf32>
586+
%1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
587+
return %1 : vector<4xf32>
588+
}
589+
590+
// CHECK-LABEL: @negative_load_scalar_from_vec_memref
591+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xvector<4xf32>>, %[[ARG1:.*]]: index)
592+
func.func @negative_load_scalar_from_vec_memref(%arg0: memref<?xvector<4xf32>>, %arg1: index) -> f32 {
593+
// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xvector<4xf32>>, vector<4xf32>
594+
// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<4xf32>
595+
// CHECK: return %[[EXT]] : f32
596+
%0 = vector.load %arg0[%arg1] : memref<?xvector<4xf32>>, vector<4xf32>
597+
%1 = vector.extract %0[0] : f32 from vector<4xf32>
598+
return %1 : f32
599+
}
600+
601+
// CHECK-LABEL: @negative_extract_load_no_single_use
602+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
603+
func.func @negative_extract_load_no_single_use(%arg0: memref<?xf32>, %arg1: index) -> (f32, vector<4xf32>) {
604+
// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<4xf32>
605+
// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<4xf32>
606+
// CHECK: return %[[EXT]], %[[RES]] : f32, vector<4xf32>
607+
%0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<4xf32>
608+
%1 = vector.extract %0[0] : f32 from vector<4xf32>
609+
return %1, %0 : f32, vector<4xf32>
610+
}
611+
612+
// CHECK-LABEL: @negative_load_scalable
613+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
614+
func.func @negative_load_scalable(%arg0: memref<?xf32>, %arg1: index) -> f32 {
615+
// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<[1]xf32>
616+
// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<[1]xf32>
617+
// CHECK: return %[[EXT]] : f32
618+
%0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<[1]xf32>
619+
%1 = vector.extract %0[0] : f32 from vector<[1]xf32>
620+
return %1 : f32
621+
}
622+
623+
// CHECK-LABEL: @negative_extract_load_unsupported_ranks
624+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
625+
func.func @negative_extract_load_unsupported_ranks(%arg0: memref<?xf32>, %arg1: index) -> vector<4xf32> {
626+
// CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<2x4xf32>
627+
// CHECK: %[[EXT:.*]] = vector.extract %[[RES]][1] : vector<4xf32> from vector<2x4xf32>
628+
// CHECK: return %[[EXT]] : vector<4xf32>
629+
%0 = vector.load %arg0[%arg1] : memref<?xf32>, vector<2x4xf32>
630+
%1 = vector.extract %0[1] : vector<4xf32> from vector<2x4xf32>
631+
return %1 : vector<4xf32>
632+
}
633+
634+
//-----------------------------------------------------------------------------
635+
// [Pattern: StoreFromSplat]
636+
//-----------------------------------------------------------------------------
637+
638+
// CHECK-LABEL: @store_splat
639+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
640+
func.func @store_splat(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
641+
// CHECK: memref.store %[[ARG2]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>
642+
%0 = vector.splat %arg2 : vector<1xf32>
643+
vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
644+
return
645+
}
646+
647+
// CHECK-LABEL: @store_broadcast
648+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
649+
func.func @store_broadcast(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
650+
// CHECK: memref.store %[[ARG2]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>
651+
%0 = vector.broadcast %arg2 : f32 to vector<1xf32>
652+
vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
653+
return
654+
}
655+
656+
// CHECK-LABEL: @store_broadcast_1d_2d
657+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: vector<1xf32>)
658+
func.func @store_broadcast_1d_2d(%arg0: memref<?x?xf32>, %arg1: index, %arg2: index, %arg3: vector<1xf32>) {
659+
// CHECK: vector.store %[[ARG3]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<?x?xf32>, vector<1xf32>
660+
%0 = vector.broadcast %arg3 : vector<1xf32> to vector<1x1xf32>
661+
vector.store %0, %arg0[%arg1, %arg2] : memref<?x?xf32>, vector<1x1xf32>
662+
return
663+
}
664+
665+
// CHECK-LABEL: @negative_store_scalable
666+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
667+
func.func @negative_store_scalable(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
668+
// CHECK: %[[RES:.*]] = vector.splat %[[ARG2]] : vector<[1]xf32>
669+
// CHECK: vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<[1]xf32>
670+
%0 = vector.splat %arg2 : vector<[1]xf32>
671+
vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<[1]xf32>
672+
return
673+
}
674+
675+
// CHECK-LABEL: @negative_store_vec_memref
676+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xvector<1xf32>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
677+
func.func @negative_store_vec_memref(%arg0: memref<?xvector<1xf32>>, %arg1: index, %arg2: f32) {
678+
// CHECK: %[[RES:.*]] = vector.splat %[[ARG2]] : vector<1xf32>
679+
// CHECK: vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xvector<1xf32>>, vector<1xf32>
680+
%0 = vector.splat %arg2 : vector<1xf32>
681+
vector.store %0, %arg0[%arg1] : memref<?xvector<1xf32>>, vector<1xf32>
682+
return
683+
}
684+
685+
// CHECK-LABEL: @negative_store_non_1
686+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
687+
func.func @negative_store_non_1(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) {
688+
// CHECK: %[[RES:.*]] = vector.splat %[[ARG2]] : vector<4xf32>
689+
// CHECK: vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<4xf32>
690+
%0 = vector.splat %arg2 : vector<4xf32>
691+
vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<4xf32>
692+
return
693+
}
694+
695+
// CHECK-LABEL: @negative_store_no_single_use
696+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
697+
func.func @negative_store_no_single_use(%arg0: memref<?xf32>, %arg1: index, %arg2: f32) -> vector<1xf32> {
698+
// CHECK: %[[RES:.*]] = vector.splat %[[ARG2]] : vector<1xf32>
699+
// CHECK: vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<1xf32>
700+
// CHECK: return %[[RES:.*]] : vector<1xf32>
701+
%0 = vector.splat %arg2 : vector<1xf32>
702+
vector.store %0, %arg0[%arg1] : memref<?xf32>, vector<1xf32>
703+
return %0 : vector<1xf32>
704+
}

0 commit comments

Comments
 (0)