Skip to content

Commit 9db316c

Browse files
committed
[mlir][vector] Add tests for populateSinkVectorBroadcastPatterns (1/n)
Adds tests for scalable vectors in: * sink-vector-broadcast.mlir This test file excercises patterns grouped under `populateSinkVectorBroadcastPatterns`, which includes: * `ReorderElementwiseOpsOnBroadcast`, * `ReorderCastOpsOnBroadcast`. Right now there are only tests for the former. However, I've noticed that "vector-reduce-to-contract.mlir" contains tests for the latter and I've left a few TODOs to group these tests back together in one file. Additionally, added some helpful `notifyMatchFailure` messages in `ReorderElementwiseOpsOnBroadcast`.
1 parent dd094b2 commit 9db316c

File tree

3 files changed

+121
-26
lines changed

3 files changed

+121
-26
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -979,15 +979,18 @@ struct ReorderElementwiseOpsOnBroadcast final
979979
if (!llvm::isa<ShapedType>(op->getResults()[0].getType()))
980980
return failure();
981981
if (!OpTrait::hasElementwiseMappableTraits(op))
982+
return rewriter.notifyMatchFailure(
983+
op, "Op doesn't have ElementwiseMappableTraits");
984+
if (op->getNumOperands() == 0)
982985
return failure();
983-
if (op->getNumOperands() == 0 ||
984-
op->getResults()[0].getType() != op->getOperand(0).getType()) {
985-
return failure();
986-
}
987-
// Avoid operations that only accept vector types, since broadcast
988-
// source might be scalar types.
986+
if (op->getResults()[0].getType() != op->getOperand(0).getType())
987+
return rewriter.notifyMatchFailure(op,
988+
"result and operand type mismatch");
989989
if (isa<vector::FMAOp>(op)) {
990-
return failure();
990+
return rewriter.notifyMatchFailure(
991+
op,
992+
"Op only accepts vector types - not supported as broadcast source "
993+
"might be a scalar");
991994
}
992995

993996
// Get the type of the lhs operand

mlir/test/Dialect/Vector/sink-vector-broadcast.mlir

Lines changed: 101 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,32 @@
11
// RUN: mlir-opt %s -test-sink-vector-broadcast -split-input-file | FileCheck %s
22

3+
//-----------------------------------------------------------------------------
4+
// [Pattern: ReorderElementwiseOpsOnBroadcast]
5+
//-----------------------------------------------------------------------------
6+
37
// CHECK-LABEL: func.func @broadcast_scalar_with_bcast(
48
// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index) -> vector<1x4xindex> {
59
// CHECK: %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[ARG_1]] : index
610
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
711
// CHECK: return %[[BCAST]] : vector<1x4xindex>
812

9-
func.func @broadcast_scalar_with_bcast( %arg1: index, %arg2: index) -> vector<1x4xindex> {
13+
func.func @broadcast_scalar_with_bcast(%arg1: index, %arg2: index) -> vector<1x4xindex> {
1014
%0 = vector.broadcast %arg1 : index to vector<1x4xindex>
1115
%1 = vector.broadcast %arg2 : index to vector<1x4xindex>
12-
%2 = arith.addi %0, %1 : vector<1x4xindex>
13-
return %2 : vector<1x4xindex>
16+
%2 = arith.addi %0, %1 : vector<1x4xindex> return %2 : vector<1x4xindex>
17+
}
18+
19+
// CHECK-LABEL: func.func @broadcast_scalar_with_bcast_scalable(
20+
// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index) -> vector<1x[4]xindex> {
21+
// CHECK: %[[ADD:.*]] = arith.addi %[[ARG_0]], %[[ARG_1]] : index
22+
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x[4]xindex>
23+
// CHECK: return %[[BCAST]] : vector<1x[4]xindex>
24+
25+
func.func @broadcast_scalar_with_bcast_scalable(%arg1: index, %arg2: index) -> vector<1x[4]xindex> {
26+
%0 = vector.broadcast %arg1 : index to vector<1x[4]xindex>
27+
%1 = vector.broadcast %arg2 : index to vector<1x[4]xindex>
28+
%2 = arith.addi %0, %1 : vector<1x[4]xindex>
29+
return %2 : vector<1x[4]xindex>
1430
}
1531

1632
// -----
@@ -21,13 +37,26 @@ func.func @broadcast_scalar_with_bcast( %arg1: index, %arg2: index) -> vector<1x
2137
// CHECK: %[[ADD:.*]] = arith.addi %[[ARG1]], %[[ARG2]] : index
2238
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x4xindex>
2339
// CHECK: return %[[BCAST]] : vector<1x4xindex>
24-
func.func @broadcast_scalar_with_bcast_and_splat( %arg1: index, %arg2: index) -> vector<1x4xindex> {
40+
func.func @broadcast_scalar_with_bcast_and_splat(%arg1: index, %arg2: index) -> vector<1x4xindex> {
2541
%0 = vector.splat %arg1 : vector<1x4xindex>
2642
%1 = vector.broadcast %arg2 : index to vector<1x4xindex>
2743
%2 = arith.addi %0, %1 : vector<1x4xindex>
2844
return %2 : vector<1x4xindex>
2945
}
3046

47+
// CHECK-LABEL: func.func @broadcast_scalar_with_bcast_and_splat_scalable(
48+
// CHECK-SAME: %[[ARG1:.*]]: index,
49+
// CHECK-SAME: %[[ARG2:.*]]: index) -> vector<1x[4]xindex> {
50+
// CHECK: %[[ADD:.*]] = arith.addi %[[ARG1]], %[[ARG2]] : index
51+
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADD]] : index to vector<1x[4]xindex>
52+
// CHECK: return %[[BCAST]] : vector<1x[4]xindex>
53+
func.func @broadcast_scalar_with_bcast_and_splat_scalable(%arg1: index, %arg2: index) -> vector<1x[4]xindex> {
54+
%0 = vector.splat %arg1 : vector<1x[4]xindex>
55+
%1 = vector.broadcast %arg2 : index to vector<1x[4]xindex>
56+
%2 = arith.addi %0, %1 : vector<1x[4]xindex>
57+
return %2 : vector<1x[4]xindex>
58+
}
59+
3160
// -----
3261

