Skip to content

Commit e9504c5

Browse files
authored
[mlir][vector] Add tests for populateSinkVectorOpsPatterns (2/N) (#122338)
Adds tests for scalable vectors in: * "vector-sink.mlir". This test file exercises patterns included in `populateSinkVectorOpsPatterns`: * `ReorderElementwiseOpsOnBroadcast`, * `ReorderCastOpsOnBroadcast`, * `ReorderElementwiseOpsOnTranspose`. This PR focuses on adding tests for the latter two patterns (`ReorderCastOpsOnBroadcast` and `ReorderElementwiseOpsOnTranspose`). Tests for `ReorderElementwiseOpsOnBroadcast` were added in #102286. Please note that in PR #102856, I renamed: * `populateSinkVectorBroadcastPatterns`, to * `populateSinkVectorOpsPatterns`.
1 parent c82a6a0 commit e9504c5

File tree

1 file changed

+103
-0
lines changed

1 file changed

+103
-0
lines changed

mlir/test/Dialect/Vector/vector-sink.mlir

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,16 @@ func.func @broadcast_vector_extsi(%a : vector<4xi8>) -> vector<2x4xi32> {
228228

229229
// -----
230230

231+
func.func @broadcast_vector_extsi_scalable(%a : vector<[4]xi8>) -> vector<2x[4]xi32> {
232+
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<[4]xi8> to vector<[4]xi32>
233+
// CHECK: vector.broadcast %[[EXT:.+]] : vector<[4]xi32> to vector<2x[4]xi32>
234+
%b = vector.broadcast %a : vector<[4]xi8> to vector<2x[4]xi8>
235+
%r = arith.extsi %b : vector<2x[4]xi8> to vector<2x[4]xi32>
236+
return %r : vector<2x[4]xi32>
237+
}
238+
239+
// -----
240+
231241
func.func @broadcast_scalar_extsi(%a : i8) -> vector<2x4xi32> {
232242
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : i8 to i32
233243
// CHECK: vector.broadcast %[[EXT]] : i32 to vector<2x4xi32>
@@ -236,6 +246,16 @@ func.func @broadcast_scalar_extsi(%a : i8) -> vector<2x4xi32> {
236246
return %r : vector<2x4xi32>
237247
}
238248

249+
// -----
250+
251+
func.func @broadcast_scalar_extsi_scalable(%a : i8) -> vector<2x[4]xi32> {
252+
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : i8 to i32
253+
// CHECK: vector.broadcast %[[EXT]] : i32 to vector<2x[4]xi32>
254+
%b = vector.broadcast %a : i8 to vector<2x[4]xi8>
255+
%r = arith.extsi %b : vector<2x[4]xi8> to vector<2x[4]xi32>
256+
return %r : vector<2x[4]xi32>
257+
}
258+
239259
//===----------------------------------------------------------------------===//
240260
// [Pattern: ReorderElementwiseOpsOnTranspose]
241261
//===----------------------------------------------------------------------===//
@@ -250,6 +270,16 @@ func.func @transpose_extsi(%a : vector<4x2xi8>) -> vector<2x4xi32> {
250270

251271
// -----
252272

273+
func.func @transpose_extsi_scalable(%a : vector<[4]x2xi8>) -> vector<2x[4]xi32> {
274+
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<[4]x2xi8> to vector<[4]x2xi32>
275+
// CHECK: vector.transpose %[[EXT]], [1, 0] : vector<[4]x2xi32> to vector<2x[4]xi32>
276+
%b = vector.transpose %a, [1, 0]: vector<[4]x2xi8> to vector<2x[4]xi8>
277+
%r = arith.extsi %b : vector<2x[4]xi8> to vector<2x[4]xi32>
278+
return %r : vector<2x[4]xi32>
279+
}
280+
281+
// -----
282+
253283
// CHECK-LABEL: func @transpose_elementwise_same_type
254284
// CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
255285
// CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<4x2xf32>
@@ -265,6 +295,21 @@ func.func @transpose_elementwise_same_type(%a : vector<4x2xf32>, %b : vector<4x2
265295

266296
// -----
267297

298+
// CHECK-LABEL: func @transpose_elementwise_same_type_scalable
299+
// CHECK-SAME: (%[[A:.+]]: vector<[4]x2xf32>, %[[B:.+]]: vector<[4]x2xf32>)
300+
// CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<[4]x2xf32>
301+
// CHECK: %[[T:.+]] = vector.transpose %[[ADD]], [1, 0]
302+
// CHECK: return %[[T]]
303+
304+
func.func @transpose_elementwise_same_type_scalable(%a : vector<[4]x2xf32>, %b : vector<[4]x2xf32>) -> vector<2x[4]xf32> {
305+
%at = vector.transpose %a, [1, 0]: vector<[4]x2xf32> to vector<2x[4]xf32>
306+
%bt = vector.transpose %b, [1, 0]: vector<[4]x2xf32> to vector<2x[4]xf32>
307+
%r = arith.addf %at, %bt : vector<2x[4]xf32>
308+
return %r : vector<2x[4]xf32>
309+
}
310+
311+
// -----
312+
268313
// CHECK-LABEL: func @transpose_elementwise_diff_operand_types
269314
// CHECK-SAME: (%[[COND:.+]]: vector<4x2xi1>, %[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
270315
// CHECK: %[[S:.+]] = arith.select %[[COND]], %[[A]], %[[B]] : vector<4x2xi1>, vector<4x2xf32>
@@ -280,6 +325,21 @@ func.func @transpose_elementwise_diff_operand_types(%cond: vector<4x2xi1>, %a :
280325

281326
// -----
282327

328+
// CHECK-LABEL: func @transpose_elementwise_diff_operand_types_scalable
329+
// CHECK-SAME: (%[[COND:.+]]: vector<[4]x2xi1>, %[[A:.+]]: vector<[4]x2xf32>, %[[B:.+]]: vector<[4]x2xf32>)
330+
// CHECK: %[[S:.+]] = arith.select %[[COND]], %[[A]], %[[B]] : vector<[4]x2xi1>, vector<[4]x2xf32>
331+
// CHECK: %[[T:.+]] = vector.transpose %[[S]], [1, 0] : vector<[4]x2xf32> to vector<2x[4]xf32>
332+
// CHECK: return %[[T]]
333+
func.func @transpose_elementwise_diff_operand_types_scalable(%cond: vector<[4]x2xi1>, %a : vector<[4]x2xf32>, %b : vector<[4]x2xf32>) -> vector<2x[4]xf32> {
334+
%condt = vector.transpose %cond, [1, 0]: vector<[4]x2xi1> to vector<2x[4]xi1>
335+
%at = vector.transpose %a, [1, 0]: vector<[4]x2xf32> to vector<2x[4]xf32>
336+
%bt = vector.transpose %b, [1, 0]: vector<[4]x2xf32> to vector<2x[4]xf32>
337+
%r = arith.select %condt, %at, %bt : vector<2x[4]xi1>, vector<2x[4]xf32>
338+
return %r : vector<2x[4]xf32>
339+
}
340+
341+
// -----
342+
283343
// CHECK-LABEL: func @transpose_elementwise_diff_operand_result_type
284344
// CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
285345
// CHECK: %[[CMP:.+]] = arith.cmpf olt, %[[A]], %[[B]] : vector<4x2xf32>
@@ -294,6 +354,20 @@ func.func @transpose_elementwise_diff_operand_result_type(%a : vector<4x2xf32>,
294354

295355
// -----
296356

357+
// CHECK-LABEL: func @transpose_elementwise_diff_operand_result_type_scalable
358+
// CHECK-SAME: (%[[A:.+]]: vector<[4]x2xf32>, %[[B:.+]]: vector<[4]x2xf32>)
359+
// CHECK: %[[CMP:.+]] = arith.cmpf olt, %[[A]], %[[B]] : vector<[4]x2xf32>
360+
// CHECK: %[[T:.+]] = vector.transpose %[[CMP]], [1, 0] : vector<[4]x2xi1> to vector<2x[4]xi1>
361+
// CHECK: return %[[T]]
362+
func.func @transpose_elementwise_diff_operand_result_type_scalable(%a : vector<[4]x2xf32>, %b : vector<[4]x2xf32>) -> vector<2x[4]xi1> {
363+
%at = vector.transpose %a, [1, 0]: vector<[4]x2xf32> to vector<2x[4]xf32>
364+
%bt = vector.transpose %b, [1, 0]: vector<[4]x2xf32> to vector<2x[4]xf32>
365+
%r = arith.cmpf olt, %at, %bt : vector<2x[4]xf32>
366+
return %r : vector<2x[4]xi1>
367+
}
368+
369+
// -----
370+
297371
// CHECK-LABEL: func @transpose_elementwise_splat_constant
298372
// CHECK-SAME: (%[[A:.+]]: vector<4x6x3x2xf32>)
299373
// CHECK: %[[B:.+]] = arith.constant dense<5.000000e+00> : vector<4x6x3x2xf32>
@@ -310,6 +384,22 @@ func.func @transpose_elementwise_splat_constant(%a : vector<4x6x3x2xf32>) -> vec
310384

311385
// -----
312386

387+
// CHECK-LABEL: func @transpose_elementwise_splat_constant_scalable
388+
// CHECK-SAME: (%[[A:.+]]: vector<[4]x6x3x2xf32>)
389+
// CHECK: %[[B:.+]] = arith.constant dense<5.000000e+00> : vector<[4]x6x3x2xf32>
390+
// CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<[4]x6x3x2xf32>
391+
// CHECK: %[[T:.+]] = vector.transpose %[[ADD]], [1, 0, 3, 2] : vector<[4]x6x3x2xf32> to vector<6x[4]x2x3xf32>
392+
// CHECK: return %[[T:.+]] : vector<6x[4]x2x3xf32>
393+
394+
func.func @transpose_elementwise_splat_constant_scalable(%a : vector<[4]x6x3x2xf32>) -> vector<6x[4]x2x3xf32> {
395+
%b = arith.constant dense<5.0> : vector<6x[4]x2x3xf32>
396+
%at = vector.transpose %a, [1, 0, 3, 2]: vector<[4]x6x3x2xf32> to vector<6x[4]x2x3xf32>
397+
%r = arith.addf %at, %b : vector<6x[4]x2x3xf32>
398+
return %r : vector<6x[4]x2x3xf32>
399+
}
400+
401+
// -----
402+
313403
// CHECK-LABEL: func @transpose_elementwise_diff_map
314404
// CHECK: vector.transpose
315405
// CHECK: vector.transpose
@@ -320,3 +410,16 @@ func.func @transpose_elementwise_diff_map(%a : vector<4x6x3x2xf32>, %b: vector<6
320410
%r = arith.addf %at, %bt : vector<6x4x2x3xf32>
321411
return %r : vector<6x4x2x3xf32>
322412
}
413+
414+
// -----
415+
416+
// CHECK-LABEL: func @transpose_elementwise_diff_map_scalable
417+
// CHECK: vector.transpose
418+
// CHECK: vector.transpose
419+
// CHECK: arith.addf
420+
func.func @transpose_elementwise_diff_map_scalable(%a : vector<[4]x6x3x2xf32>, %b: vector<6x2x[4]x3xf32>) -> vector<6x[4]x2x3xf32> {
421+
%at = vector.transpose %a, [1, 0, 3, 2]: vector<[4]x6x3x2xf32> to vector<6x[4]x2x3xf32>
422+
%bt = vector.transpose %b, [0, 2, 1, 3]: vector<6x2x[4]x3xf32> to vector<6x[4]x2x3xf32>
423+
%r = arith.addf %at, %bt : vector<6x[4]x2x3xf32>
424+
return %r : vector<6x[4]x2x3xf32>
425+
}

0 commit comments

Comments
 (0)