@@ -438,6 +438,17 @@ func.func @extract_elementwise_scalar(%arg0: vector<4xf32>, %arg1: vector<4xf32>
438
438
return %1 : f32
439
439
}
440
440
441
+ // CHECK-LABEL: @extract_elementwise_arg_res_different_types
442
+ // CHECK-SAME: (%[[ARG0:.*]]: vector<4xindex>)
443
+ func.func @extract_elementwise_arg_res_different_types (%arg0: vector <4 xindex >) -> i64 {
444
+ // CHECK: %[[EXT:.*]] = vector.extract %[[ARG0]][1] : index from vector<4xindex>
445
+ // CHECK: %[[RES:.*]] = arith.index_cast %[[EXT]] : index to i64
446
+ // CHECK: return %[[RES]] : i64
447
+ %0 = arith.index_cast %arg0: vector <4 xindex > to vector <4 xi64 >
448
+ %1 = vector.extract %0 [1 ] : i64 from vector <4 xi64 >
449
+ return %1 : i64
450
+ }
451
+
441
452
// CHECK-LABEL: @extract_elementwise_vec
442
453
// CHECK-SAME: (%[[ARG0:.*]]: vector<2x4xf32>, %[[ARG1:.*]]: vector<2x4xf32>)
443
454
func.func @extract_elementwise_vec (%arg0: vector <2 x4 xf32 >, %arg1: vector <2 x4 xf32 >) -> vector <4 xf32 > {
@@ -461,3 +472,27 @@ func.func @extract_elementwise_no_single_use(%arg0: vector<4xf32>, %arg1: vector
461
472
%1 = vector.extract %0 [1 ] : f32 from vector <4 xf32 >
462
473
return %1 , %0 : f32 , vector <4 xf32 >
463
474
}
475
+
476
+ // CHECK-LABEL: @extract_elementwise_not_one_res
477
+ // CHECK-SAME: (%[[ARG0:.*]]: vector<4xi32>, %[[ARG1:.*]]: vector<4xi32>)
478
+ func.func @extract_elementwise_not_one_res (%arg0: vector <4 xi32 >, %arg1: vector <4 xi32 >) -> i32 {
479
+ // Do not propagate extract, as elementwise has more than 1 result.
480
+ // CHECK: %[[LOW:.*]], %[[HIGH:.*]] = arith.mulsi_extended %[[ARG0]], %[[ARG1]] : vector<4xi32>
481
+ // CHECK: %[[EXT:.*]] = vector.extract %[[LOW]][1] : i32 from vector<4xi32>
482
+ // CHECK: return %[[EXT]] : i32
483
+ %low , %hi = arith.mulsi_extended %arg0 , %arg1 : vector <4 xi32 >
484
+ %1 = vector.extract %low [1 ] : i32 from vector <4 xi32 >
485
+ return %1 : i32
486
+ }
487
+
488
+ // CHECK-LABEL: @extract_not_elementwise
489
+ // CHECK-SAME: (%[[ARG0:.*]]: vector<4xi64>)
490
+ func.func @extract_not_elementwise (%arg0: vector <4 xi64 >) -> i64 {
491
+ // `test.increment` is not an elemewise op.
492
+ // CHECK: %[[INC:.*]] = test.increment %[[ARG0]] : vector<4xi64>
493
+ // CHECK: %[[RES:.*]] = vector.extract %[[INC]][1] : i64 from vector<4xi64>
494
+ // CHECK: return %[[RES]] : i64
495
+ %0 = test.increment %arg0: vector <4 xi64 >
496
+ %1 = vector.extract %0 [1 ] : i64 from vector <4 xi64 >
497
+ return %1 : i64
498
+ }
0 commit comments