Skip to content

Commit 6ed05ed

Browse files
authored
[mlir][vector] linearize vector.insert_strided_slice (flatten to vector.shuffle) (#138725)
Extends the set of vector operations that we can linearize to include vector.insert_strided_slice. The new pattern reuses the ideas from vector.extract_strided_slice linearization.
1 parent 61c1f6d commit 6ed05ed

File tree

2 files changed

+287
-94
lines changed

2 files changed

+287
-94
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp

Lines changed: 208 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -109,17 +109,110 @@ struct LinearizeVectorizable final
109109
}
110110
};
111111

112-
/// This pattern converts the ExtractStridedSliceOp into a ShuffleOp that works
113-
/// on a linearized vector.
114-
/// Following,
112+
template <typename TOp>
113+
static bool stridesAllOne(TOp op) {
114+
static_assert(
115+
std::is_same_v<TOp, vector::ExtractStridedSliceOp> ||
116+
std::is_same_v<TOp, vector::InsertStridedSliceOp>,
117+
"expected vector.extract_strided_slice or vector.insert_strided_slice");
118+
ArrayAttr strides = op.getStrides();
119+
return llvm::all_of(
120+
strides, [](auto stride) { return isConstantIntValue(stride, 1); });
121+
}
122+
123+
/// Convert an array of attributes into a vector of integers, if possible.
124+
static FailureOr<SmallVector<int64_t>> intsFromArrayAttr(ArrayAttr attrs) {
125+
if (!attrs)
126+
return failure();
127+
SmallVector<int64_t> ints;
128+
ints.reserve(attrs.size());
129+
for (auto attr : attrs) {
130+
if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
131+
ints.push_back(intAttr.getInt());
132+
} else {
133+
return failure();
134+
}
135+
}
136+
return ints;
137+
}
138+
139+
/// Consider inserting a vector of shape `small` into a vector of shape `large`,
140+
/// at position `offsets`: this function enumeratates all the indices in `large`
141+
/// that are written to. The enumeration is with row-major ordering.
142+
///
143+
/// Example: insert a 1x2 vector into a 4x5 vector at position (1,3). The 2
144+
/// positions written to are (1,3) and (1,4), which have linearized indices 8
145+
/// and 9. So [8,9] is returned.
146+
///
147+
/// The length of the returned vector is equal to the number of elements in
148+
/// the shape `small` (i.e. the product of dimensions of `small`).
149+
SmallVector<int64_t> static getStridedSliceInsertionIndices(
150+
ArrayRef<int64_t> small, ArrayRef<int64_t> large,
151+
ArrayRef<int64_t> offsets) {
152+
153+
// Example of alignment between, `large`, `small` and `offsets`:
154+
// large = 4, 5, 6, 7, 8
155+
// small = 1, 6, 7, 8
156+
// offsets = 2, 3, 0
157+
//
158+
// `offsets` has implicit trailing 0s, `small` has implicit leading 1s.
159+
assert((large.size() >= small.size()) &&
160+
"rank of 'large' cannot be lower than rank of 'small'");
161+
assert((large.size() >= offsets.size()) &&
162+
"rank of 'large' cannot be lower than the number of offsets");
163+
unsigned delta = large.size() - small.size();
164+
unsigned nOffsets = offsets.size();
165+
auto getSmall = [&](int64_t i) -> int64_t {
166+
return i >= delta ? small[i - delta] : 1;
167+
};
168+
auto getOffset = [&](int64_t i) -> int64_t {
169+
return i < nOffsets ? offsets[i] : 0;
170+
};
171+
172+
// Using 2 vectors of indices, at each iteration populate the updated set of
173+
// indices based on the old set of indices, and the size of the small vector
174+
// in the current iteration.
175+
SmallVector<int64_t> indices{0};
176+
int64_t stride = 1;
177+
for (int i = large.size() - 1; i >= 0; --i) {
178+
int64_t currentSize = indices.size();
179+
int64_t smallSize = getSmall(i);
180+
int64_t nextSize = currentSize * smallSize;
181+
SmallVector<int64_t> nextIndices(nextSize);
182+
int64_t *base = nextIndices.begin();
183+
int64_t offset = getOffset(i) * stride;
184+
for (int j = 0; j < smallSize; ++j) {
185+
for (int k = 0; k < currentSize; ++k) {
186+
base[k] = indices[k] + offset;
187+
}
188+
offset += stride;
189+
base += currentSize;
190+
}
191+
stride *= large[i];
192+
indices = std::move(nextIndices);
193+
}
194+
return indices;
195+
}
196+
197+
/// This pattern converts a vector.extract_strided_slice operation into a
198+
/// vector.shuffle operation that has a rank-1 (linearized) operand and result.
199+
///
200+
/// For example, the following:
201+
///
202+
/// ```
115203
/// vector.extract_strided_slice %source
116204
/// { offsets = [..], strides = [..], sizes = [..] }
205+
/// ```
206+
///
117207
/// is converted to :
208+
/// ```
118209
/// %source_1d = vector.shape_cast %source
119-
/// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
120-
/// %out_nd = vector.shape_cast %out_1d
121-
/// `shuffle_indices_1d` is computed using the offsets and sizes of the
122-
/// extraction.
210+
/// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ]
211+
/// %out_nd = vector.shape_cast %out_1d
212+
/// ```
213+
///
214+
/// `shuffle_indices_1d` is computed using the offsets and sizes of the original
215+
/// vector.extract_strided_slice operation.
123216
struct LinearizeVectorExtractStridedSlice final
124217
: public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
125218
using OpConversionPattern::OpConversionPattern;
@@ -129,88 +222,116 @@ struct LinearizeVectorExtractStridedSlice final
129222
: OpConversionPattern(typeConverter, context, benefit) {}
130223

