@@ -74,6 +74,43 @@ func.func @gather_memref_1d_i32_index(%base: memref<?xf32>, %v: vector<2xi32>, %
74
74
return %0 : vector <2 x3 xf32 >
75
75
}
76
76
77
+ // CHECK-LABEL: @scalable_gather_memref_2d
78
+ // CHECK-SAME: %[[BASE:.*]]: memref<?x?xf32>,
79
+ // CHECK-SAME: %[[IDXVEC:.*]]: vector<2x[3]xindex>,
80
+ // CHECK-SAME: %[[MASK:.*]]: vector<2x[3]xi1>,
81
+ // CHECK-SAME: %[[PASS:.*]]: vector<2x[3]xf32>
82
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
83
+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
84
+ // CHECK: %[[INIT:.*]] = arith.constant dense<0.000000e+00> : vector<2x[3]xf32>
85
+ // CHECK: %[[IDXVEC0:.*]] = vector.extract %[[IDXVEC]][0] : vector<[3]xindex> from vector<2x[3]xindex>
86
+ // CHECK: %[[MASK0:.*]] = vector.extract %[[MASK]][0] : vector<[3]xi1> from vector<2x[3]xi1>
87
+ // CHECK: %[[PASS0:.*]] = vector.extract %[[PASS]][0] : vector<[3]xf32> from vector<2x[3]xf32>
88
+ // CHECK: %[[GATHER0:.*]] = vector.gather %[[BASE]]{{\[}}%[[C0]], %[[C1]]] {{\[}}%[[IDXVEC0]]], %[[MASK0]], %[[PASS0]] : memref<?x?xf32>, vector<[3]xindex>, vector<[3]xi1>, vector<[3]xf32> into vector<[3]xf32>
89
+ // CHECK: %[[INS0:.*]] = vector.insert %[[GATHER0]], %[[INIT]] [0] : vector<[3]xf32> into vector<2x[3]xf32>
90
+ // CHECK: %[[IDXVEC1:.*]] = vector.extract %[[IDXVEC]][1] : vector<[3]xindex> from vector<2x[3]xindex>
91
+ // CHECK: %[[MASK1:.*]] = vector.extract %[[MASK]][1] : vector<[3]xi1> from vector<2x[3]xi1>
92
+ // CHECK: %[[PASS1:.*]] = vector.extract %[[PASS]][1] : vector<[3]xf32> from vector<2x[3]xf32>
93
+ // CHECK: %[[GATHER1:.*]] = vector.gather %[[BASE]]{{\[}}%[[C0]], %[[C1]]] {{\[}}%[[IDXVEC1]]], %[[MASK1]], %[[PASS1]] : memref<?x?xf32>, vector<[3]xindex>, vector<[3]xi1>, vector<[3]xf32> into vector<[3]xf32>
94
+ // CHECK: %[[INS1:.*]] = vector.insert %[[GATHER1]], %[[INS0]] [1] : vector<[3]xf32> into vector<2x[3]xf32>
95
+ // CHECK-NEXT: return %[[INS1]] : vector<2x[3]xf32>
96
+ func.func @scalable_gather_memref_2d (%base: memref <?x?xf32 >, %v: vector <2 x[3 ]xindex >, %mask: vector <2 x[3 ]xi1 >, %pass_thru: vector <2 x[3 ]xf32 >) -> vector <2 x[3 ]xf32 > {
97
+ %c0 = arith.constant 0 : index
98
+ %c1 = arith.constant 1 : index
99
+ %0 = vector.gather %base [%c0 , %c1 ][%v ], %mask , %pass_thru : memref <?x?xf32 >, vector <2 x[3 ]xindex >, vector <2 x[3 ]xi1 >, vector <2 x[3 ]xf32 > into vector <2 x[3 ]xf32 >
100
+ return %0 : vector <2 x[3 ]xf32 >
101
+ }
102
+
103
+ // CHECK-LABEL: @scalable_gather_cant_unroll
104
+ // CHECK-NOT: extract
105
+ // CHECK: vector.gather
106
+ // CHECK-NOT: extract
107
+ func.func @scalable_gather_cant_unroll (%base: memref <?x?xf32 >, %v: vector <[4 ]x8 xindex >, %mask: vector <[4 ]x8 xi1 >, %pass_thru: vector <[4 ]x8 xf32 >) -> vector <[4 ]x8 xf32 > {
108
+ %c0 = arith.constant 0 : index
109
+ %c1 = arith.constant 1 : index
110
+ %0 = vector.gather %base [%c0 , %c1 ][%v ], %mask , %pass_thru : memref <?x?xf32 >, vector <[4 ]x8 xindex >, vector <[4 ]x8 xi1 >, vector <[4 ]x8 xf32 > into vector <[4 ]x8 xf32 >
111
+ return %0 : vector <[4 ]x8 xf32 >
112
+ }
113
+
77
114
// CHECK-LABEL: @gather_tensor_1d
78
115
// CHECK-SAME: ([[BASE:%.+]]: tensor<?xf32>, [[IDXVEC:%.+]]: vector<2xindex>, [[MASK:%.+]]: vector<2xi1>, [[PASS:%.+]]: vector<2xf32>)
79
116
// CHECK-DAG: [[M0:%.+]] = vector.extract [[MASK]][0] : i1 from vector<2xi1>
0 commit comments