@@ -228,6 +228,16 @@ func.func @broadcast_vector_extsi(%a : vector<4xi8>) -> vector<2x4xi32> {
228
228
229
229
// -----
230
230
231
+ func.func @broadcast_vector_extsi_scalable (%a : vector <[4 ]xi8 >) -> vector <2 x[4 ]xi32 > {
232
+ // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<[4]xi8> to vector<[4]xi32>
233
+ // CHECK: vector.broadcast %[[EXT:.+]] : vector<[4]xi32> to vector<2x[4]xi32>
234
+ %b = vector.broadcast %a : vector <[4 ]xi8 > to vector <2 x[4 ]xi8 >
235
+ %r = arith.extsi %b : vector <2 x[4 ]xi8 > to vector <2 x[4 ]xi32 >
236
+ return %r : vector <2 x[4 ]xi32 >
237
+ }
238
+
239
+ // -----
240
+
231
241
func.func @broadcast_scalar_extsi (%a : i8 ) -> vector <2 x4 xi32 > {
232
242
// CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : i8 to i32
233
243
// CHECK: vector.broadcast %[[EXT]] : i32 to vector<2x4xi32>
@@ -236,6 +246,16 @@ func.func @broadcast_scalar_extsi(%a : i8) -> vector<2x4xi32> {
236
246
return %r : vector <2 x4 xi32 >
237
247
}
238
248
249
+ // -----
250
+
251
+ func.func @broadcast_scalar_extsi_scalable (%a : i8 ) -> vector <2 x[4 ]xi32 > {
252
+ // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : i8 to i32
253
+ // CHECK: vector.broadcast %[[EXT]] : i32 to vector<2x[4]xi32>
254
+ %b = vector.broadcast %a : i8 to vector <2 x[4 ]xi8 >
255
+ %r = arith.extsi %b : vector <2 x[4 ]xi8 > to vector <2 x[4 ]xi32 >
256
+ return %r : vector <2 x[4 ]xi32 >
257
+ }
258
+
239
259
//===----------------------------------------------------------------------===//
240
260
// [Pattern: ReorderElementwiseOpsOnTranspose]
241
261
//===----------------------------------------------------------------------===//
@@ -250,6 +270,16 @@ func.func @transpose_extsi(%a : vector<4x2xi8>) -> vector<2x4xi32> {
250
270
251
271
// -----
252
272
273
+ func.func @transpose_extsi_scalable (%a : vector <[4 ]x2 xi8 >) -> vector <2 x[4 ]xi32 > {
274
+ // CHECK: %[[EXT:.+]] = arith.extsi %{{.+}} : vector<[4]x2xi8> to vector<[4]x2xi32>
275
+ // CHECK: vector.transpose %[[EXT]], [1, 0] : vector<[4]x2xi32> to vector<2x[4]xi32>
276
+ %b = vector.transpose %a , [1 , 0 ]: vector <[4 ]x2 xi8 > to vector <2 x[4 ]xi8 >
277
+ %r = arith.extsi %b : vector <2 x[4 ]xi8 > to vector <2 x[4 ]xi32 >
278
+ return %r : vector <2 x[4 ]xi32 >
279
+ }
280
+
281
+ // -----
282
+
253
283
// CHECK-LABEL: func @transpose_elementwise_same_type
254
284
// CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
255
285
// CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<4x2xf32>
@@ -265,6 +295,21 @@ func.func @transpose_elementwise_same_type(%a : vector<4x2xf32>, %b : vector<4x2
265
295
266
296
// -----
267
297
298
+ // CHECK-LABEL: func @transpose_elementwise_same_type_scalable
299
+ // CHECK-SAME: (%[[A:.+]]: vector<[4]x2xf32>, %[[B:.+]]: vector<[4]x2xf32>)
300
+ // CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<[4]x2xf32>
301
+ // CHECK: %[[T:.+]] = vector.transpose %[[ADD]], [1, 0]
302
+ // CHECK: return %[[T]]
303
+
304
+ func.func @transpose_elementwise_same_type_scalable (%a : vector <[4 ]x2 xf32 >, %b : vector <[4 ]x2 xf32 >) -> vector <2 x[4 ]xf32 > {
305
+ %at = vector.transpose %a , [1 , 0 ]: vector <[4 ]x2 xf32 > to vector <2 x[4 ]xf32 >
306
+ %bt = vector.transpose %b , [1 , 0 ]: vector <[4 ]x2 xf32 > to vector <2 x[4 ]xf32 >
307
+ %r = arith.addf %at , %bt : vector <2 x[4 ]xf32 >
308
+ return %r : vector <2 x[4 ]xf32 >
309
+ }
310
+
311
+ // -----
312
+
268
313
// CHECK-LABEL: func @transpose_elementwise_diff_operand_types
269
314
// CHECK-SAME: (%[[COND:.+]]: vector<4x2xi1>, %[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
270
315
// CHECK: %[[S:.+]] = arith.select %[[COND]], %[[A]], %[[B]] : vector<4x2xi1>, vector<4x2xf32>
@@ -280,6 +325,21 @@ func.func @transpose_elementwise_diff_operand_types(%cond: vector<4x2xi1>, %a :
280
325
281
326
// -----
282
327
328
+ // CHECK-LABEL: func @transpose_elementwise_diff_operand_types_scalable
329
+ // CHECK-SAME: (%[[COND:.+]]: vector<[4]x2xi1>, %[[A:.+]]: vector<[4]x2xf32>, %[[B:.+]]: vector<[4]x2xf32>)
330
+ // CHECK: %[[S:.+]] = arith.select %[[COND]], %[[A]], %[[B]] : vector<[4]x2xi1>, vector<[4]x2xf32>
331
+ // CHECK: %[[T:.+]] = vector.transpose %[[S]], [1, 0] : vector<[4]x2xf32> to vector<2x[4]xf32>
332
+ // CHECK: return %[[T]]
333
+ func.func @transpose_elementwise_diff_operand_types_scalable (%cond: vector <[4 ]x2 xi1 >, %a : vector <[4 ]x2 xf32 >, %b : vector <[4 ]x2 xf32 >) -> vector <2 x[4 ]xf32 > {
334
+ %condt = vector.transpose %cond , [1 , 0 ]: vector <[4 ]x2 xi1 > to vector <2 x[4 ]xi1 >
335
+ %at = vector.transpose %a , [1 , 0 ]: vector <[4 ]x2 xf32 > to vector <2 x[4 ]xf32 >
336
+ %bt = vector.transpose %b , [1 , 0 ]: vector <[4 ]x2 xf32 > to vector <2 x[4 ]xf32 >
337
+ %r = arith.select %condt , %at , %bt : vector <2 x[4 ]xi1 >, vector <2 x[4 ]xf32 >
338
+ return %r : vector <2 x[4 ]xf32 >
339
+ }
340
+
341
+ // -----
342
+
283
343
// CHECK-LABEL: func @transpose_elementwise_diff_operand_result_type
284
344
// CHECK-SAME: (%[[A:.+]]: vector<4x2xf32>, %[[B:.+]]: vector<4x2xf32>)
285
345
// CHECK: %[[CMP:.+]] = arith.cmpf olt, %[[A]], %[[B]] : vector<4x2xf32>
@@ -294,6 +354,20 @@ func.func @transpose_elementwise_diff_operand_result_type(%a : vector<4x2xf32>,
294
354
295
355
// -----
296
356
357
+ // CHECK-LABEL: func @transpose_elementwise_diff_operand_result_type_scalable
358
+ // CHECK-SAME: (%[[A:.+]]: vector<[4]x2xf32>, %[[B:.+]]: vector<[4]x2xf32>)
359
+ // CHECK: %[[CMP:.+]] = arith.cmpf olt, %[[A]], %[[B]] : vector<[4]x2xf32>
360
+ // CHECK: %[[T:.+]] = vector.transpose %[[CMP]], [1, 0] : vector<[4]x2xi1> to vector<2x[4]xi1>
361
+ // CHECK: return %[[T]]
362
+ func.func @transpose_elementwise_diff_operand_result_type_scalable (%a : vector <[4 ]x2 xf32 >, %b : vector <[4 ]x2 xf32 >) -> vector <2 x[4 ]xi1 > {
363
+ %at = vector.transpose %a , [1 , 0 ]: vector <[4 ]x2 xf32 > to vector <2 x[4 ]xf32 >
364
+ %bt = vector.transpose %b , [1 , 0 ]: vector <[4 ]x2 xf32 > to vector <2 x[4 ]xf32 >
365
+ %r = arith.cmpf olt , %at , %bt : vector <2 x[4 ]xf32 >
366
+ return %r : vector <2 x[4 ]xi1 >
367
+ }
368
+
369
+ // -----
370
+
297
371
// CHECK-LABEL: func @transpose_elementwise_splat_constant
298
372
// CHECK-SAME: (%[[A:.+]]: vector<4x6x3x2xf32>)
299
373
// CHECK: %[[B:.+]] = arith.constant dense<5.000000e+00> : vector<4x6x3x2xf32>
@@ -310,6 +384,22 @@ func.func @transpose_elementwise_splat_constant(%a : vector<4x6x3x2xf32>) -> vec
310
384
311
385
// -----
312
386
387
+ // CHECK-LABEL: func @transpose_elementwise_splat_constant_scalable
388
+ // CHECK-SAME: (%[[A:.+]]: vector<[4]x6x3x2xf32>)
389
+ // CHECK: %[[B:.+]] = arith.constant dense<5.000000e+00> : vector<[4]x6x3x2xf32>
390
+ // CHECK: %[[ADD:.+]] = arith.addf %[[A]], %[[B]] : vector<[4]x6x3x2xf32>
391
+ // CHECK: %[[T:.+]] = vector.transpose %[[ADD]], [1, 0, 3, 2] : vector<[4]x6x3x2xf32> to vector<6x[4]x2x3xf32>
392
+ // CHECK: return %[[T:.+]] : vector<6x[4]x2x3xf32>
393
+
394
+ func.func @transpose_elementwise_splat_constant_scalable (%a : vector <[4 ]x6 x3 x2 xf32 >) -> vector <6 x[4 ]x2 x3 xf32 > {
395
+ %b = arith.constant dense <5.0 > : vector <6 x[4 ]x2 x3 xf32 >
396
+ %at = vector.transpose %a , [1 , 0 , 3 , 2 ]: vector <[4 ]x6 x3 x2 xf32 > to vector <6 x[4 ]x2 x3 xf32 >
397
+ %r = arith.addf %at , %b : vector <6 x[4 ]x2 x3 xf32 >
398
+ return %r : vector <6 x[4 ]x2 x3 xf32 >
399
+ }
400
+
401
+ // -----
402
+
313
403
// CHECK-LABEL: func @transpose_elementwise_diff_map
314
404
// CHECK: vector.transpose
315
405
// CHECK: vector.transpose
@@ -320,3 +410,16 @@ func.func @transpose_elementwise_diff_map(%a : vector<4x6x3x2xf32>, %b: vector<6
320
410
%r = arith.addf %at , %bt : vector <6 x4 x2 x3 xf32 >
321
411
return %r : vector <6 x4 x2 x3 xf32 >
322
412
}
413
+
414
+ // -----
415
+
416
+ // CHECK-LABEL: func @transpose_elementwise_diff_map_scalable
417
+ // CHECK: vector.transpose
418
+ // CHECK: vector.transpose
419
+ // CHECK: arith.addf
420
+ func.func @transpose_elementwise_diff_map_scalable (%a : vector <[4 ]x6 x3 x2 xf32 >, %b: vector <6 x2 x[4 ]x3 xf32 >) -> vector <6 x[4 ]x2 x3 xf32 > {
421
+ %at = vector.transpose %a , [1 , 0 , 3 , 2 ]: vector <[4 ]x6 x3 x2 xf32 > to vector <6 x[4 ]x2 x3 xf32 >
422
+ %bt = vector.transpose %b , [0 , 2 , 1 , 3 ]: vector <6 x2 x[4 ]x3 xf32 > to vector <6 x[4 ]x2 x3 xf32 >
423
+ %r = arith.addf %at , %bt : vector <6 x[4 ]x2 x3 xf32 >
424
+ return %r : vector <6 x[4 ]x2 x3 xf32 >
425
+ }
0 commit comments