3362
// CHECK-LABEL: func.func @broadcast_vector(
@@ -37,13 +66,27 @@ func.func @broadcast_scalar_with_bcast_and_splat( %arg1: index, %arg2: index) ->
3766
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADDF]] : vector<4xf32> to vector<3x4xf32>
3867
// CHECK: return %[[BCAST]] : vector<3x4xf32>
3968

40-
func.func @broadcast_vector( %arg1: vector<4xf32>, %arg2: vector<4xf32>) -> vector<3x4xf32> {
69+
func.func @broadcast_vector(%arg1: vector<4xf32>, %arg2: vector<4xf32>) -> vector<3x4xf32> {
4170
%arg1_bcast = vector.broadcast %arg1 : vector<4xf32> to vector<3x4xf32>
4271
%arg2_bcast = vector.broadcast %arg2 : vector<4xf32> to vector<3x4xf32>
4372
%2 = arith.addf %arg1_bcast, %arg2_bcast : vector<3x4xf32>
4473
return %2 : vector<3x4xf32>
4574
}
4675

76+
// CHECK-LABEL: func.func @broadcast_vector_scalable(
77+
// CHECK-SAME: %[[ARG_0:.*]]: vector<[4]xf32>,
78+
// CHECK-SAME: %[[ARG_1:.*]]: vector<[4]xf32>) -> vector<3x[4]xf32> {
79+
// CHECK: %[[ADDF:.*]] = arith.addf %[[ARG_0]], %[[ARG_1]] : vector<[4]xf32>
80+
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ADDF]] : vector<[4]xf32> to vector<3x[4]xf32>
81+
// CHECK: return %[[BCAST]] : vector<3x[4]xf32>
82+
83+
func.func @broadcast_vector_scalable(%arg1: vector<[4]xf32>, %arg2: vector<[4]xf32>) -> vector<3x[4]xf32> {
84+
%arg1_bcast = vector.broadcast %arg1 : vector<[4]xf32> to vector<3x[4]xf32>
85+
%arg2_bcast = vector.broadcast %arg2 : vector<[4]xf32> to vector<3x[4]xf32>
86+
%2 = arith.addf %arg1_bcast, %arg2_bcast : vector<3x[4]xf32>
87+
return %2 : vector<3x[4]xf32>
88+
}
89+
4790
// -----
4891