131224
LogicalResult
132-
matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
225+
matchAndRewrite(vector::ExtractStridedSliceOp extractStridedSliceOp,
226+
OpAdaptor adaptor,
133227
ConversionPatternRewriter &rewriter) const override {
134-
VectorType dstType =
135-
getTypeConverter()->convertType<VectorType>(extractOp.getType());
136-
assert(dstType && "vector type destination expected.");
137-
if (extractOp.getVector().getType().isScalable() || dstType.isScalable())
138-
return rewriter.notifyMatchFailure(extractOp,
139-
"scalable vectors are not supported.");
140228

141-
ArrayAttr offsets = extractOp.getOffsets();
142-
ArrayAttr sizes = extractOp.getSizes();
143-
ArrayAttr strides = extractOp.getStrides();
144-
if (!isConstantIntValue(strides[0], 1))
229+
VectorType flatOutputType = getTypeConverter()->convertType<VectorType>(
230+
extractStridedSliceOp.getType());
231+
assert(flatOutputType && "vector type expected");
232+
233+
// Expect a legalization failure if the strides are not all 1 (if ever the
234+
// verifier for extract_strided_slice allows non-1 strides).
235+
if (!stridesAllOne(extractStridedSliceOp)) {
145236
return rewriter.notifyMatchFailure(
146-
extractOp, "Strided slice with stride != 1 is not supported.");
147-
Value srcVector = adaptor.getVector();
148-
// If kD offsets are specified for nD source vector (n > k), the granularity
149-
// of the extraction is greater than 1. In this case last (n-k) dimensions
150-
// form the extraction granularity.
151-
// Example :
152-
// vector.extract_strided_slice %src {
153-
// offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} :
154-
// vector<4x8x8xf32> to vector<2x2x8xf32>
155-
// Here, extraction granularity is 8.
156-
int64_t extractGranularitySize = 1;
157-
int64_t nD = extractOp.getSourceVectorType().getRank();
158-
int64_t kD = (int64_t)offsets.size();
159-
int64_t k = kD;
160-
while (k < nD) {
161-
extractGranularitySize *= extractOp.getSourceVectorType().getShape()[k];
162-
++k;
237+
extractStridedSliceOp,
238+
"extract_strided_slice with strides != 1 not supported");
163239
}
164-
// Get total number of extracted slices.
165-
int64_t nExtractedSlices = 1;
166-
for (Attribute size : sizes) {
167-
nExtractedSlices *= cast<IntegerAttr>(size).getInt();
240+
241+
FailureOr<SmallVector<int64_t>> offsets =
242+
intsFromArrayAttr(extractStridedSliceOp.getOffsets());
243+
if (failed(offsets)) {
244+
return rewriter.notifyMatchFailure(extractStridedSliceOp,
245+
"failed to get integer offsets");
168246
}
169-
// Compute the strides of the source vector considering first k dimensions.
170-
llvm::SmallVector<int64_t, 4> sourceStrides(kD, extractGranularitySize);
171-
for (int i = kD - 2; i >= 0; --i) {
172-
sourceStrides[i] = sourceStrides[i + 1] *
173-
extractOp.getSourceVectorType().getShape()[i + 1];
247+
248+
ArrayRef<int64_t> inputShape =
249+
extractStridedSliceOp.getSourceVectorType().getShape();
250+
251+
ArrayRef<int64_t> outputShape = extractStridedSliceOp.getType().getShape();
252+
253+
SmallVector<int64_t> indices = getStridedSliceInsertionIndices(
254+
outputShape, inputShape, offsets.value());
255+
256+
Value srcVector = adaptor.getVector();
257+
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
258+
extractStridedSliceOp, flatOutputType, srcVector, srcVector, indices);
259+
return success();
260+
}
261+
};
262+
263+
/// This pattern converts a vector.insert_strided_slice operation into a
264+
/// vector.shuffle operation that has rank-1 (linearized) operands and result.
265+
///
266+
/// For example, the following:
267+
/// ```
268+
/// %0 = vector.insert_strided_slice %to_store, %into
269+
/// {offsets = [1, 0, 0, 0], strides = [1, 1]}
270+
/// : vector<2x2xi8> into vector<2x1x3x2xi8>
271+
/// ```
272+
///
273+
/// is converted to
274+
/// ```
275+
/// %to_store_1d
276+
/// = vector.shape_cast %to_store : vector<2x2xi8> to vector<4xi8>
277+
/// %into_1d = vector.shape_cast %into : vector<2x1x3x2xi8> to vector<12xi8>
278+
/// %out_1d = vector.shuffle %into_1d, %to_store_1d [ shuffle_indices_1d ]
279+
/// %out_nd = vector.shape_cast %out_1d : vector<12xi8> to vector<2x1x3x2xi8>
280+
/// ```
281+
///
282+
/// where shuffle_indices_1d in this case is
283+
/// [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 10, 11].
284+
/// ^^^^^^^^^^^^^^
285+
/// to_store_1d
286+
///
287+
struct LinearizeVectorInsertStridedSlice final
288+
: public mlir::OpConversionPattern<mlir::vector::InsertStridedSliceOp> {
289+
using OpConversionPattern::OpConversionPattern;
290+
LinearizeVectorInsertStridedSlice(const TypeConverter &typeConverter,
291+
MLIRContext *context,
292+
PatternBenefit benefit = 1)
293+
: OpConversionPattern(typeConverter, context, benefit) {}
294+
295+
LogicalResult
296+
matchAndRewrite(vector::InsertStridedSliceOp insertStridedSliceOp,
297+
OpAdaptor adaptor,
298+
ConversionPatternRewriter &rewriter) const override {
299+
300+
// Expect a legalization failure if the strides are not all 1 (if ever the
301+
// verifier for insert_strided_slice allows non-1 strides).
302+
if (!stridesAllOne(insertStridedSliceOp)) {
303+
return rewriter.notifyMatchFailure(
304+
insertStridedSliceOp,
305+
"insert_strided_slice with strides != 1 not supported");
174306
}
175-
// Final shuffle indices has nExtractedSlices * extractGranularitySize
176-
// elements.
177-
llvm::SmallVector<int64_t, 4> indices(nExtractedSlices *
178-
extractGranularitySize);
179-
// Compute the strides of the extracted kD vector.
180-
llvm::SmallVector<int64_t, 4> extractedStrides(kD, 1);
181-
// Compute extractedStrides.
182-
for (int i = kD - 2; i >= 0; --i) {
183-
extractedStrides[i] =
184-
extractedStrides[i + 1] * cast<IntegerAttr>(sizes[i + 1]).getInt();
307+
308+
VectorType inputType = insertStridedSliceOp.getValueToStore().getType();
309+
ArrayRef<int64_t> inputShape = inputType.getShape();
310+
311+
VectorType outputType = insertStridedSliceOp.getType();
312+
ArrayRef<int64_t> outputShape = outputType.getShape();
313+
int64_t nOutputElements = outputType.getNumElements();
314+
315+
FailureOr<SmallVector<int64_t>> offsets =
316+
intsFromArrayAttr(insertStridedSliceOp.getOffsets());
317+
if (failed(offsets)) {
318+
return rewriter.notifyMatchFailure(insertStridedSliceOp,
319+
"failed to get integer offsets");
185320
}
186-
// Iterate over all extracted slices from 0 to nExtractedSlices - 1
187-
// and compute the multi-dimensional index and the corresponding linearized
188-
// index within the source vector.
189-
for (int64_t i = 0; i < nExtractedSlices; ++i) {
190-
int64_t index = i;
191-
// Compute the corresponding multi-dimensional index.
192-
llvm::SmallVector<int64_t, 4> multiDimIndex(kD, 0);
193-
for (int64_t j = 0; j < kD; ++j) {
194-
multiDimIndex[j] = (index / extractedStrides[j]);
195-
index -= multiDimIndex[j] * extractedStrides[j];
196-
}
197-
// Compute the corresponding linearized index in the source vector
198-
// i.e. shift the multiDimIndex by the offsets.
199-
int64_t linearizedIndex = 0;
200-
for (int64_t j = 0; j < kD; ++j) {
201-
linearizedIndex +=
202-
(cast<IntegerAttr>(offsets[j]).getInt() + multiDimIndex[j]) *
203-
sourceStrides[j];
204-
}
205-
// Fill the indices array form linearizedIndex to linearizedIndex +
206-
// extractGranularitySize.
207-
for (int64_t j = 0; j < extractGranularitySize; ++j) {
208-
indices[i * extractGranularitySize + j] = linearizedIndex + j;
209-
}
321+
SmallVector<int64_t> sliceIndices = getStridedSliceInsertionIndices(
322+
inputShape, outputShape, offsets.value());
323+
324+
SmallVector<int64_t> indices(nOutputElements);
325+
std::iota(indices.begin(), indices.end(), 0);
326+
for (auto [index, sliceIndex] : llvm::enumerate(sliceIndices)) {
327+
indices[sliceIndex] = index + nOutputElements;
210328
}
211-
// Perform a shuffle to extract the kD vector.
212-
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
213-
extractOp, dstType, srcVector, srcVector, indices);
329+
330+
Value flatToStore = adaptor.getValueToStore();
331+
Value flatDest = adaptor.getDest();
332+
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(insertStridedSliceOp,
333+
flatDest.getType(), flatDest,
334+
flatToStore, indices);
214335
return success();
215336
}
216337
};
@@ -296,7 +417,7 @@ struct LinearizeVectorExtract final
296417
// Skip if result is not a vector type
297418
if (!isa<VectorType>(extractOp.getType()))
298419
return rewriter.notifyMatchFailure(extractOp,
299-
"scalar extract is not supported.");
420+
"scalar extract not supported");
300421
Type dstTy = getTypeConverter()->convertType(extractOp.getType());
301422
assert(dstTy && "expected 1-D vector type");
302423

@@ -453,8 +574,8 @@ struct LinearizeVectorSplat final
453574
static bool isNotLinearizableBecauseScalable(Operation *op) {
454575

455576
bool unsupported =
456-
isa<vector::ExtractStridedSliceOp, vector::ExtractOp, vector::InsertOp>(
457-
op);
577+
isa<vector::ExtractStridedSliceOp, vector::InsertStridedSliceOp,
578+
vector::ExtractOp, vector::InsertOp>(op);
458579
if (!unsupported)
459580
return false;
460581

@@ -539,6 +660,7 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
539660
const TypeConverter &typeConverter, const ConversionTarget &target,
540661
RewritePatternSet &patterns) {
541662
patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract,
542-
LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
543-
typeConverter, patterns.getContext());
663+
LinearizeVectorInsert, LinearizeVectorExtractStridedSlice,
664+
LinearizeVectorInsertStridedSlice>(typeConverter,
665+
patterns.getContext());
544666
}

0 commit comments

Comments
 (0)