Skip to content

Commit aa08e66

Browse files
committed
[MLIR][LLVM] Fold extract of constant
1 parent 131a3cf commit aa08e66

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1932,6 +1932,19 @@ OpFoldResult LLVM::ExtractValueOp::fold(FoldAdaptor adaptor) {
19321932
getContainerMutable().set(extractValueOp.getContainer());
19331933
return getResult();
19341934
}
1935+
1936+
{
1937+
DenseElementsAttr constval;
1938+
matchPattern(getContainer().getOperand(), m_Constant(&constval));
1939+
if (constval && constval.getElementType() == getType()) {
1940+
if (constval.isSplat()) {
1941+
return constval.getSplatValue();
1942+
} else if (getPosition().size() == 1) {
1943+
return constval.getValues()[getPosition()[0]];
1944+
}
1945+
}
1946+
}
1947+
19351948
auto insertValueOp = getContainer().getDefiningOp<InsertValueOp>();
19361949
OpFoldResult result = {};
19371950
while (insertValueOp) {

mlir/test/Dialect/LLVMIR/canonicalize.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,26 @@ llvm.func @fold_extract_extractvalue(%arr: !llvm.struct<(i64, array<1 x ptr<1>>)
9494
llvm.return %b : !llvm.ptr<1>
9595
}
9696

97+
// -----
98+
// CHECK-LABEL: fold_extract_const
99+
// CHECK-NOT: extractvalue
100+
// CHECK: llvm.mlir.constant(-5.0)
101+
llvm.func @fold_extract_const() -> f64 {
102+
%a = llvm.mlir.constant(dense<[-8.900000e+01, 5.000000e-01]> : tensor<2xf64>) : !llvm.array<2 x f64>
103+
%b = llvm.extractvalue %a[1] : !llvm.array<2 x f64>
104+
llvm.return %b : f64
105+
}
106+
107+
// -----
108+
// CHECK-LABEL: fold_extract_splat
109+
// CHECK-NOT: extractvalue
110+
// CHECK: llvm.mlir.constant(-8.9)
111+
llvm.func @fold_extract_const() -> f64 {
112+
%a = llvm.mlir.constant(dense<[-8.900000e+01> : tensor<2xf64>) : !llvm.array<2 x f64>
113+
%b = llvm.extractvalue %a[1] : !llvm.array<2 x f64>
114+
llvm.return %b : f64
115+
}
116+
97117
// -----
98118

99119
// CHECK-LABEL: fold_bitcast

0 commit comments

Comments
 (0)