4992
// CHECK-LABEL: func.func @broadcast_scalar_and_vec(
@@ -53,13 +96,27 @@ func.func @broadcast_vector( %arg1: vector<4xf32>, %arg2: vector<4xf32>) -> vect
5396
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG2]] : vector<4xindex> to vector<1x4xindex>
5497
// CHECK: %[[ADD:.*]] = arith.addi %[[SPLAT]], %[[BCAST]] : vector<1x4xindex>
5598
// CHECK: return %[[ADD]] : vector<1x4xindex>
56-
func.func @broadcast_scalar_and_vec( %arg1: index, %arg2: vector<4xindex>) -> vector<1x4xindex> {
99+
func.func @broadcast_scalar_and_vec(%arg1: index, %arg2: vector<4xindex>) -> vector<1x4xindex> {
57100
%0 = vector.splat %arg1 : vector<1x4xindex>
58101
%1 = vector.broadcast %arg2 : vector<4xindex> to vector<1x4xindex>
59102
%2 = arith.addi %0, %1 : vector<1x4xindex>
60103
return %2 : vector<1x4xindex>
61104
}
62105

106+
// CHECK-LABEL: func.func @broadcast_scalar_and_vec_scalable(
107+
// CHECK-SAME: %[[ARG1:.*]]: index,
108+
// CHECK-SAME: %[[ARG2:.*]]: vector<[4]xindex>) -> vector<1x[4]xindex> {
109+
// CHECK: %[[SPLAT:.*]] = vector.splat %[[ARG1]] : vector<1x[4]xindex>
110+
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG2]] : vector<[4]xindex> to vector<1x[4]xindex>
111+
// CHECK: %[[ADD:.*]] = arith.addi %[[SPLAT]], %[[BCAST]] : vector<1x[4]xindex>
112+
// CHECK: return %[[ADD]] : vector<1x[4]xindex>
113+
func.func @broadcast_scalar_and_vec_scalable(%arg1: index, %arg2: vector<[4]xindex>) -> vector<1x[4]xindex> {
114+
%0 = vector.splat %arg1 : vector<1x[4]xindex>
115+
%1 = vector.broadcast %arg2 : vector<[4]xindex> to vector<1x[4]xindex>
116+
%2 = arith.addi %0, %1 : vector<1x[4]xindex>
117+
return %2 : vector<1x[4]xindex>
118+
}
119+
63120
// -----
64121

65122
// CHECK-LABEL: func.func @broadcast_vector_and_scalar(
@@ -69,12 +126,25 @@ func.func @broadcast_scalar_and_vec( %arg1: index, %arg2: vector<4xindex>) -> ve
69126
// CHECK: %[[ADD:.*]] = arith.addi %[[BCAST]], %[[ARG_1]] : vector<4xi32>
70127
// CHECK: return %[[ADD]] : vector<4xi32>
71128

72-
func.func @broadcast_vector_and_scalar( %arg1: i32, %arg2: vector<4xi32>) -> vector<4xi32> {
129+
func.func @broadcast_vector_and_scalar(%arg1: i32, %arg2: vector<4xi32>) -> vector<4xi32> {
73130
%arg1_bcast = vector.broadcast %arg1 : i32 to vector<4xi32>
74131
%2 = arith.addi %arg1_bcast, %arg2 : vector<4xi32>
75132
return %2 : vector<4xi32>
76133
}
77134

135+
// CHECK-LABEL: func.func @broadcast_vector_and_scalar_scalable(
136+
// CHECK-SAME: %[[ARG_0:.*]]: i32,
137+
// CHECK-SAME: %[[ARG_1:.*]]: vector<[4]xi32>) -> vector<[4]xi32> {
138+
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[ARG_0]] : i32 to vector<[4]xi32>
139+
// CHECK: %[[ADD:.*]] = arith.addi %[[BCAST]], %[[ARG_1]] : vector<[4]xi32>
140+
// CHECK: return %[[ADD]] : vector<[4]xi32>
141+
142+
func.func @broadcast_vector_and_scalar_scalable(%arg1: i32, %arg2: vector<[4]xi32>) -> vector<[4]xi32> {
143+
%arg1_bcast = vector.broadcast %arg1 : i32 to vector<[4]xi32>
144+
%2 = arith.addi %arg1_bcast, %arg2 : vector<[4]xi32>
145+
return %2 : vector<[4]xi32>
146+
}
147+
78148
// -----
79149

80150
#matmat_accesses = [
@@ -87,40 +157,52 @@ func.func @broadcast_vector_and_scalar( %arg1: i32, %arg2: vector<4xi32>) -> vec
87157
iterator_types = ["parallel", "parallel", "reduction"]
88158
}
89159

