-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][vector] Move extract_strided_slice canonicalization to folding #135676
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3717,6 +3717,58 @@ OpFoldResult ExtractStridedSliceOp::fold(FoldAdaptor adaptor) { | |
return getVector(); | ||
if (succeeded(foldExtractStridedOpFromInsertChain(*this))) | ||
return getResult(); | ||
|
||
// All subsequent successful folds require a constant input. | ||
Attribute foldInput = adaptor.getVector(); | ||
if (!foldInput) | ||
return {}; | ||
|
||
// ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp. | ||
if (auto splat = llvm::dyn_cast<SplatElementsAttr>(foldInput)) | ||
DenseElementsAttr::get(getType(), splat.getSplatValue<Attribute>()); | ||
|
||
// ExtractStridedSliceOp(non-splat ConstantOp) -> ConstantOp. | ||
if (auto dense = llvm::dyn_cast<DenseElementsAttr>(foldInput)) { | ||
// TODO: Handle non-unit strides when they become available. | ||
if (hasNonUnitStrides()) | ||
return {}; | ||
|
||
VectorType sourceVecTy = getSourceVectorType(); | ||
ArrayRef<int64_t> sourceShape = sourceVecTy.getShape(); | ||
SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape); | ||
|
||
VectorType sliceVecTy = getType(); | ||
ArrayRef<int64_t> sliceShape = sliceVecTy.getShape(); | ||
int64_t rank = sliceVecTy.getRank(); | ||
|
||
// Expand offsets and sizes to match the vector rank. | ||
SmallVector<int64_t, 4> offsets(rank, 0); | ||
copy(getI64SubArray(getOffsets()), offsets.begin()); | ||
|
||
SmallVector<int64_t, 4> sizes(sourceShape); | ||
copy(getI64SubArray(getSizes()), sizes.begin()); | ||
|
||
// Calculate the slice elements by enumerating all slice positions and | ||
// linearizing them. The enumeration order is lexicographic which yields a | ||
// sequence of monotonically increasing linearized position indices. | ||
const auto denseValuesBegin = dense.value_begin<Attribute>(); | ||
SmallVector<Attribute> sliceValues; | ||
sliceValues.reserve(sliceVecTy.getNumElements()); | ||
SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end()); | ||
do { | ||
int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides); | ||
assert(linearizedPosition < sourceVecTy.getNumElements() && | ||
"Invalid index"); | ||
sliceValues.push_back(*(denseValuesBegin + linearizedPosition)); | ||
} while ( | ||
succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets))); | ||
|
||
assert(static_cast<int64_t>(sliceValues.size()) == | ||
sliceVecTy.getNumElements() && | ||
"Invalid number of slice elements"); | ||
return DenseElementsAttr::get(sliceVecTy, sliceValues); | ||
} | ||
|
||
return {}; | ||
} | ||
|
||
|
@@ -3781,98 +3833,6 @@ class StridedSliceConstantMaskFolder final | |
} | ||
}; | ||
|
||
// Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp. | ||
class StridedSliceSplatConstantFolder final | ||
: public OpRewritePattern<ExtractStridedSliceOp> { | ||
public: | ||
using OpRewritePattern::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp, | ||
PatternRewriter &rewriter) const override { | ||
// Return if 'ExtractStridedSliceOp' operand is not defined by a splat | ||
// ConstantOp. | ||
Value sourceVector = extractStridedSliceOp.getVector(); | ||
Attribute vectorCst; | ||
if (!matchPattern(sourceVector, m_Constant(&vectorCst))) | ||
return failure(); | ||
|
||
auto splat = llvm::dyn_cast<SplatElementsAttr>(vectorCst); | ||
if (!splat) | ||
return failure(); | ||
|
||
auto newAttr = SplatElementsAttr::get(extractStridedSliceOp.getType(), | ||
splat.getSplatValue<Attribute>()); | ||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp, | ||
newAttr); | ||
return success(); | ||
} | ||
}; | ||
|
||
// Pattern to rewrite a ExtractStridedSliceOp(non-splat ConstantOp) -> | ||
// ConstantOp. | ||
class StridedSliceNonSplatConstantFolder final | ||
: public OpRewritePattern<ExtractStridedSliceOp> { | ||
public: | ||
using OpRewritePattern::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp, | ||
PatternRewriter &rewriter) const override { | ||
// Return if 'ExtractStridedSliceOp' operand is not defined by a non-splat | ||
// ConstantOp. | ||
Value sourceVector = extractStridedSliceOp.getVector(); | ||
Attribute vectorCst; | ||
if (!matchPattern(sourceVector, m_Constant(&vectorCst))) | ||
return failure(); | ||
|
||
// The splat case is handled by `StridedSliceSplatConstantFolder`. | ||
auto dense = llvm::dyn_cast<DenseElementsAttr>(vectorCst); | ||
if (!dense || dense.isSplat()) | ||
return failure(); | ||
|
||
// TODO: Handle non-unit strides when they become available. | ||
if (extractStridedSliceOp.hasNonUnitStrides()) | ||
return failure(); | ||
|
||
auto sourceVecTy = llvm::cast<VectorType>(sourceVector.getType()); | ||
ArrayRef<int64_t> sourceShape = sourceVecTy.getShape(); | ||
SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape); | ||
|
||
VectorType sliceVecTy = extractStridedSliceOp.getType(); | ||
ArrayRef<int64_t> sliceShape = sliceVecTy.getShape(); | ||
int64_t sliceRank = sliceVecTy.getRank(); | ||
|
||
// Expand offsets and sizes to match the vector rank. | ||
SmallVector<int64_t, 4> offsets(sliceRank, 0); | ||
copy(getI64SubArray(extractStridedSliceOp.getOffsets()), offsets.begin()); | ||
|
||
SmallVector<int64_t, 4> sizes(sourceShape); | ||
copy(getI64SubArray(extractStridedSliceOp.getSizes()), sizes.begin()); | ||
|
||
// Calculate the slice elements by enumerating all slice positions and | ||
// linearizing them. The enumeration order is lexicographic which yields a | ||
// sequence of monotonically increasing linearized position indices. | ||
auto denseValuesBegin = dense.value_begin<Attribute>(); | ||
SmallVector<Attribute> sliceValues; | ||
sliceValues.reserve(sliceVecTy.getNumElements()); | ||
SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end()); | ||
do { | ||
int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides); | ||
assert(linearizedPosition < sourceVecTy.getNumElements() && | ||
"Invalid index"); | ||
sliceValues.push_back(*(denseValuesBegin + linearizedPosition)); | ||
} while ( | ||
succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets))); | ||
|
||
assert(static_cast<int64_t>(sliceValues.size()) == | ||
sliceVecTy.getNumElements() && | ||
"Invalid number of slice elements"); | ||
auto newAttr = DenseElementsAttr::get(sliceVecTy, sliceValues); | ||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp, | ||
newAttr); | ||
return success(); | ||
} | ||
}; | ||
|
||
// Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to | ||
// BroadcastOp(ExtractStrideSliceOp). | ||
class StridedSliceBroadcast final | ||
|
@@ -4016,8 +3976,7 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns( | |
RewritePatternSet &results, MLIRContext *context) { | ||
// Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) -> | ||
// ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp. | ||
results.add<StridedSliceConstantMaskFolder, StridedSliceSplatConstantFolder, | ||
StridedSliceNonSplatConstantFolder, StridedSliceBroadcast, | ||
results.add<StridedSliceConstantMaskFolder, StridedSliceBroadcast, | ||
StridedSliceSplat, ContiguousExtractStridedSliceToExtract>( | ||
context); | ||
} | ||
|
@@ -5657,10 +5616,8 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) { | |
|
||
// shape_cast(constant) -> constant | ||
if (auto splatAttr = | ||
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) { | ||
return DenseElementsAttr::get(resultType, | ||
splatAttr.getSplatValue<Attribute>()); | ||
} | ||
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource())) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: drop There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. grep -r " dyn_cast" lib/Dialect/Vector/IR/VectorOps.cpp | wc -l |
||
return splatAttr.reshape(getType()); | ||
|
||
// shape_cast(poison) -> poison | ||
if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) { | ||
|
@@ -6004,10 +5961,9 @@ void vector::TransposeOp::build(OpBuilder &builder, OperationState &result, | |
|
||
OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) { | ||
// Eliminate splat constant transpose ops. | ||
if (auto attr = | ||
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getVector())) | ||
if (attr.isSplat()) | ||
return attr.reshape(getResultVectorType()); | ||
if (auto splat = | ||
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getVector())) | ||
return splat.reshape(getResultVectorType()); | ||
|
||
// Eliminate identity transpose ops. This happens when the dimensions of the | ||
// input vector remain in their original order after the transpose operation. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: move this long folder to a function and follow the pattern:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will do.