-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][Vector] Add support for poison indices to Extract/IndexOp
#123488
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,7 @@ | |
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" | ||
#include "mlir/Dialect/MemRef/IR/MemRef.h" | ||
#include "mlir/Dialect/Tensor/IR/Tensor.h" | ||
#include "mlir/Dialect/UB/IR/UBOps.h" | ||
#include "mlir/Dialect/Utils/IndexingUtils.h" | ||
#include "mlir/Dialect/Utils/StructuredOpsUtils.h" | ||
#include "mlir/IR/AffineExpr.h" | ||
|
@@ -1274,6 +1275,13 @@ OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) { | |
return srcElements[posIdx]; | ||
} | ||
|
||
// Returns `true` if `index` is either within [0, maxIndex) or equal to | ||
// `poisonValue`. | ||
static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue, | ||
int64_t maxIndex) { | ||
return index == poisonValue || (index >= 0 && index < maxIndex); | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
// ExtractOp | ||
//===----------------------------------------------------------------------===// | ||
|
@@ -1355,11 +1363,12 @@ LogicalResult vector::ExtractOp::verify() { | |
for (auto [idx, pos] : llvm::enumerate(position)) { | ||
if (auto attr = dyn_cast<Attribute>(pos)) { | ||
int64_t constIdx = cast<IntegerAttr>(attr).getInt(); | ||
if (constIdx < 0 || constIdx >= getSourceVectorType().getDimSize(idx)) { | ||
if (!isValidPositiveIndexOrPoison( | ||
constIdx, kPoisonIndex, getSourceVectorType().getDimSize(idx))) { | ||
return emitOpError("expected position attribute #") | ||
<< (idx + 1) | ||
<< " to be a non-negative integer smaller than the " | ||
"corresponding vector dimension"; | ||
"corresponding vector dimension or poison (-1)"; | ||
} | ||
} | ||
} | ||
|
@@ -1977,12 +1986,26 @@ static Value foldScalarExtractFromFromElements(ExtractOp extractOp) { | |
return fromElementsOp.getElements()[flatIndex]; | ||
} | ||
|
||
OpFoldResult ExtractOp::fold(FoldAdaptor) { | ||
/// Fold an insert or extract operation into an poison value when a poison index | ||
/// is found at any dimension of the static position. | ||
static ub::PoisonAttr | ||
foldPoisonIndexInsertExtractOp(MLIRContext *context, | ||
ArrayRef<int64_t> staticPos, int64_t poisonVal) { | ||
if (!llvm::is_contained(staticPos, poisonVal)) | ||
return ub::PoisonAttr(); | ||
|
||
return ub::PoisonAttr::get(context); | ||
} | ||
|
||
OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) { | ||
// Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v. | ||
// Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type | ||
// mismatch). | ||
if (getNumIndices() == 0 && getVector().getType() == getResult().getType()) | ||
return getVector(); | ||
if (auto res = foldPoisonIndexInsertExtractOp( | ||
getContext(), adaptor.getStaticPosition(), kPoisonIndex)) | ||
return res; | ||
if (succeeded(foldExtractOpFromExtractChain(*this))) | ||
return getResult(); | ||
if (auto res = ExtractFromInsertTransposeChainState(*this).fold()) | ||
|
@@ -2249,6 +2272,21 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp, | |
resultType.getNumElements())); | ||
return success(); | ||
} | ||
|
||
/// Fold an insert or extract operation into an poison value when a poison index | ||
/// is found at any dimension of the static position. | ||
template <typename OpTy> | ||
LogicalResult | ||
canonicalizePoisonIndexInsertExtractOp(OpTy op, PatternRewriter &rewriter) { | ||
if (auto poisonAttr = foldPoisonIndexInsertExtractOp( | ||
op.getContext(), op.getStaticPosition(), OpTy::kPoisonIndex)) { | ||
rewriter.replaceOpWithNewOp<ub::PoisonOp>(op, op.getType(), poisonAttr); | ||
return success(); | ||
} | ||
|
||
return failure(); | ||
} | ||
|
||
} // namespace | ||
|
||
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results, | ||
|
@@ -2257,6 +2295,7 @@ void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results, | |
ExtractOpFromBroadcast, ExtractOpFromCreateMask>(context); | ||
results.add(foldExtractFromShapeCastToShapeCast); | ||
results.add(foldExtractFromFromElements); | ||
results.add(canonicalizePoisonIndexInsertExtractOp<ExtractOp>); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I posted another comment about this, we probably don't need to duplicate something that's in the folder here. But no strong opinion, if you think we want to keep it here, we can, and then we can discuss in another pr if we want folders in the canonicalizer or not like this (other folders seem to also be here, which is weird to me). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current state, AFAICT, is that the canonicalizer does not apply foldings as it uses |
||
} | ||
|
||
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, | ||
|
@@ -2600,7 +2639,7 @@ LogicalResult ShuffleOp::verify() { | |
int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) + | ||
(v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0)); | ||
for (auto [idx, maskPos] : llvm::enumerate(mask)) { | ||
if (maskPos != kMaskPoisonValue && (maskPos < 0 || maskPos >= indexSize)) | ||
if (!isValidPositiveIndexOrPoison(maskPos, kPoisonIndex, indexSize)) | ||
return emitOpError("mask index #") << (idx + 1) << " out of range"; | ||
} | ||
return success(); | ||
|
@@ -2882,7 +2921,8 @@ LogicalResult InsertOp::verify() { | |
for (auto [idx, pos] : llvm::enumerate(position)) { | ||
if (auto attr = pos.dyn_cast<Attribute>()) { | ||
int64_t constIdx = cast<IntegerAttr>(attr).getInt(); | ||
if (constIdx < 0 || constIdx >= destVectorType.getDimSize(idx)) { | ||
if (!isValidPositiveIndexOrPoison(constIdx, kPoisonIndex, | ||
destVectorType.getDimSize(idx))) { | ||
return emitOpError("expected position attribute #") | ||
<< (idx + 1) | ||
<< " to be a non-negative integer smaller than the " | ||
|
@@ -3020,6 +3060,7 @@ void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results, | |
MLIRContext *context) { | ||
results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat, | ||
InsertOpConstantFolder>(context); | ||
results.add(canonicalizePoisonIndexInsertExtractOp<InsertOp>); | ||
} | ||
|
||
OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) { | ||
|
@@ -3028,6 +3069,10 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) { | |
// (type mismatch). | ||
if (getNumIndices() == 0 && getSourceType() == getType()) | ||
return getSource(); | ||
if (auto res = foldPoisonIndexInsertExtractOp( | ||
getContext(), adaptor.getStaticPosition(), kPoisonIndex)) | ||
return res; | ||
|
||
return {}; | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,4 +37,5 @@ add_mlir_library(MLIRTransforms | |
MLIRSideEffectInterfaces | ||
MLIRSupport | ||
MLIRTransformUtils | ||
MLIRUBDialect | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lazy me started with this approach to refactor and declare the poison value index for different operations (shuffle, insert, extract, ...), thinking that I would turn this into an interface eventually. Giving it another thought, I feel like using an interface with a
getPoisonIndexValue
method that returns-1
could be an overkill? WDYT?I was also wondering if the
OpFoldResult
implementation would have a place somewhere to accommodate this declaration. cc: @matthias-springerThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have an opinion just now, but will have in due course 😂 (when I start interacting with this more).
To me this is an implementation detail that can always be updated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No opinion here, but -1 is a very easy to misplace number that someone can write.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-1
is the value used by LLVM so I'm just sticking to that to prevent unnecessary conversion bugs. We use numeric_limits::min for dynamic shapes so no conflict there. I can't think of an alternative that would make a difference. Any negative number would lead to a verification error or UB so...