90-
// CHECK-LABEL: func.func @broadcast_not_elementwise() -> vector<2x2xf32> {
91-
// CHECK-DAG: %[[VAL_0:.*]] = arith.constant dense<1.000000e+00> : vector<2x2xf32>
92-
// CHECK-DAG: %[[VAL_1:.*]] = arith.constant dense<2.000000e+00> : vector<2x2xf32>
93-
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant dense<3.000000e+00> : vector<2x2xf32>
94-
// CHECK: %[[VAL_3:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
95-
func.func @broadcast_not_elementwise() -> vector<2x2xf32> {
160+
// CHECK-LABEL: func.func @negative_not_elementwise
161+
// CHECK-DAG: %[[F1:.*]] = arith.constant dense<1.000000e+00> : vector<2x2xf32>
162+
// CHECK-DAG: %[[F2:.*]] = arith.constant dense<2.000000e+00> : vector<2x2xf32>
163+
// CHECK-DAG: %[[F3:.*]] = arith.constant dense<3.000000e+00> : vector<2x2xf32>
164+
// CHECK: %[[RES:.*]] = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[F1]], %[[F2]], %[[F3]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
165+
func.func @negative_not_elementwise() -> vector<2x2xf32> {
96166
%f1 = arith.constant 1.0: f32
97167
%f2 = arith.constant 2.0: f32
98168
%f3 = arith.constant 3.0: f32
99169

100170
%A = vector.broadcast %f1 : f32 to vector<2x2xf32>
101171
%B = vector.broadcast %f2 : f32 to vector<2x2xf32>
102172
%C = vector.broadcast %f3 : f32 to vector<2x2xf32>
103-
%mm1 = vector.contract #matmat_trait %A, %B, %C
173+
%res = vector.contract #matmat_trait %A, %B, %C
104174
: vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
105175

106-
return %mm1 : vector<2x2xf32>
176+
return %res : vector<2x2xf32>
107177
}
108178

109-
// CHECK-LABEL: func.func @dont_sink_cmp(
179+
// -----
180+
181+
// The source and the result for arith.cmp have different types - not supported
182+
183+
// CHECK-LABEL: func.func @negative_source_and_result_mismatch
110184
// CHECK: %[[BROADCAST:.+]] = vector.broadcast
111185
// CHECK: %[[RETURN:.+]] = arith.cmpf uno, %[[BROADCAST]], %[[BROADCAST]]
112186
// CHECK: return %[[RETURN]]
113-
func.func @dont_sink_cmp(%arg0 : f32, %arg1 : vector<1xf32>) -> vector<1xi1> {
187+
func.func @negative_source_and_result_mismatch(%arg0 : f32, %arg1 : vector<1xf32>) -> vector<1xi1> {
114188
%0 = vector.broadcast %arg0 : f32 to vector<1xf32>
115189
%1 = arith.cmpf uno, %0, %0 : vector<1xf32>
116190
return %1 : vector<1xi1>
117191
}
118192

119-
// CHECK-LABEL: func.func @dont_sink_fma(
193+
// -----
194+
195+
// vector.fma only supports vectors - currently it's not possible to replace this with e.g.:
196+
// %scalar_res = vector.fma %scalar_1, %scalar2
197+
// %vec_res = vector.broadcast %scalar_res
198+
//
199+
// TODO: It should be possible to support this case
200+
201+
// CHECK-LABEL: func.func @negative_op_only_supports_vectors
120202
// CHECK: %[[BROADCAST:.+]] = vector.broadcast
121203
// CHECK: %[[RESULT:.+]] = vector.fma %[[BROADCAST]]
122204
// CHECK: return %[[RESULT]]
123-
func.func @dont_sink_fma(%arg0 : f32) -> vector<1xf32> {
205+
func.func @negative_op_only_supports_vectors(%arg0 : f32) -> vector<1xf32> {
124206
%0 = vector.broadcast %arg0 : f32 to vector<1xf32>
125207
%1 = vector.fma %0, %0, %0 : vector<1xf32>
126208
return %1 : vector<1xf32>

mlir/test/Dialect/Vector/vector-reduce-to-contract.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,8 +246,12 @@ func.func @contract_broadcast_would_have_no_reduction_dim_pair(%arg0 : vector<1x
246246

247247

248248
//===----------------------------------------------------------------------===//
249+
// [Pattern: ReorderCastOpsOnBroadcast]
250+
//
249251
// Reorder casting ops and vector ops. The casting ops have almost identical
250252
// pattern, so only arith.extsi op is tested.
253+
//
254+
// TODO: Potential duplication with sink-vector-broadcast.mlir
251255
//===----------------------------------------------------------------------===//
252256

253257
// -----
@@ -272,6 +276,11 @@ func.func @broadcast_scalar_extsi(%a : i8) -> vector<2x4xi32> {
272276

273277
// -----
274278

279+
//===----------------------------------------------------------------------===//
280+
// [Pattern: ReorderElementwiseOpsOnTranspose]
281+
//
282+
// TODO: Potential duplication with sink-vector-broadcast.mlir
283+
//===----------------------------------------------------------------------===//
275284
func.func @transpose_extsi(%a : vector<4x2xi8>) -> vector<2x4xi32> {
276285
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<4x2xi8> to vector<4x2xi32>
277286
// CHECK: vector.transpose %[[EXT]], [1, 0] : vector<4x2xi32> to vector<2x4xi32>
@@ -282,6 +291,7 @@ func.func @transpose_extsi(%a : vector<4x2xi8>) -> vector<2x4xi32> {
282291

283292
//===----------------------------------------------------------------------===//
284293
// Reorder elementwise ops and vector ops.
294+
// TODO: Potential duplication with sink-vector-broadcast.mlir
285295
//===----------------------------------------------------------------------===//
286296

287297
// -----

0 commit comments

Comments
 (0)