Skip to content

Commit 85c5f4d

Browse files
GroverkssNoumanAmir657
authored andcommitted
[mlir][Vector] Fix vector.insert folder for scalar to 0-d inserts (llvm#113828)
The current vector.insert folder tries to replace a scalar with a 0-rank vector. This patch fixes this crash by not folding unless they types of the result and replacement are same.
1 parent 39d51e7 commit 85c5f4d

File tree

2 files changed

+41
-15
lines changed

2 files changed

+41
-15
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2951,11 +2951,11 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
29512951
InsertOpConstantFolder>(context);
29522952
}
29532953

2954-
// Eliminates insert operations that produce values identical to their source
2955-
// value. This happens when the source and destination vectors have identical
2956-
// sizes.
29572954
OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
2958-
if (getNumIndices() == 0)
2955+
// Fold "vector.insert %v, %dest [] : vector<2x2xf32> from vector<2x2xf32>" to
2956+
// %v. Note: Do not fold "vector.insert %v, %dest [] : f32 into vector<f32>"
2957+
// (type mismatch).
2958+
if (getNumIndices() == 0 && getSourceType() == getType())
29592959
return getSource();
29602960
return {};
29612961
}

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -800,6 +800,43 @@ func.func @fold_extract_shapecast_to_shapecast(%arg0 : vector<3x4xf32>) -> vecto
800800

801801
// -----
802802

803+
// CHECK-LABEL: func @extract_no_fold_scalar_to_0d(
804+
// CHECK-SAME: %[[v:.*]]: vector<f32>)
805+
// CHECK: %[[extract:.*]] = vector.extract %[[v]][] : f32 from vector<f32>
806+
// CHECK: return %[[extract]]
807+
func.func @extract_no_fold_scalar_to_0d(%v: vector<f32>) -> f32 {
808+
%0 = vector.extract %v[] : f32 from vector<f32>
809+
return %0 : f32
810+
}
811+
812+
// -----
813+
814+
// CHECK-LABEL: func @insert_fold_same_rank(
815+
// CHECK-SAME: %[[v:.*]]: vector<2x2xf32>)
816+
// CHECK: %[[CST:.+]] = arith.constant
817+
// CHECK-SAME: : vector<2x2xf32>
818+
// CHECK-NOT: vector.insert
819+
// CHECK: return %[[CST]]
820+
func.func @insert_fold_same_rank(%v: vector<2x2xf32>) -> vector<2x2xf32> {
821+
%cst = arith.constant dense<0.000000e+00> : vector<2x2xf32>
822+
%0 = vector.insert %cst, %v [] : vector<2x2xf32> into vector<2x2xf32>
823+
return %0 : vector<2x2xf32>
824+
}
825+
826+
// -----
827+
828+
// CHECK-LABEL: func @insert_no_fold_scalar_to_0d(
829+
// CHECK-SAME: %[[v:.*]]: vector<f32>)
830+
// CHECK: %[[extract:.*]] = vector.insert %{{.*}}, %[[v]] [] : f32 into vector<f32>
831+
// CHECK: return %[[extract]]
832+
func.func @insert_no_fold_scalar_to_0d(%v: vector<f32>) -> vector<f32> {
833+
%cst = arith.constant 0.000000e+00 : f32
834+
%0 = vector.insert %cst, %v [] : f32 into vector<f32>
835+
return %0 : vector<f32>
836+
}
837+
838+
// -----
839+
803840
// CHECK-LABEL: dont_fold_expand_collapse
804841
// CHECK: %[[A:.*]] = vector.shape_cast %{{.*}} : vector<1x1x64xf32> to vector<1x1x8x8xf32>
805842
// CHECK: %[[B:.*]] = vector.shape_cast %{{.*}} : vector<1x1x8x8xf32> to vector<8x8xf32>
@@ -2606,17 +2643,6 @@ func.func @rank_1_shuffle_to_interleave(%arg0: vector<6xi32>, %arg1: vector<6xi3
26062643

26072644
// -----
26082645

2609-
// CHECK-LABEL: func @extract_from_0d_regression(
2610-
// CHECK-SAME: %[[v:.*]]: vector<f32>)
2611-
// CHECK: %[[extract:.*]] = vector.extract %[[v]][] : f32 from vector<f32>
2612-
// CHECK: return %[[extract]]
2613-
func.func @extract_from_0d_regression(%v: vector<f32>) -> f32 {
2614-
%0 = vector.extract %v[] : f32 from vector<f32>
2615-
return %0 : f32
2616-
}
2617-
2618-
// -----
2619-
26202646
// CHECK-LABEL: func @extract_from_0d_splat_broadcast_regression(
26212647
// CHECK-SAME: %[[a:.*]]: f32, %[[b:.*]]: vector<f32>, %[[c:.*]]: vector<2xf32>)
26222648
func.func @extract_from_0d_splat_broadcast_regression(%a: f32, %b: vector<f32>, %c: vector<2xf32>) -> (f32, f32, f32, f32, f32, vector<6x7xf32>, vector<3xf32>) {

0 commit comments

Comments
 (0)