Skip to content

Commit af5c471

Browse files
authored
[mlir][Vector] Add vector.extract(vector.shuffle) folder (#115105)
This PR adds a folder for extracting an element from a vector shuffle. It turns something like: ``` %shuffle = vector.shuffle %a, %b [0, 8, 7, 15] : vector<8xf32>, vector<8xf32> %extract = vector.extract %shuffle[3] : f32 from vector<4xf32> ``` into: ``` %extract = vector.extract %b[7] : f32 from vector<8xf32> ```
1 parent 30d8000 commit af5c471

File tree

2 files changed

+61
-0
lines changed

2 files changed

+61
-0
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1705,6 +1705,47 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
17051705
return extractOp.getResult();
17061706
}
17071707

1708+
/// Fold extractOp coming from ShuffleOp.
1709+
///
1710+
/// Example:
1711+
///
1712+
/// %shuffle = vector.shuffle %a, %b [0, 8, 7, 15]
1713+
/// : vector<8xf32>, vector<8xf32>
1714+
/// %extract = vector.extract %shuffle[3] : f32 from vector<4xf32>
1715+
/// ->
1716+
/// %extract = vector.extract %b[7] : f32 from vector<8xf32>
1717+
///
1718+
static Value foldExtractFromShuffle(ExtractOp extractOp) {
1719+
// Dynamic positions are not folded as the resulting code would be more
1720+
// complex than the input code.
1721+
if (extractOp.hasDynamicPosition())
1722+
return Value();
1723+
1724+
auto shuffleOp = extractOp.getVector().getDefiningOp<ShuffleOp>();
1725+
if (!shuffleOp)
1726+
return Value();
1727+
1728+
// TODO: 0-D or multi-dimensional vectors not supported yet.
1729+
if (shuffleOp.getResultVectorType().getRank() != 1)
1730+
return Value();
1731+
1732+
int64_t inputVecSize = shuffleOp.getV1().getType().getShape()[0];
1733+
auto shuffleMask = shuffleOp.getMask();
1734+
int64_t extractIdx = extractOp.getStaticPosition()[0];
1735+
int64_t shuffleIdx = shuffleMask[extractIdx];
1736+
1737+
// Find the shuffled vector to extract from based on the shuffle index.
1738+
if (shuffleIdx < inputVecSize) {
1739+
extractOp.setOperand(0, shuffleOp.getV1());
1740+
extractOp.setStaticPosition({shuffleIdx});
1741+
} else {
1742+
extractOp.setOperand(0, shuffleOp.getV2());
1743+
extractOp.setStaticPosition({shuffleIdx - inputVecSize});
1744+
}
1745+
1746+
return extractOp.getResult();
1747+
}
1748+
17081749
// Fold extractOp with source coming from ShapeCast op.
17091750
static Value foldExtractFromShapeCast(ExtractOp extractOp) {
17101751
// TODO: Canonicalization for dynamic position not implemented yet.
@@ -1953,6 +1994,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor) {
19531994
return res;
19541995
if (auto res = foldExtractFromBroadcast(*this))
19551996
return res;
1997+
if (auto res = foldExtractFromShuffle(*this))
1998+
return res;
19561999
if (auto res = foldExtractFromShapeCast(*this))
19572000
return res;
19582001
if (auto val = foldExtractFromExtractStrided(*this))

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,24 @@ func.func @fold_extract_broadcast(%a : vector<1xf32>) -> vector<8xf32> {
740740
%r = vector.extract %b[0] : vector<8xf32> from vector<1x8xf32>
741741
return %r : vector<8xf32>
742742
}
743+
// -----
744+
745+
// CHECK-LABEL: @fold_extract_shuffle
746+
// CHECK-SAME: %[[A:.*]]: vector<8xf32>, %[[B:.*]]: vector<8xf32>
747+
// CHECK-NOT: vector.shuffle
748+
// CHECK: vector.extract %[[A]][0] : f32 from vector<8xf32>
749+
// CHECK: vector.extract %[[B]][0] : f32 from vector<8xf32>
750+
// CHECK: vector.extract %[[A]][7] : f32 from vector<8xf32>
751+
// CHECK: vector.extract %[[B]][7] : f32 from vector<8xf32>
752+
func.func @fold_extract_shuffle(%a : vector<8xf32>, %b : vector<8xf32>)
753+
-> (f32, f32, f32, f32) {
754+
%shuffle = vector.shuffle %a, %b [0, 8, 7, 15] : vector<8xf32>, vector<8xf32>
755+
%e0 = vector.extract %shuffle[0] : f32 from vector<4xf32>
756+
%e1 = vector.extract %shuffle[1] : f32 from vector<4xf32>
757+
%e2 = vector.extract %shuffle[2] : f32 from vector<4xf32>
758+
%e3 = vector.extract %shuffle[3] : f32 from vector<4xf32>
759+
return %e0, %e1, %e2, %e3 : f32, f32, f32, f32
760+
}
743761

744762
// -----
745763

0 commit comments

Comments
 (0)