-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][vector] linearize vector.insert_strided_slice (flatten to vector.shuffle) #138725
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 4 commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
d708a63
linearize with shuffle
newling c578d7b
name improvement
newling cd894d8
scalable blacklist
newling bcc4391
address some of reviewer's comments
newling bb7a04d
simplify std::swap logic as per review comment
newling 34f3842
notification failure rather than assert
newling e59ca0d
Merge branch 'main' into linearize_insert_strided_slice
newling File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -109,17 +109,108 @@ struct LinearizeVectorizable final | |
} | ||
}; | ||
|
||
/// This pattern converts the ExtractStridedSliceOp into a ShuffleOp that works | ||
/// on a linearized vector. | ||
/// Following, | ||
template <typename TOp> | ||
static bool stridesAllOne(TOp op) { | ||
static_assert( | ||
std::is_same_v<TOp, vector::ExtractStridedSliceOp> || | ||
std::is_same_v<TOp, vector::InsertStridedSliceOp>, | ||
"expected vector.extract_strided_slice or vector.insert_strided_slice"); | ||
ArrayAttr strides = op.getStrides(); | ||
return llvm::all_of( | ||
strides, [](auto stride) { return isConstantIntValue(stride, 1); }); | ||
} | ||
|
||
/// Convert an array of attributes into a vector of integers, if possible. | ||
static FailureOr<SmallVector<int64_t>> intsFromArrayAttr(ArrayAttr attrs) { | ||
if (!attrs) | ||
return failure(); | ||
SmallVector<int64_t> ints; | ||
ints.reserve(attrs.size()); | ||
for (auto attr : attrs) { | ||
if (auto intAttr = dyn_cast<IntegerAttr>(attr)) { | ||
ints.push_back(intAttr.getInt()); | ||
} else { | ||
return failure(); | ||
} | ||
} | ||
return ints; | ||
} | ||
|
||
/// Consider inserting a vector of shape `small` into a vector of shape `large`, | ||
/// at position `offsets`: this function enumeratates all the indices in `large` | ||
/// that are written to. The enumeration is with row-major ordering. | ||
/// | ||
/// Example: insert a 1x2 vector into a 4x5 vector at position (1,3). The 2 | ||
/// positions written to are (1,3) and (1,4), which have linearized indices 8 | ||
/// and 9. So [8,9] is returned. | ||
/// | ||
/// The length of the returned vector is equal to the number of elements in | ||
/// the shape `small` (i.e. the product of dimensions of `small`). | ||
SmallVector<int64_t> static getStridedSliceInsertionIndices( | ||
ArrayRef<int64_t> small, ArrayRef<int64_t> large, | ||
ArrayRef<int64_t> offsets) { | ||
|
||
// Example of alignment between, `large`, `small` and `offsets`: | ||
// large = 4, 5, 6, 7, 8 | ||
// small = 1, 6, 7, 8 | ||
// offsets = 2, 3, 0 | ||
// | ||
// `offsets` has implicit trailing 0s, `small` has implicit leading 1s. | ||
assert((large.size() >= small.size()) && | ||
"rank of 'large' cannot be lower than rank of 'small'"); | ||
assert((large.size() >= offsets.size()) && | ||
"rank of 'large' cannot be lower than the number of offsets"); | ||
unsigned delta = large.size() - small.size(); | ||
unsigned nOffsets = offsets.size(); | ||
auto getSmall = [&](int64_t i) { return i >= delta ? small[i - delta] : 1; }; | ||
auto getOffset = [&](int64_t i) { return i < nOffsets ? offsets[i] : 0; }; | ||
|
||
// Using 2 vectors of indices, at each iteration populate the updated set of | ||
// indices based on the old set of indices, and the size of the small vector | ||
// in the current iteration. | ||
SmallVector<int64_t> indices{0}; | ||
SmallVector<int64_t> nextIndices; | ||
int64_t stride = 1; | ||
for (int i = large.size() - 1; i >= 0; --i) { | ||
auto currentSize = indices.size(); | ||
auto smallSize = getSmall(i); | ||
auto nextSize = currentSize * smallSize; | ||
nextIndices.resize(nextSize); | ||
int64_t *base = nextIndices.begin(); | ||
int64_t offset = getOffset(i) * stride; | ||
for (int j = 0; j < smallSize; ++j) { | ||
for (uint64_t k = 0; k < currentSize; ++k) { | ||
base[k] = indices[k] + offset; | ||
} | ||
offset += stride; | ||
base += currentSize; | ||
} | ||
stride *= large[i]; | ||
std::swap(indices, nextIndices); | ||
nextIndices.clear(); | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didn't have time to review this function. |
||
return indices; | ||
} | ||
|
||
/// This pattern converts a vector.extract_strided_slice operation into a | ||
/// vector.shuffle operation that has a rank-1 (linearized) operand and result. | ||
/// | ||
/// For example, the following: | ||
/// | ||
/// ``` | ||
/// vector.extract_strided_slice %source | ||
/// { offsets = [..], strides = [..], sizes = [..] } | ||
/// ``` | ||
/// | ||
/// is converted to : | ||
/// ``` | ||
/// %source_1d = vector.shape_cast %source | ||
/// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ] | ||
/// %out_nd = vector.shape_cast %out_1d | ||
/// `shuffle_indices_1d` is computed using the offsets and sizes of the | ||
/// extraction. | ||
/// %out_1d = vector.shuffle %source_1d, %source_1d [ shuffle_indices_1d ] | ||
/// %out_nd = vector.shape_cast %out_1d | ||
/// ``` | ||
/// | ||
/// `shuffle_indices_1d` is computed using the offsets and sizes of the original | ||
/// vector.extract_strided_slice operation. | ||
struct LinearizeVectorExtractStridedSlice final | ||
: public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> { | ||
using OpConversionPattern::OpConversionPattern; | ||
|
@@ -129,88 +220,107 @@ struct LinearizeVectorExtractStridedSlice final | |
: OpConversionPattern(typeConverter, context, benefit) {} | ||
|
||
LogicalResult | ||
matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor, | ||
matchAndRewrite(vector::ExtractStridedSliceOp extractStridedSliceOp, | ||
OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
VectorType dstType = | ||
getTypeConverter()->convertType<VectorType>(extractOp.getType()); | ||
assert(dstType && "vector type destination expected."); | ||
if (extractOp.getVector().getType().isScalable() || dstType.isScalable()) | ||
return rewriter.notifyMatchFailure(extractOp, | ||
"scalable vectors are not supported."); | ||
|
||
ArrayAttr offsets = extractOp.getOffsets(); | ||
ArrayAttr sizes = extractOp.getSizes(); | ||
ArrayAttr strides = extractOp.getStrides(); | ||
if (!isConstantIntValue(strides[0], 1)) | ||
return rewriter.notifyMatchFailure( | ||
extractOp, "Strided slice with stride != 1 is not supported."); | ||
Value srcVector = adaptor.getVector(); | ||
// If kD offsets are specified for nD source vector (n > k), the granularity | ||
// of the extraction is greater than 1. In this case last (n-k) dimensions | ||
// form the extraction granularity. | ||
// Example : | ||
// vector.extract_strided_slice %src { | ||
// offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : | ||
// vector<4x8x8xf32> to vector<2x2x8xf32> | ||
// Here, extraction granularity is 8. | ||
int64_t extractGranularitySize = 1; | ||
int64_t nD = extractOp.getSourceVectorType().getRank(); | ||
int64_t kD = (int64_t)offsets.size(); | ||
int64_t k = kD; | ||
while (k < nD) { | ||
extractGranularitySize *= extractOp.getSourceVectorType().getShape()[k]; | ||
++k; | ||
} | ||
// Get total number of extracted slices. | ||
int64_t nExtractedSlices = 1; | ||
for (Attribute size : sizes) { | ||
nExtractedSlices *= cast<IntegerAttr>(size).getInt(); | ||
} | ||
// Compute the strides of the source vector considering first k dimensions. | ||
llvm::SmallVector<int64_t, 4> sourceStrides(kD, extractGranularitySize); | ||
for (int i = kD - 2; i >= 0; --i) { | ||
sourceStrides[i] = sourceStrides[i + 1] * | ||
extractOp.getSourceVectorType().getShape()[i + 1]; | ||
VectorType flatOutputType = getTypeConverter()->convertType<VectorType>( | ||
extractStridedSliceOp.getType()); | ||
assert(flatOutputType && "vector type expected"); | ||
|
||
assert(stridesAllOne(extractStridedSliceOp) && | ||
"has extract_strided_slice's verifier not checked strides are 1?"); | ||
|
||
FailureOr<SmallVector<int64_t>> offsets = | ||
intsFromArrayAttr(extractStridedSliceOp.getOffsets()); | ||
if (failed(offsets)) { | ||
return rewriter.notifyMatchFailure(extractStridedSliceOp, | ||
"failed to get integer offsets"); | ||
} | ||
// Final shuffle indices has nExtractedSlices * extractGranularitySize | ||
// elements. | ||
llvm::SmallVector<int64_t, 4> indices(nExtractedSlices * | ||
extractGranularitySize); | ||
// Compute the strides of the extracted kD vector. | ||
llvm::SmallVector<int64_t, 4> extractedStrides(kD, 1); | ||
// Compute extractedStrides. | ||
for (int i = kD - 2; i >= 0; --i) { | ||
extractedStrides[i] = | ||
extractedStrides[i + 1] * cast<IntegerAttr>(sizes[i + 1]).getInt(); | ||
|
||
ArrayRef<int64_t> inputShape = | ||
extractStridedSliceOp.getSourceVectorType().getShape(); | ||
|
||
ArrayRef<int64_t> outputShape = extractStridedSliceOp.getType().getShape(); | ||
|
||
SmallVector<int64_t> indices = getStridedSliceInsertionIndices( | ||
outputShape, inputShape, offsets.value()); | ||
|
||
Value srcVector = adaptor.getVector(); | ||
rewriter.replaceOpWithNewOp<vector::ShuffleOp>( | ||
extractStridedSliceOp, flatOutputType, srcVector, srcVector, indices); | ||
return success(); | ||
} | ||
}; | ||
|
||
/// This pattern converts a vector.insert_strided_slice operation into a | ||
/// vector.shuffle operation that has rank-1 (linearized) operands and result. | ||
/// | ||
/// For example, the following: | ||
/// ``` | ||
/// %0 = vector.insert_strided_slice %to_store, %into | ||
/// {offsets = [1, 0, 0, 0], strides = [1, 1]} | ||
/// : vector<2x2xi8> into vector<2x1x3x2xi8> | ||
/// ``` | ||
/// | ||
/// is converted to | ||
/// ``` | ||
/// %to_store_1d | ||
/// = vector.shape_cast %to_store : vector<2x2xi8> to vector<4xi8> | ||
/// %into_1d = vector.shape_cast %into : vector<2x1x3x2xi8> to vector<12xi8> | ||
/// %out_1d = vector.shuffle %into_1d, %to_store_1d [ shuffle_indices_1d ] | ||
/// %out_nd = vector.shape_cast %out_1d : vector<12xi8> to vector<2x1x3x2xi8> | ||
/// ``` | ||
/// | ||
/// where shuffle_indices_1d in this case is | ||
/// [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 10, 11]. | ||
/// ^^^^^^^^^^^^^^ | ||
/// to_store_1d | ||
/// | ||
struct LinearizeVectorInsertStridedSlice final | ||
: public mlir::OpConversionPattern<mlir::vector::InsertStridedSliceOp> { | ||
using OpConversionPattern::OpConversionPattern; | ||
LinearizeVectorInsertStridedSlice(const TypeConverter &typeConverter, | ||
MLIRContext *context, | ||
PatternBenefit benefit = 1) | ||
: OpConversionPattern(typeConverter, context, benefit) {} | ||
|
||
LogicalResult | ||
matchAndRewrite(vector::InsertStridedSliceOp insertStridedSliceOp, | ||
OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
|
||
// See InsertStridedSliceOp's verify method. | ||
assert(stridesAllOne(insertStridedSliceOp) && | ||
"has insert_strided_slice's verifier not checked strides are 1?"); | ||
|
||
VectorType inputType = insertStridedSliceOp.getValueToStore().getType(); | ||
ArrayRef<int64_t> inputShape = inputType.getShape(); | ||
|
||
VectorType outputType = insertStridedSliceOp.getType(); | ||
ArrayRef<int64_t> outputShape = outputType.getShape(); | ||
int64_t nOutputElements = outputType.getNumElements(); | ||
|
||
FailureOr<SmallVector<int64_t>> offsets = | ||
intsFromArrayAttr(insertStridedSliceOp.getOffsets()); | ||
if (failed(offsets)) { | ||
return rewriter.notifyMatchFailure(insertStridedSliceOp, | ||
"failed to get integer offsets"); | ||
} | ||
// Iterate over all extracted slices from 0 to nExtractedSlices - 1 | ||
// and compute the multi-dimensional index and the corresponding linearized | ||
// index within the source vector. | ||
for (int64_t i = 0; i < nExtractedSlices; ++i) { | ||
int64_t index = i; | ||
// Compute the corresponding multi-dimensional index. | ||
llvm::SmallVector<int64_t, 4> multiDimIndex(kD, 0); | ||
for (int64_t j = 0; j < kD; ++j) { | ||
multiDimIndex[j] = (index / extractedStrides[j]); | ||
index -= multiDimIndex[j] * extractedStrides[j]; | ||
} | ||
// Compute the corresponding linearized index in the source vector | ||
// i.e. shift the multiDimIndex by the offsets. | ||
int64_t linearizedIndex = 0; | ||
for (int64_t j = 0; j < kD; ++j) { | ||
linearizedIndex += | ||
(cast<IntegerAttr>(offsets[j]).getInt() + multiDimIndex[j]) * | ||
sourceStrides[j]; | ||
} | ||
// Fill the indices array form linearizedIndex to linearizedIndex + | ||
// extractGranularitySize. | ||
for (int64_t j = 0; j < extractGranularitySize; ++j) { | ||
indices[i * extractGranularitySize + j] = linearizedIndex + j; | ||
} | ||
SmallVector<int64_t> sliceIndices = getStridedSliceInsertionIndices( | ||
inputShape, outputShape, offsets.value()); | ||
|
||
SmallVector<int64_t> indices(nOutputElements); | ||
std::iota(indices.begin(), indices.end(), 0); | ||
for (auto [index, sliceIndex] : llvm::enumerate(sliceIndices)) { | ||
indices[sliceIndex] = index + nOutputElements; | ||
} | ||
// Perform a shuffle to extract the kD vector. | ||
rewriter.replaceOpWithNewOp<vector::ShuffleOp>( | ||
extractOp, dstType, srcVector, srcVector, indices); | ||
|
||
Value flatToStore = adaptor.getValueToStore(); | ||
Value flatDest = adaptor.getDest(); | ||
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(insertStridedSliceOp, | ||
flatDest.getType(), flatDest, | ||
flatToStore, indices); | ||
return success(); | ||
} | ||
}; | ||
|
@@ -296,7 +406,7 @@ struct LinearizeVectorExtract final | |
// Skip if result is not a vector type | ||
if (!isa<VectorType>(extractOp.getType())) | ||
return rewriter.notifyMatchFailure(extractOp, | ||
"scalar extract is not supported."); | ||
"scalar extract not supported"); | ||
Type dstTy = getTypeConverter()->convertType(extractOp.getType()); | ||
assert(dstTy && "expected 1-D vector type"); | ||
|
||
|
@@ -453,8 +563,8 @@ struct LinearizeVectorSplat final | |
static bool isNotLinearizableBecauseScalable(Operation *op) { | ||
|
||
bool unsupported = | ||
isa<vector::ExtractStridedSliceOp, vector::ExtractOp, vector::InsertOp>( | ||
op); | ||
isa<vector::ExtractStridedSliceOp, vector::InsertStridedSliceOp, | ||
vector::ExtractOp, vector::InsertOp>(op); | ||
if (!unsupported) | ||
return false; | ||
|
||
|
@@ -539,6 +649,7 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( | |
const TypeConverter &typeConverter, const ConversionTarget &target, | ||
RewritePatternSet &patterns) { | ||
patterns.add<LinearizeVectorShuffle, LinearizeVectorExtract, | ||
LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>( | ||
typeConverter, patterns.getContext()); | ||
LinearizeVectorInsert, LinearizeVectorExtractStridedSlice, | ||
LinearizeVectorInsertStridedSlice>(typeConverter, | ||
patterns.getContext()); | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wonder if it's better to
resize
+clear
or just movenextIndices
decl inside the loop... The latter at least would make clearer thatnextIndices
doesn't have valid information at the beginning of the iteration... It wasn't obvious to me with theresize
andclear
approach...There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, done