@@ -565,3 +565,112 @@ module attributes {transform.with_named_sequence} {
565
565
transform.yield
566
566
}
567
567
}
568
+
569
+ // -----
570
+
571
+ // Test hoisting of vector.extract/vector.broadcast pairs
572
+
573
+ // CHECK-LABEL: func.func @hoist_vector_broadcasts
574
+ // CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC:.+]]: vector<3x4xf32>) -> vector<3x4xf32> {
575
+ // CHECK: %[[EXTRACT:.+]] = vector.extract %[[VEC]][0] : vector<4xf32> from vector<3x4xf32>
576
+ // CHECK-NEXT: %[[LOOP:.+]] = scf.for {{.*}} {
577
+ // CHECK-NEXT: %[[USE:.+]] = "some_use"({{.*}}) : (vector<4xf32>) -> vector<4xf32>
578
+ // CHECK-NEXT: scf.yield %[[USE]] : vector<4xf32>
579
+ // CHECK-NEXT: }
580
+ // CHECK-NEXT: %[[BCAST:.+]] = vector.broadcast %[[LOOP]] : vector<4xf32> to vector<3x4xf32>
581
+ // CHECK-NEXT: return %[[BCAST]] : vector<3x4xf32>
582
+
583
+ func.func @hoist_vector_broadcasts (%lb : index , %ub : index , %step : index , %vec : vector <3 x4 xf32 >) -> vector <3 x4 xf32 > {
584
+ %bcast_vec = scf.for %arg0 = %lb to %ub step %step iter_args (%iarg = %vec ) -> vector <3 x4 xf32 > {
585
+ %extract = vector.extract %iarg [0 ] : vector <4 xf32 > from vector <3 x4 xf32 >
586
+ %use = " some_use" (%extract ) : (vector <4 xf32 >) -> vector <4 xf32 >
587
+ %broadcast = vector.broadcast %use : vector <4 xf32 > to vector <3 x4 xf32 >
588
+ scf.yield %broadcast : vector <3 x4 xf32 >
589
+ }
590
+ return %bcast_vec : vector <3 x4 xf32 >
591
+ }
592
+
593
+ module attributes {transform.with_named_sequence } {
594
+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
595
+ %0 = transform.structured.match ops {[" func.func" ]} in %arg1
596
+ : (!transform.any_op ) -> !transform.any_op
597
+ transform.structured.hoist_redundant_vector_broadcasts %0
598
+ : (!transform.any_op ) -> !transform.any_op
599
+ transform.yield
600
+ }
601
+ }
602
+
603
+ // -----
604
+
605
+ // Test hoisting of vector.extract/vector.broadcast pairs with dynamic position
606
+
607
+ // CHECK-LABEL: func.func @hoist_vector_broadcasts
608
+ // CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC:.+]]: vector<3x4xf32>, %[[POS:.+]]: index) -> vector<3x4xf32> {
609
+ // CHECK: %[[EXTRACT:.+]] = vector.extract %[[VEC]][%[[POS]]] : vector<4xf32> from vector<3x4xf32>
610
+ // CHECK-NEXT: %[[LOOP:.+]] = scf.for {{.*}} {
611
+ // CHECK-NEXT: %[[USE:.+]] = "some_use"({{.*}}) : (vector<4xf32>) -> vector<4xf32>
612
+ // CHECK-NEXT: scf.yield %[[USE]] : vector<4xf32>
613
+ // CHECK-NEXT: }
614
+ // CHECK-NEXT: %[[BCAST:.+]] = vector.broadcast %[[LOOP]] : vector<4xf32> to vector<3x4xf32>
615
+ // CHECK-NEXT: return %[[BCAST]] : vector<3x4xf32>
616
+
617
+ func.func @hoist_vector_broadcasts_dynamic (%lb : index , %ub : index , %step : index , %vec : vector <3 x4 xf32 >, %pos: index ) -> vector <3 x4 xf32 > {
618
+ %bcast_vec = scf.for %arg0 = %lb to %ub step %step iter_args (%iarg = %vec ) -> vector <3 x4 xf32 > {
619
+ %extract = vector.extract %iarg [%pos ] : vector <4 xf32 > from vector <3 x4 xf32 >
620
+ %use = " some_use" (%extract ) : (vector <4 xf32 >) -> vector <4 xf32 >
621
+ %broadcast = vector.broadcast %use : vector <4 xf32 > to vector <3 x4 xf32 >
622
+ scf.yield %broadcast : vector <3 x4 xf32 >
623
+ }
624
+ return %bcast_vec : vector <3 x4 xf32 >
625
+ }
626
+
627
+ module attributes {transform.with_named_sequence } {
628
+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
629
+ %0 = transform.structured.match ops {[" func.func" ]} in %arg1
630
+ : (!transform.any_op ) -> !transform.any_op
631
+ transform.structured.hoist_redundant_vector_broadcasts %0
632
+ : (!transform.any_op ) -> !transform.any_op
633
+ transform.yield
634
+ }
635
+ }
636
+
637
+ // -----
638
+
639
+ // Test hoisting of vector.extract/vector.broadcast pairs with multiple iter_args
640
+
641
+ // CHECK-LABEL: func.func @hoist_vector_broadcasts_multiple
642
+ // CHECK-SAME: (%{{.+}}: index, %{{.+}}: index, %{{.+}}: index, %[[VEC1:.+]]: vector<3x4xf32>,
643
+ // CHECK-SAME: %[[VEC2:.+]]: vector<3x5xf32>) -> (vector<3x4xf32>, vector<3x5xf32>) {
644
+ // CHECK-DAG: %[[EXTRACT1:.+]] = vector.extract %[[VEC1]][0] : vector<4xf32> from vector<3x4xf32>
645
+ // CHECK-DAG: %[[EXTRACT2:.+]] = vector.extract %[[VEC2]][1] : vector<5xf32> from vector<3x5xf32>
646
+ // CHECK-NEXT: %[[LOOP:.+]]:2 = scf.for {{.*}} {
647
+ // CHECK-DAG: %[[USE1:.+]] = "some_use1"({{.*}}) : (vector<4xf32>) -> vector<4xf32>
648
+ // CHECK-DAG: %[[USE2:.+]] = "some_use2"({{.*}}) : (vector<5xf32>) -> vector<5xf32>
649
+ // CHECK-NEXT: scf.yield %[[USE1]], %[[USE2]] : vector<4xf32>, vector<5xf32>
650
+ // CHECK-NEXT: }
651
+ // CHECK-DAG: %[[BCAST1:.+]] = vector.broadcast %[[LOOP]]#0 : vector<4xf32> to vector<3x4xf32>
652
+ // CHECK-DAG: %[[BCAST2:.+]] = vector.broadcast %[[LOOP]]#1 : vector<5xf32> to vector<3x5xf32>
653
+ // CHECK-NEXT: return %[[BCAST1]], %[[BCAST2]] : vector<3x4xf32>, vector<3x5xf32>
654
+
655
+ func.func @hoist_vector_broadcasts_multiple (%lb : index , %ub : index , %step : index , %vec1 : vector <3 x4 xf32 >, %vec2 : vector <3 x5 xf32 >) -> (vector <3 x4 xf32 >, vector <3 x5 xf32 >) {
656
+ %bcast_vec:2 = scf.for %arg0 = %lb to %ub step %step iter_args (%iarg = %vec1 , %iarg2 = %vec2 ) -> (vector <3 x4 xf32 >, vector <3 x5 xf32 >) {
657
+ %extract1 = vector.extract %iarg [0 ] : vector <4 xf32 > from vector <3 x4 xf32 >
658
+ %extract2 = vector.extract %iarg2 [1 ] : vector <5 xf32 > from vector <3 x5 xf32 >
659
+ %use1 = " some_use1" (%extract1 ) : (vector <4 xf32 >) -> vector <4 xf32 >
660
+ %use2 = " some_use2" (%extract2 ) : (vector <5 xf32 >) -> vector <5 xf32 >
661
+ %broadcast1 = vector.broadcast %use1 : vector <4 xf32 > to vector <3 x4 xf32 >
662
+ %broadcast2 = vector.broadcast %use2 : vector <5 xf32 > to vector <3 x5 xf32 >
663
+ scf.yield %broadcast1 , %broadcast2 : vector <3 x4 xf32 >,vector <3 x5 xf32 >
664
+ }
665
+ return %bcast_vec#0 , %bcast_vec#1 : vector <3 x4 xf32 >, vector <3 x5 xf32 >
666
+ }
667
+
668
+ module attributes {transform.with_named_sequence } {
669
+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
670
+ %0 = transform.structured.match ops {[" func.func" ]} in %arg1
671
+ : (!transform.any_op ) -> !transform.any_op
672
+ transform.structured.hoist_redundant_vector_broadcasts %0
673
+ : (!transform.any_op ) -> !transform.any_op
674
+ transform.yield
675
+ }
676
+ }
0 commit comments