@@ -151,3 +151,57 @@ func.func @gather_tensor_1d_none_set(%base: tensor<?xf32>, %v: vector<2xindex>,
151
151
%0 = vector.gather %base [%c0 ][%v ], %mask , %pass_thru : tensor <?xf32 >, vector <2 xindex >, vector <2 xi1 >, vector <2 xf32 > into vector <2 xf32 >
152
152
return %0 : vector <2 xf32 >
153
153
}
154
+
155
+ // Check that vector.gather of a strided memref is replaced with a
156
+ // vector.gather with indices encoding the original strides. Note that with the
157
+ // other patterns
158
+ #map = affine_map <()[s0 ] -> (s0 * 4096 )>
159
+ #map1 = affine_map <()[s0 ] -> (s0 * -4096 + 518400 , 4096 )>
160
+ func.func @strided_gather (%M_in : memref <100 x3 xf32 >, %M_out: memref <518400 xf32 >, %idxs : vector <4 xindex >, %x : index , %y : index ) {
161
+ %c0 = arith.constant 0 : index
162
+ %x_1 = affine.apply #map ()[%x ]
163
+ // Strided MemRef
164
+ %subview = memref.subview %M_in [0 , 0 ] [100 , 1 ] [1 , 1 ] : memref <100 x3 xf32 > to memref <100 xf32 , strided <[3 ]>>
165
+ %cst_0 = arith.constant dense <true > : vector <4 xi1 >
166
+ %cst = arith.constant dense <0.000000e+00 > : vector <4 xf32 >
167
+ // Gather of a strided MemRef
168
+ %7 = vector.gather %subview [%c0 ] [%idxs ], %cst_0 , %cst : memref <100 xf32 , strided <[3 ]>>, vector <4 xindex >, vector <4 xi1 >, vector <4 xf32 > into vector <4 xf32 >
169
+ %subview_1 = memref.subview %M_out [%x_1 ] [%y ] [1 ] : memref <518400 xf32 > to memref <?xf32 , strided <[1 ], offset : ?>>
170
+ vector.store %7 , %subview_1 [%c0 ] : memref <?xf32 , strided <[1 ], offset : ?>>, vector <4 xf32 >
171
+ return
172
+ }
173
+ // CHECK-LABEL: func.func @strided_gather(
174
+ // CHECK-SAME: %[[M_in:.*]]: memref<100x3xf32>,
175
+ // CHECK-SAME: %[[M_out:.*]]: memref<518400xf32>,
176
+ // CHECK-SAME: %[[IDXS:.*]]: vector<4xindex>,
177
+ // CHECK-SAME: %[[VAL_4:.*]]: index,
178
+ // CHECK-SAME: %[[VAL_5:.*]]: index) {
179
+ // CHECK: %[[CST_3:.*]] = arith.constant dense<3> : vector<4xindex>
180
+ // CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<4xi1>
181
+
182
+ // CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[M_in]] {{\[\[}}0, 1]] : memref<100x3xf32> into memref<300xf32>
183
+ // CHECK: %[[NEW_IDXS:.*]] = arith.muli %[[IDXS]], %[[CST_3]] : vector<4xindex>
184
+
185
+ // CHECK: %[[MASK_0:.*]] = vector.extract %[[MASK]][0] : i1 from vector<4xi1>
186
+ // CHECK: %[[IDX_0:.*]] = vector.extract %[[NEW_IDXS]][0] : index from vector<4xindex>
187
+ // CHECK: scf.if %[[MASK_0]] -> (vector<4xf32>)
188
+ // CHECK: %[[M_0:.*]] = vector.load %[[COLLAPSED]]{{\[}}%[[IDX_0]]] : memref<300xf32>, vector<1xf32>
189
+ // CHECK: %[[V_0:.*]] = vector.extract %[[M_0]][0] : f32 from vector<1xf32>
190
+
191
+ // CHECK: %[[MASK_1:.*]] = vector.extract %[[MASK]][1] : i1 from vector<4xi1>
192
+ // CHECK: %[[IDX_1:.*]] = vector.extract %[[NEW_IDXS]][1] : index from vector<4xindex>
193
+ // CHECK: scf.if %[[MASK_1]] -> (vector<4xf32>)
194
+ // CHECK: %[[M_1:.*]] = vector.load %[[COLLAPSED]]{{\[}}%[[IDX_1]]] : memref<300xf32>, vector<1xf32>
195
+ // CHECK: %[[V_1:.*]] = vector.extract %[[M_1]][0] : f32 from vector<1xf32>
196
+
197
+ // CHECK: %[[MASK_2:.*]] = vector.extract %[[MASK]][2] : i1 from vector<4xi1>
198
+ // CHECK: %[[IDX_2:.*]] = vector.extract %[[NEW_IDXS]][2] : index from vector<4xindex>
199
+ // CHECK: scf.if %[[MASK_2]] -> (vector<4xf32>)
200
+ // CHECK: %[[M_2:.*]] = vector.load %[[COLLAPSED]][%[[IDX_2]]] : memref<300xf32>, vector<1xf32>
201
+ // CHECK: %[[V_2:.*]] = vector.extract %[[M_2]][0] : f32 from vector<1xf32>
202
+
203
+ // CHECK: %[[MASK_3:.*]] = vector.extract %[[MASK]][3] : i1 from vector<4xi1>
204
+ // CHECK: %[[IDX_3:.*]] = vector.extract %[[NEW_IDXS]][3] : index from vector<4xindex>
205
+ // CHECK: scf.if %[[MASK_3]] -> (vector<4xf32>)
206
+ // CHECK: %[[M_3:.*]] = vector.load %[[COLLAPSED]]{{\[}}%[[IDX_3]]] : memref<300xf32>, vector<1xf32>
207
+ // CHECK: %[[V_3:.*]] = vector.extract %[[M_3]][0] : f32 from vector<1xf32>
0 commit comments