@@ -109,90 +109,6 @@ struct LinearizeVectorizable final
109
109
}
110
110
};
111
111
112
- template <typename TOp>
113
- static bool stridesAllOne (TOp op) {
114
- static_assert (
115
- std::is_same_v<TOp, vector::ExtractStridedSliceOp> ||
116
- std::is_same_v<TOp, vector::InsertStridedSliceOp>,
117
- " expected vector.extract_strided_slice or vector.insert_strided_slice" );
118
- ArrayAttr strides = op.getStrides ();
119
- return llvm::all_of (strides, isOneInteger);
120
- }
121
-
122
- // / Convert an array of attributes into a vector of integers, if possible.
123
- static FailureOr<SmallVector<int64_t >> intsFromArrayAttr (ArrayAttr attrs) {
124
- if (!attrs)
125
- return failure ();
126
- SmallVector<int64_t > ints;
127
- ints.reserve (attrs.size ());
128
- for (auto attr : attrs) {
129
- if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
130
- ints.push_back (intAttr.getInt ());
131
- } else {
132
- return failure ();
133
- }
134
- }
135
- return ints;
136
- }
137
-
138
- // / Consider inserting a vector of shape `small` into a vector of shape `large`,
139
- // / at position `offsets`: this function enumeratates all the indices in `large`
140
- // / that are written to. The enumeration is with row-major ordering.
141
- // /
142
- // / Example: insert a 1x2 vector into a 4x5 vector at position (1,3). The 2
143
- // / positions written to are (1,3) and (1,4), which have linearized indices 8
144
- // / and 9. So [8,9] is returned.
145
- // /
146
- // / The length of the returned vector is equal to the number of elements in
147
- // / the shape `small` (i.e. the product of dimensions of `small`).
148
- SmallVector<int64_t > static getStridedSliceInsertionIndices (
149
- ArrayRef<int64_t > small, ArrayRef<int64_t > large,
150
- ArrayRef<int64_t > offsets) {
151
-
152
- // Example of alignment between, `large`, `small` and `offsets`:
153
- // large = 4, 5, 6, 7, 8
154
- // small = 1, 6, 7, 8
155
- // offsets = 2, 3, 0
156
- //
157
- // `offsets` has implicit trailing 0s, `small` has implicit leading 1s.
158
- assert ((large.size () >= small.size ()) &&
159
- " rank of 'large' cannot be lower than rank of 'small'" );
160
- assert ((large.size () >= offsets.size ()) &&
161
- " rank of 'large' cannot be lower than the number of offsets" );
162
- unsigned delta = large.size () - small.size ();
163
- unsigned nOffsets = offsets.size ();
164
- auto getSmall = [&](int64_t i) -> int64_t {
165
- return i >= delta ? small[i - delta] : 1 ;
166
- };
167
- auto getOffset = [&](int64_t i) -> int64_t {
168
- return i < nOffsets ? offsets[i] : 0 ;
169
- };
170
-
171
- // Using 2 vectors of indices, at each iteration populate the updated set of
172
- // indices based on the old set of indices, and the size of the small vector
173
- // in the current iteration.
174
- SmallVector<int64_t > indices{0 };
175
- int64_t stride = 1 ;
176
- for (int i = large.size () - 1 ; i >= 0 ; --i) {
177
- int64_t currentSize = indices.size ();
178
- int64_t smallSize = getSmall (i);
179
- int64_t nextSize = currentSize * smallSize;
180
- SmallVector<int64_t > nextIndices (nextSize);
181
- int64_t *base = nextIndices.begin ();
182
- int64_t offset = getOffset (i) * stride;
183
- for (int j = 0 ; j < smallSize; ++j) {
184
- for (int k = 0 ; k < currentSize; ++k) {
185
- base[k] = indices[k] + offset;
186
- }
187
- offset += stride;
188
- base += currentSize;
189
- }
190
- stride *= large[i];
191
- indices = std::move (nextIndices);
192
- }
193
- return indices;
194
- }
195
-
196
112
// / This pattern converts a vector.extract_strided_slice operation into a
197
113
// / vector.shuffle operation that has a rank-1 (linearized) operand and result.
198
114
// /
@@ -231,30 +147,23 @@ struct LinearizeVectorExtractStridedSlice final
231
147
232
148
// Expect a legalization failure if the strides are not all 1 (if ever the
233
149
// verifier for extract_strided_slice allows non-1 strides).
234
- if (! stridesAllOne ( extractStridedSliceOp)) {
150
+ if (extractStridedSliceOp. hasNonUnitStrides ( )) {
235
151
return rewriter.notifyMatchFailure (
236
152
extractStridedSliceOp,
237
153
" extract_strided_slice with strides != 1 not supported" );
238
154
}
239
155
240
- FailureOr<SmallVector<int64_t >> offsets =
241
- intsFromArrayAttr ( extractStridedSliceOp.getOffsets () );
242
- if (failed (offsets )) {
156
+ FailureOr<SmallVector<int64_t >> indices =
157
+ extractStridedSliceOp.getLinearIndices ( );
158
+ if (failed (indices )) {
243
159
return rewriter.notifyMatchFailure (extractStridedSliceOp,
244
- " failed to get integer offsets " );
160
+ " failed to get indices " );
245
161
}
246
162
247
- ArrayRef<int64_t > inputShape =
248
- extractStridedSliceOp.getSourceVectorType ().getShape ();
249
-
250
- ArrayRef<int64_t > outputShape = extractStridedSliceOp.getType ().getShape ();
251
-
252
- SmallVector<int64_t > indices = getStridedSliceInsertionIndices (
253
- outputShape, inputShape, offsets.value ());
254
-
255
163
Value srcVector = adaptor.getVector ();
256
- rewriter.replaceOpWithNewOp <vector::ShuffleOp>(
257
- extractStridedSliceOp, flatOutputType, srcVector, srcVector, indices);
164
+ rewriter.replaceOpWithNewOp <vector::ShuffleOp>(extractStridedSliceOp,
165
+ flatOutputType, srcVector,
166
+ srcVector, indices.value ());
258
167
return success ();
259
168
}
260
169
};
@@ -298,31 +207,24 @@ struct LinearizeVectorInsertStridedSlice final
298
207
299
208
// Expect a legalization failure if the strides are not all 1 (if ever the
300
209
// verifier for insert_strided_slice allows non-1 strides).
301
- if (! stridesAllOne ( insertStridedSliceOp)) {
210
+ if (insertStridedSliceOp. hasNonUnitStrides ( )) {
302
211
return rewriter.notifyMatchFailure (
303
212
insertStridedSliceOp,
304
213
" insert_strided_slice with strides != 1 not supported" );
305
214
}
306
215
307
- VectorType inputType = insertStridedSliceOp.getValueToStore ().getType ();
308
- ArrayRef<int64_t > inputShape = inputType.getShape ();
309
-
310
216
VectorType outputType = insertStridedSliceOp.getType ();
311
- ArrayRef<int64_t > outputShape = outputType.getShape ();
312
217
int64_t nOutputElements = outputType.getNumElements ();
313
218
314
- FailureOr<SmallVector<int64_t >> offsets =
315
- intsFromArrayAttr ( insertStridedSliceOp.getOffsets () );
316
- if (failed (offsets)) {
219
+ FailureOr<SmallVector<int64_t >> sliceIndices =
220
+ insertStridedSliceOp.getLinearIndices ( );
221
+ if (failed (sliceIndices))
317
222
return rewriter.notifyMatchFailure (insertStridedSliceOp,
318
- " failed to get integer offsets" );
319
- }
320
- SmallVector<int64_t > sliceIndices = getStridedSliceInsertionIndices (
321
- inputShape, outputShape, offsets.value ());
223
+ " failed to get indices" );
322
224
323
225
SmallVector<int64_t > indices (nOutputElements);
324
226
std::iota (indices.begin (), indices.end (), 0 );
325
- for (auto [index, sliceIndex] : llvm::enumerate (sliceIndices)) {
227
+ for (auto [index, sliceIndex] : llvm::enumerate (sliceIndices. value () )) {
326
228
indices[sliceIndex] = index + nOutputElements;
327
229
}
328
230
0 commit comments