@@ -1885,11 +1885,40 @@ OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
1885
1885
1886
1886
auto insertValueOp = getContainer ().getDefiningOp <InsertValueOp>();
1887
1887
OpFoldResult result = {};
1888
+ ArrayRef<int64_t > extractPos = getPosition ();
1889
+ bool switchedToInsertedValue = false ;
1888
1890
while (insertValueOp) {
1889
- if (getPosition () == insertValueOp.getPosition ())
1891
+ ArrayRef<int64_t > insertPos = insertValueOp.getPosition ();
1892
+ auto extractPosSize = extractPos.size ();
1893
+ auto insertPosSize = insertPos.size ();
1894
+
1895
+ // Case 1: Exact match of positions.
1896
+ if (extractPos == insertPos)
1890
1897
return insertValueOp.getValue ();
1891
- unsigned min =
1892
- std::min (getPosition ().size (), insertValueOp.getPosition ().size ());
1898
+
1899
+ // Case 2: Insert position is a prefix of extract position. Continue
1900
+ // traversal with the inserted value. Example:
1901
+ // ```
1902
+ // %0 = llvm.insertvalue %arg1, %undef[0] : !llvm.struct<(i32, i32, i32)>
1903
+ // %1 = llvm.insertvalue %arg2, %0[1] : !llvm.struct<(i32, i32, i32)>
1904
+ // %2 = llvm.insertvalue %arg3, %1[2] : !llvm.struct<(i32, i32, i32)>
1905
+ // %3 = llvm.insertvalue %2, %foo[0]
1906
+ // : !llvm.struct<(struct<(i32, i32, i32)>, i64)>
1907
+ // %4 = llvm.extractvalue %3[0, 0]
1908
+ // : !llvm.struct<(struct<(i32, i32, i32)>, i64)>
1909
+ // ```
1910
+ // In the above example, %4 is folded to %arg1.
1911
+ if (extractPosSize > insertPosSize &&
1912
+ extractPos.take_front (insertPosSize) == insertPos) {
1913
+ insertValueOp = insertValueOp.getValue ().getDefiningOp <InsertValueOp>();
1914
+ extractPos = extractPos.drop_front (insertPosSize);
1915
+ switchedToInsertedValue = true ;
1916
+ continue ;
1917
+ }
1918
+
1919
+ // Case 3: Try to continue the traversal with the container value.
1920
+ unsigned min = std::min (extractPosSize, insertPosSize);
1921
+
1893
1922
// If one is fully prefix of the other, stop propagating back as it will
1894
1923
// miss dependencies. For instance, %3 should not fold to %f0 in the
1895
1924
// following example:
@@ -1900,15 +1929,17 @@ OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
1900
1929
// !llvm.array<4 x !llvm.array<4 x f32>>
1901
1930
// %3 = llvm.extractvalue %2[0, 0] : !llvm.array<4 x !llvm.array<4 x f32>>
1902
1931
// ```
1903
- if (getPosition ().take_front (min) ==
1904
- insertValueOp.getPosition ().take_front (min))
1932
+ if (extractPos.take_front (min) == insertPos.take_front (min))
1905
1933
return result;
1906
-
1907
1934
// If neither a prefix, nor the exact position, we can extract out of the
1908
1935
// value being inserted into. Moreover, we can try again if that operand
1909
1936
// is itself an insertvalue expression.
1910
- getContainerMutable ().assign (insertValueOp.getContainer ());
1911
- result = getResult ();
1937
+ if (!switchedToInsertedValue) {
1938
+ // Do not swap out the container operand if we decided earlier to
1939
+ // continue the traversal with the inserted value (Case 2).
1940
+ getContainerMutable ().assign (insertValueOp.getContainer ());
1941
+ result = getResult ();
1942
+ }
1912
1943
insertValueOp = insertValueOp.getContainer ().getDefiningOp <InsertValueOp>();
1913
1944
}
1914
1945
return result;
0 commit comments