Skip to content

Commit fd8bc37

Browse files
authored
[mlir][Vector][NFC] Run extractInsertFoldConstantOp earlier in the folder (#140814)
This PR moves `extractInsertFoldConstantOp` earlier in the folder lists of `vector.extract` and `vector.insert`. Many folders require having non-dynamic indices so `extractInsertFoldConstantOp` is a requirement for them to trigger.
1 parent 6c813e8 commit fd8bc37

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2143,11 +2143,16 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
21432143
// mismatch).
21442144
if (getNumIndices() == 0 && getVector().getType() == getResult().getType())
21452145
return getVector();
2146+
if (auto res = foldPoisonSrcExtractOp(adaptor.getVector()))
2147+
return res;
2148+
// Fold `arith.constant` indices into the `vector.extract` operation. Make
2149+
// sure that patterns requiring constant indices are added after this fold.
2150+
SmallVector<Value> operands = {getVector()};
2151+
if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
2152+
return val;
21462153
if (auto res = foldPoisonIndexInsertExtractOp(
21472154
getContext(), adaptor.getStaticPosition(), kPoisonIndex))
21482155
return res;
2149-
if (auto res = foldPoisonSrcExtractOp(adaptor.getVector()))
2150-
return res;
21512156
if (auto res = foldDenseElementsAttrSrcExtractOp(*this, adaptor.getVector()))
21522157
return res;
21532158
if (succeeded(foldExtractOpFromExtractChain(*this)))
@@ -2166,9 +2171,6 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
21662171
return val;
21672172
if (auto val = foldScalarExtractFromFromElements(*this))
21682173
return val;
2169-
SmallVector<Value> operands = {getVector()};
2170-
if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
2171-
return val;
21722174
return OpFoldResult();
21732175
}
21742176

@@ -3145,6 +3147,8 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
31453147
// (type mismatch).
31463148
if (getNumIndices() == 0 && getValueToStoreType() == getType())
31473149
return getValueToStore();
3150+
// Fold `arith.constant` indices into the `vector.insert` operation. Make
3151+
// sure that patterns requiring constant indices are added after this fold.
31483152
SmallVector<Value> operands = {getValueToStore(), getDest()};
31493153
if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
31503154
return val;

0 commit comments

Comments
 (0)