@@ -595,3 +595,60 @@ module attributes {transform.with_named_sequence} {
595
595
transform.yield
596
596
}
597
597
}
598
+
599
+
600
+ // -----
601
+
602
+ func.func @vectorize_scalar_broadcast_column_tensor (%in: tensor <1 x1 x4 xi32 >) -> tensor <1 x1 x4 xi32 > {
603
+ %c4 = arith.constant 4 : index
604
+ %c0 = arith.constant 0 : index
605
+ %cst = arith.constant dense <[[0 ], [0 ], [1 ], [1 ], [2 ], [2 ], [3 ], [3 ], [4 ], [4 ], [5 ], [5 ], [6 ], [6 ], [7 ], [7 ], [8 ], [8 ], [9 ], [9 ], [10 ], [10 ], [11 ], [11 ], [12 ], [12 ], [13 ], [13 ], [14 ], [14 ]]> : tensor <30 x1 xi32 >
606
+
607
+ %out = linalg.generic {index ing_maps = [affine_map <(d0 , d1 , d2 ) -> (d0 , d1 , d2 )>], iterator_types = [" parallel" , " parallel" , " parallel" ]} outs (%in : tensor <1 x1 x4 xi32 >) {
608
+ ^bb0 (%out: i32 ):
609
+ %8 = linalg.index 0 : index
610
+ %idx_0 = linalg.index 0 : index
611
+ %extracted = tensor.extract %cst [%idx_0 , %c0 ] : tensor <30 x1 xi32 >
612
+ linalg.yield %extracted : i32
613
+ } -> tensor <1 x1 x4 xi32 >
614
+
615
+ return %out:tensor <1 x1 x4 xi32 >
616
+ }
617
+ // CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1) -> (0, 0, 0)>
618
+ // CHECK-LABEL: func.func @vectorize_scalar_broadcast_column_tensor(
619
+ // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> {
620
+ // CHECK: %[[VAL_1:.*]] = arith.constant 4 : index
621
+ // CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
622
+ // CHECK: %[[VAL_3:.*]] = arith.constant dense<{{\[\[}}0], [0], [1], [1], [2], [2], [3], [3], [4], [4], [5], [5], [6], [6], [7], [7], [8], [8], [9], [9], [10], [10], [11], [11], [12], [12], [13], [13], [14], [14]]> : tensor<30x1xi32>
623
+ // CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
624
+ // CHECK: %[[VAL_5:.*]] = arith.constant 1 : index
625
+ // CHECK: %[[VAL_6:.*]] = arith.constant 4 : index
626
+ // CHECK: %[[VAL_7:.*]] = arith.constant 0 : index
627
+ // CHECK: %[[VAL_8:.*]] = arith.constant 0 : i32
628
+ // CHECK: %[[VAL_9:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_7]], %[[VAL_7]], %[[VAL_7]]], %[[VAL_8]] : tensor<1x1x4xi32>, vector<1x1x4xi32>
629
+ // CHECK: %[[VAL_10:.*]] = vector.step : vector<1xindex>
630
+ // CHECK: %[[VAL_11:.*]] = vector.broadcast %[[VAL_10]] : vector<1xindex> to vector<4x1x1xindex>
631
+ // CHECK: %[[VAL_12:.*]] = vector.transpose %[[VAL_11]], [2, 1, 0] : vector<4x1x1xindex> to vector<1x1x4xindex>
632
+ // CHECK: %[[VAL_13:.*]] = vector.step : vector<1xindex>
633
+ // CHECK: %[[VAL_14:.*]] = vector.broadcast %[[VAL_13]] : vector<1xindex> to vector<4x1x1xindex>
634
+ // CHECK: %[[VAL_15:.*]] = vector.transpose %[[VAL_14]], [2, 1, 0] : vector<4x1x1xindex> to vector<1x1x4xindex>
635
+ // CHECK: %[[VAL_16:.*]] = arith.constant dense<true> : vector<1x1x4xi1>
636
+ // CHECK: %[[VAL_17:.*]] = arith.constant dense<0> : vector<1x1x4xi32>
637
+ // CHECK: %[[VAL_18:.*]] = arith.constant 0 : index
638
+ // CHECK: %[[VAL_19:.*]] = arith.constant 0 : i32
639
+ // CHECK: %[[VAL_20:.*]] = vector.shape_cast %[[VAL_15]] : vector<1x1x4xindex> to vector<4xindex>
640
+ // CHECK: %[[VAL_21:.*]] = vector.extractelement %[[VAL_20]]{{\[}}%[[VAL_19]] : i32] : vector<4xindex>
641
+ // CHECK: %[[VAL_22:.*]] = arith.constant 0 : i32
642
+ // CHECK: %[[VAL_23:.*]] = vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_21]], %[[VAL_2]]], %[[VAL_22]] {in_bounds = [true, true, true], permutation_map = #[[$ATTR_1]]} : tensor<30x1xi32>, vector<1x1x4xi32>
643
+ // CHECK: %[[VAL_24:.*]] = arith.constant 0 : index
644
+ // CHECK: %[[VAL_25:.*]] = vector.transfer_write %[[VAL_23]], %[[VAL_0]]{{\[}}%[[VAL_24]], %[[VAL_24]], %[[VAL_24]]] : vector<1x1x4xi32>, tensor<1x1x4xi32>
645
+ // CHECK: return %[[VAL_25]] : tensor<1x1x4xi32>
646
+ // CHECK: }
647
+
648
+ module attributes {transform.with_named_sequence } {
649
+ transform.named_sequence @__transform_main (%arg1: !transform.any_op {transform.readonly }) {
650
+ %0 = transform.structured.match ops {[" linalg.generic" ]} in %arg1 : (!transform.any_op ) -> !transform.any_op
651
+ transform.structured.vectorize %0 vector_sizes [1 , 1 , 4 ]{ vectorize_nd_extract } : !transform.any_op
652
+ transform.yield
653
+ }
654
+ }
0 commit comments