Skip to content

Commit d3e1fd6

Browse files
[mlir][LLVM] Improve llvm.extractvalue folder (#136861)
Continue the traversal on the SSA chain of the inserted value for additional folding opportunities.
1 parent 98eb476 commit d3e1fd6

File tree

2 files changed

+55
-8
lines changed

2 files changed

+55
-8
lines changed

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1885,11 +1885,40 @@ OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
18851885

18861886
auto insertValueOp = getContainer().getDefiningOp<InsertValueOp>();
18871887
OpFoldResult result = {};
1888+
ArrayRef<int64_t> extractPos = getPosition();
1889+
bool switchedToInsertedValue = false;
18881890
while (insertValueOp) {
1889-
if (getPosition() == insertValueOp.getPosition())
1891+
ArrayRef<int64_t> insertPos = insertValueOp.getPosition();
1892+
auto extractPosSize = extractPos.size();
1893+
auto insertPosSize = insertPos.size();
1894+
1895+
// Case 1: Exact match of positions.
1896+
if (extractPos == insertPos)
18901897
return insertValueOp.getValue();
1891-
unsigned min =
1892-
std::min(getPosition().size(), insertValueOp.getPosition().size());
1898+
1899+
// Case 2: Insert position is a prefix of extract position. Continue
1900+
// traversal with the inserted value. Example:
1901+
// ```
1902+
// %0 = llvm.insertvalue %arg1, %undef[0] : !llvm.struct<(i32, i32, i32)>
1903+
// %1 = llvm.insertvalue %arg2, %0[1] : !llvm.struct<(i32, i32, i32)>
1904+
// %2 = llvm.insertvalue %arg3, %1[2] : !llvm.struct<(i32, i32, i32)>
1905+
// %3 = llvm.insertvalue %2, %foo[0]
1906+
// : !llvm.struct<(struct<(i32, i32, i32)>, i64)>
1907+
// %4 = llvm.extractvalue %3[0, 0]
1908+
// : !llvm.struct<(struct<(i32, i32, i32)>, i64)>
1909+
// ```
1910+
// In the above example, %4 is folded to %arg1.
1911+
if (extractPosSize > insertPosSize &&
1912+
extractPos.take_front(insertPosSize) == insertPos) {
1913+
insertValueOp = insertValueOp.getValue().getDefiningOp<InsertValueOp>();
1914+
extractPos = extractPos.drop_front(insertPosSize);
1915+
switchedToInsertedValue = true;
1916+
continue;
1917+
}
1918+
1919+
// Case 3: Try to continue the traversal with the container value.
1920+
unsigned min = std::min(extractPosSize, insertPosSize);
1921+
18931922
// If one is fully prefix of the other, stop propagating back as it will
18941923
// miss dependencies. For instance, %3 should not fold to %f0 in the
18951924
// following example:
@@ -1900,15 +1929,17 @@ OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
19001929
// !llvm.array<4 x !llvm.array<4 x f32>>
19011930
// %3 = llvm.extractvalue %2[0, 0] : !llvm.array<4 x !llvm.array<4 x f32>>
19021931
// ```
1903-
if (getPosition().take_front(min) ==
1904-
insertValueOp.getPosition().take_front(min))
1932+
if (extractPos.take_front(min) == insertPos.take_front(min))
19051933
return result;
1906-
19071934
// If neither a prefix, nor the exact position, we can extract out of the
19081935
// value being inserted into. Moreover, we can try again if that operand
19091936
// is itself an insertvalue expression.
1910-
getContainerMutable().assign(insertValueOp.getContainer());
1911-
result = getResult();
1937+
if (!switchedToInsertedValue) {
1938+
// Do not swap out the container operand if we decided earlier to
1939+
// continue the traversal with the inserted value (Case 2).
1940+
getContainerMutable().assign(insertValueOp.getContainer());
1941+
result = getResult();
1942+
}
19121943
insertValueOp = insertValueOp.getContainer().getDefiningOp<InsertValueOp>();
19131944
}
19141945
return result;

mlir/test/Dialect/LLVMIR/canonicalize.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,22 @@ llvm.func @fold_extractvalue() -> i32 {
5757

5858
// -----
5959

60+
// CHECK-LABEL: fold_extractvalue(
61+
// CHECK-SAME: %[[arg1:.*]]: i32, %[[arg2:.*]]: i32, %[[arg3:.*]]: i32)
62+
// CHECK-NEXT: llvm.return %[[arg1]] : i32
63+
llvm.func @fold_extractvalue(%arg1: i32, %arg2: i32, %arg3: i32) -> i32{
64+
%3 = llvm.mlir.undef : !llvm.struct<(struct<(i32, i32, i32)>, struct<(i32, i32)>)>
65+
%5 = llvm.mlir.undef : !llvm.struct<(i32, i32, i32)>
66+
%6 = llvm.insertvalue %arg1, %5[0] : !llvm.struct<(i32, i32, i32)>
67+
%7 = llvm.insertvalue %arg1, %6[1] : !llvm.struct<(i32, i32, i32)>
68+
%8 = llvm.insertvalue %arg1, %7[2] : !llvm.struct<(i32, i32, i32)>
69+
%11 = llvm.insertvalue %8, %3[0] : !llvm.struct<(struct<(i32, i32, i32)>, struct<(i32, i32)>)>
70+
%13 = llvm.extractvalue %11[0, 0] : !llvm.struct<(struct<(i32, i32, i32)>, struct<(i32, i32)>)>
71+
llvm.return %13 : i32
72+
}
73+
74+
// -----
75+
6076
// CHECK-LABEL: no_fold_extractvalue
6177
llvm.func @no_fold_extractvalue(%arr: !llvm.array<4 x f32>) -> f32 {
6278
%f0 = arith.constant 0.0 : f32

0 commit comments

Comments
 (0)