Skip to content

Commit c2e5142

Browse files
wsmosesDinistro
andauthored
[MLIR][LLVM] Fold extract of constant (#127927)
Co-authored-by: Christian Ulmann <[email protected]>
1 parent 1b610e6 commit c2e5142

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1932,6 +1932,18 @@ OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
19321932
getContainerMutable().set(extractValueOp.getContainer());
19331933
return getResult();
19341934
}
1935+
1936+
{
1937+
DenseElementsAttr constval;
1938+
matchPattern(getContainer(), m_Constant(&constval));
1939+
if (constval && constval.getElementType() == getType()) {
1940+
if (isa<SplatElementsAttr>(constval))
1941+
return constval.getSplatValue<Attribute>();
1942+
if (getPosition().size() == 1)
1943+
return constval.getValues<Attribute>()[getPosition()[0]];
1944+
}
1945+
}
1946+
19351947
auto insertValueOp = getContainer().getDefiningOp<InsertValueOp>();
19361948
OpFoldResult result = {};
19371949
while (insertValueOp) {

mlir/test/Dialect/LLVMIR/canonicalize.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,28 @@ llvm.func @fold_extract_extractvalue(%arr: !llvm.struct<(i64, array<1 x ptr<1>>)
9696

9797
// -----
9898

99+
// CHECK-LABEL: fold_extract_const
100+
// CHECK-NOT: extractvalue
101+
// CHECK: llvm.mlir.constant(5.000000e-01 : f64)
102+
llvm.func @fold_extract_const() -> f64 {
103+
%a = llvm.mlir.constant(dense<[-8.900000e+01, 5.000000e-01]> : tensor<2xf64>) : !llvm.array<2 x f64>
104+
%b = llvm.extractvalue %a[1] : !llvm.array<2 x f64>
105+
llvm.return %b : f64
106+
}
107+
108+
// -----
109+
110+
// CHECK-LABEL: fold_extract_splat
111+
// CHECK-NOT: extractvalue
112+
// CHECK: llvm.mlir.constant(-8.900000e+01 : f64)
113+
llvm.func @fold_extract_splat() -> f64 {
114+
%a = llvm.mlir.constant(dense<-8.900000e+01> : tensor<2xf64>) : !llvm.array<2 x f64>
115+
%b = llvm.extractvalue %a[1] : !llvm.array<2 x f64>
116+
llvm.return %b : f64
117+
}
118+
119+
// -----
120+
99121
// CHECK-LABEL: fold_bitcast
100122
// CHECK-SAME: %[[ARG:[[:alnum:]]+]]
101123
// CHECK-NEXT: llvm.return %[[ARG]]

0 commit comments

Comments
 (0)