@@ -513,3 +513,192 @@ func.func @negative_extract_vec_fma(%arg0: vector<4xf32>, %arg1: vector<4xf32>,
513
513
%1 = vector.extract %0 [1 ] : f32 from vector <4 xf32 >
514
514
return %1 : f32
515
515
}
516
+
517
+ //-----------------------------------------------------------------------------
518
+ // [Pattern: ExtractOpFromLoad]
519
+ //-----------------------------------------------------------------------------
520
+
521
+ // CHECK-LABEL: @extract_load_scalar
522
+ // CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
523
+ func.func @extract_load_scalar (%arg0: memref <?xf32 >, %arg1: index ) -> f32 {
524
+ // CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>
525
+ // CHECK: return %[[RES]] : f32
526
+ %0 = vector.load %arg0 [%arg1 ] : memref <?xf32 >, vector <4 xf32 >
527
+ %1 = vector.extract %0 [0 ] : f32 from vector <4 xf32 >
528
+ return %1 : f32
529
+ }
530
+
531
+ // CHECK-LABEL: @extract_load_scalar_non_zero_off
532
+ // CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
533
+ func.func @extract_load_scalar_non_zero_off (%arg0: memref <?xf32 >, %arg1: index ) -> f32 {
534
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
535
+ // CHECK: %[[OFF:.*]] = arith.addi %[[ARG1]], %[[C1]] overflow<nsw> : index
536
+ // CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[OFF]]] : memref<?xf32>
537
+ // CHECK: return %[[RES]] : f32
538
+ %0 = vector.load %arg0 [%arg1 ] : memref <?xf32 >, vector <4 xf32 >
539
+ %1 = vector.extract %0 [1 ] : f32 from vector <4 xf32 >
540
+ return %1 : f32
541
+ }
542
+
543
+ // CHECK-LABEL: @extract_load_scalar_dyn_off
544
+ // CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
545
+ func.func @extract_load_scalar_dyn_off (%arg0: memref <?xf32 >, %arg1: index , %arg2: index ) -> f32 {
546
+ // CHECK: %[[OFF:.*]] = arith.addi %[[ARG1]], %[[ARG2]] overflow<nsw> : index
547
+ // CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[OFF]]] : memref<?xf32>
548
+ // CHECK: return %[[RES]] : f32
549
+ %0 = vector.load %arg0 [%arg1 ] : memref <?xf32 >, vector <4 xf32 >
550
+ %1 = vector.extract %0 [%arg2 ] : f32 from vector <4 xf32 >
551
+ return %1 : f32
552
+ }
553
+
554
+ // CHECK-LABEL: @extract_load_vec
555
+ // CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
556
+ func.func @extract_load_vec (%arg0: memref <?x?xf32 >, %arg1: index , %arg2: index ) -> vector <4 xf32 > {
557
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
558
+ // CHECK: %[[OFF:.*]] = arith.addi %[[ARG1]], %[[C1]] overflow<nsw> : index
559
+ // CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[OFF]], %[[ARG2]]] : memref<?x?xf32>, vector<4xf32>
560
+ // CHECK: return %[[RES]] : vector<4xf32>
561
+ %0 = vector.load %arg0 [%arg1 , %arg2 ] : memref <?x?xf32 >, vector <2 x4 xf32 >
562
+ %1 = vector.extract %0 [1 ] : vector <4 xf32 > from vector <2 x4 xf32 >
563
+ return %1 : vector <4 xf32 >
564
+ }
565
+
566
+ // CHECK-LABEL: @extract_load_scalar_high_rank
567
+ // CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
568
+ func.func @extract_load_scalar_high_rank (%arg0: memref <?x?xf32 >, %arg1: index , %arg2: index ) -> f32 {
569
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
570
+ // CHECK: %[[OFF:.*]] = arith.addi %[[ARG2]], %[[C1]] overflow<nsw> : index
571
+ // CHECK: %[[RES:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[OFF]]] : memref<?x?xf32>
572
+ // CHECK: return %[[RES]] : f32
573
+ %0 = vector.load %arg0 [%arg1 , %arg2 ] : memref <?x?xf32 >, vector <4 xf32 >
574
+ %1 = vector.extract %0 [1 ] : f32 from vector <4 xf32 >
575
+ return %1 : f32
576
+ }
577
+
578
+ // CHECK-LABEL: @extract_load_vec_high_rank
579
+ // CHECK-SAME: (%[[ARG0:.*]]: memref<?x?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
580
+ func.func @extract_load_vec_high_rank (%arg0: memref <?x?x?xf32 >, %arg1: index , %arg2: index , %arg3: index ) -> vector <4 xf32 > {
581
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
582
+ // CHECK: %[[OFF:.*]] = arith.addi %[[ARG2]], %[[C1]] overflow<nsw> : index
583
+ // CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]], %[[OFF]], %[[ARG3]]] : memref<?x?x?xf32>, vector<4xf32>
584
+ // CHECK: return %[[RES]] : vector<4xf32>
585
+ %0 = vector.load %arg0 [%arg1 , %arg2 , %arg3 ] : memref <?x?x?xf32 >, vector <2 x4 xf32 >
586
+ %1 = vector.extract %0 [1 ] : vector <4 xf32 > from vector <2 x4 xf32 >
587
+ return %1 : vector <4 xf32 >
588
+ }
589
+
590
+ // CHECK-LABEL: @negative_load_scalar_from_vec_memref
591
+ // CHECK-SAME: (%[[ARG0:.*]]: memref<?xvector<4xf32>>, %[[ARG1:.*]]: index)
592
+ func.func @negative_load_scalar_from_vec_memref (%arg0: memref <?xvector <4 xf32 >>, %arg1: index ) -> f32 {
593
+ // CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xvector<4xf32>>, vector<4xf32>
594
+ // CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<4xf32>
595
+ // CHECK: return %[[EXT]] : f32
596
+ %0 = vector.load %arg0 [%arg1 ] : memref <?xvector <4 xf32 >>, vector <4 xf32 >
597
+ %1 = vector.extract %0 [0 ] : f32 from vector <4 xf32 >
598
+ return %1 : f32
599
+ }
600
+
601
+ // CHECK-LABEL: @negative_extract_load_no_single_use
602
+ // CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
603
+ func.func @negative_extract_load_no_single_use (%arg0: memref <?xf32 >, %arg1: index ) -> (f32 , vector <4 xf32 >) {
604
+ // CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<4xf32>
605
+ // CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<4xf32>
606
+ // CHECK: return %[[EXT]], %[[RES]] : f32, vector<4xf32>
607
+ %0 = vector.load %arg0 [%arg1 ] : memref <?xf32 >, vector <4 xf32 >
608
+ %1 = vector.extract %0 [0 ] : f32 from vector <4 xf32 >
609
+ return %1 , %0 : f32 , vector <4 xf32 >
610
+ }
611
+
612
+ // CHECK-LABEL: @negative_load_scalable
613
+ // CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
614
+ func.func @negative_load_scalable (%arg0: memref <?xf32 >, %arg1: index ) -> f32 {
615
+ // CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<[1]xf32>
616
+ // CHECK: %[[EXT:.*]] = vector.extract %[[RES]][0] : f32 from vector<[1]xf32>
617
+ // CHECK: return %[[EXT]] : f32
618
+ %0 = vector.load %arg0 [%arg1 ] : memref <?xf32 >, vector <[1 ]xf32 >
619
+ %1 = vector.extract %0 [0 ] : f32 from vector <[1 ]xf32 >
620
+ return %1 : f32
621
+ }
622
+
623
+ // CHECK-LABEL: @negative_extract_load_unsupported_ranks
624
+ // CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index)
625
+ func.func @negative_extract_load_unsupported_ranks (%arg0: memref <?xf32 >, %arg1: index ) -> vector <4 xf32 > {
626
+ // CHECK: %[[RES:.*]] = vector.load %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<2x4xf32>
627
+ // CHECK: %[[EXT:.*]] = vector.extract %[[RES]][1] : vector<4xf32> from vector<2x4xf32>
628
+ // CHECK: return %[[EXT]] : vector<4xf32>
629
+ %0 = vector.load %arg0 [%arg1 ] : memref <?xf32 >, vector <2 x4 xf32 >
630
+ %1 = vector.extract %0 [1 ] : vector <4 xf32 > from vector <2 x4 xf32 >
631
+ return %1 : vector <4 xf32 >
632
+ }
633
+
634
+ //-----------------------------------------------------------------------------
635
+ // [Pattern: StoreFromSplat]
636
+ //-----------------------------------------------------------------------------
637
+
638
+ // CHECK-LABEL: @store_splat
639
+ // CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
640
+ func.func @store_splat (%arg0: memref <?xf32 >, %arg1: index , %arg2: f32 ) {
641
+ // CHECK: memref.store %[[ARG2]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>
642
+ %0 = vector.splat %arg2 : vector <1 xf32 >
643
+ vector.store %0 , %arg0 [%arg1 ] : memref <?xf32 >, vector <1 xf32 >
644
+ return
645
+ }
646
+
647
+ // CHECK-LABEL: @store_broadcast
648
+ // CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
649
+ func.func @store_broadcast (%arg0: memref <?xf32 >, %arg1: index , %arg2: f32 ) {
650
+ // CHECK: memref.store %[[ARG2]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>
651
+ %0 = vector.broadcast %arg2 : f32 to vector <1 xf32 >
652
+ vector.store %0 , %arg0 [%arg1 ] : memref <?xf32 >, vector <1 xf32 >
653
+ return
654
+ }
655
+
656
+ // CHECK-LABEL: @store_broadcast_1d_2d
657
+ // CHECK-SAME: (%[[ARG0:.*]]: memref<?x?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: vector<1xf32>)
658
+ func.func @store_broadcast_1d_2d (%arg0: memref <?x?xf32 >, %arg1: index , %arg2: index , %arg3: vector <1 xf32 >) {
659
+ // CHECK: vector.store %[[ARG3]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<?x?xf32>, vector<1xf32>
660
+ %0 = vector.broadcast %arg3 : vector <1 xf32 > to vector <1 x1 xf32 >
661
+ vector.store %0 , %arg0 [%arg1 , %arg2 ] : memref <?x?xf32 >, vector <1 x1 xf32 >
662
+ return
663
+ }
664
+
665
+ // CHECK-LABEL: @negative_store_scalable
666
+ // CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
667
+ func.func @negative_store_scalable (%arg0: memref <?xf32 >, %arg1: index , %arg2: f32 ) {
668
+ // CHECK: %[[RES:.*]] = vector.splat %[[ARG2]] : vector<[1]xf32>
669
+ // CHECK: vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<[1]xf32>
670
+ %0 = vector.splat %arg2 : vector <[1 ]xf32 >
671
+ vector.store %0 , %arg0 [%arg1 ] : memref <?xf32 >, vector <[1 ]xf32 >
672
+ return
673
+ }
674
+
675
+ // CHECK-LABEL: @negative_store_vec_memref
676
+ // CHECK-SAME: (%[[ARG0:.*]]: memref<?xvector<1xf32>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
677
+ func.func @negative_store_vec_memref (%arg0: memref <?xvector <1 xf32 >>, %arg1: index , %arg2: f32 ) {
678
+ // CHECK: %[[RES:.*]] = vector.splat %[[ARG2]] : vector<1xf32>
679
+ // CHECK: vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xvector<1xf32>>, vector<1xf32>
680
+ %0 = vector.splat %arg2 : vector <1 xf32 >
681
+ vector.store %0 , %arg0 [%arg1 ] : memref <?xvector <1 xf32 >>, vector <1 xf32 >
682
+ return
683
+ }
684
+
685
+ // CHECK-LABEL: @negative_store_non_1
686
+ // CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
687
+ func.func @negative_store_non_1 (%arg0: memref <?xf32 >, %arg1: index , %arg2: f32 ) {
688
+ // CHECK: %[[RES:.*]] = vector.splat %[[ARG2]] : vector<4xf32>
689
+ // CHECK: vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<4xf32>
690
+ %0 = vector.splat %arg2 : vector <4 xf32 >
691
+ vector.store %0 , %arg0 [%arg1 ] : memref <?xf32 >, vector <4 xf32 >
692
+ return
693
+ }
694
+
695
+ // CHECK-LABEL: @negative_store_no_single_use
696
+ // CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: f32)
697
+ func.func @negative_store_no_single_use (%arg0: memref <?xf32 >, %arg1: index , %arg2: f32 ) -> vector <1 xf32 > {
698
+ // CHECK: %[[RES:.*]] = vector.splat %[[ARG2]] : vector<1xf32>
699
+ // CHECK: vector.store %[[RES]], %[[ARG0]][%[[ARG1]]] : memref<?xf32>, vector<1xf32>
700
+ // CHECK: return %[[RES:.*]] : vector<1xf32>
701
+ %0 = vector.splat %arg2 : vector <1 xf32 >
702
+ vector.store %0 , %arg0 [%arg1 ] : memref <?xf32 >, vector <1 xf32 >
703
+ return %0 : vector <1 xf32 >
704
+ }
0 commit comments