Skip to content

Commit 6af81ea

Browse files
author
Stephan Herhut
committed
[mlir][std] Fold load(tensor_to_memref) into extract_element
This canonicalization is useful to resolve loads into scalar values when doing partial bufferization. Differential Revision: https://reviews.llvm.org/D91855
1 parent ffb3fd8 commit 6af81ea

File tree

3 files changed

+39
-0
lines changed

3 files changed

+39
-0
lines changed

mlir/include/mlir/Dialect/StandardOps/IR/Ops.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2234,6 +2234,7 @@ def LoadOp : Std_Op<"load",
22342234
operand_range getIndices() { return {operand_begin() + 1, operand_end()}; }
22352235
}];
22362236

2237+
let hasCanonicalizer = 1;
22372238
let hasFolder = 1;
22382239

22392240
let assemblyFormat = "$memref `[` $indices `]` attr-dict `:` type($memref)";

mlir/lib/Dialect/StandardOps/IR/Ops.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2293,6 +2293,30 @@ OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) {
22932293
return OpFoldResult();
22942294
}
22952295

2296+
namespace {
2297+
/// Fold a load on a tensor_to_memref operation into an extract_element on the
2298+
/// corresponding tensor.
2299+
struct LoadOfTensorToMemref : public OpRewritePattern<LoadOp> {
2300+
using OpRewritePattern<LoadOp>::OpRewritePattern;
2301+
2302+
LogicalResult matchAndRewrite(LoadOp load,
2303+
PatternRewriter &rewriter) const override {
2304+
auto tensorToMemref = load.memref().getDefiningOp<TensorToMemrefOp>();
2305+
if (!tensorToMemref)
2306+
return failure();
2307+
2308+
rewriter.replaceOpWithNewOp<ExtractElementOp>(load, tensorToMemref.tensor(),
2309+
load.indices());
2310+
return success();
2311+
}
2312+
};
2313+
} // end anonymous namespace.
2314+
2315+
void LoadOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2316+
MLIRContext *context) {
2317+
results.insert<LoadOfTensorToMemref>(context);
2318+
}
2319+
22962320
//===----------------------------------------------------------------------===//
22972321
// MemRefCastOp
22982322
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Standard/canonicalize.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,20 @@ func @dim_of_tensor_load(%arg0: memref<?xf32>) -> index {
4545
return %1 : index
4646
}
4747

48+
// Test case: Folding of load(tensor_to_memref(%v, %idxs))
49+
// -> extract_element(%v, %idx)
50+
// CHECK-LABEL: func @load_from_tensor_to_memref(
51+
// CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index
52+
// CHECK-SAME: %[[TENSOR:[0-9a-z]+]]: tensor<?x?xf32>
53+
// CHECK: %[[RES:.*]] = extract_element %[[TENSOR]][%[[IDX0]], %[[IDX1]]]
54+
// CHECK-NOT: load
55+
// CHECK: return %[[RES]] : f32
56+
func @load_from_tensor_to_memref(%arg0: index, %arg1: index, %arg2: tensor<?x?xf32>) -> f32 {
57+
%0 = tensor_to_memref %arg2 : memref<?x?xf32>
58+
%1 = load %0[%arg0, %arg1] : memref<?x?xf32>
59+
return %1 : f32
60+
}
61+
4862
// Test case: Folding of dim(dynamic_tensor_from_elements %idx) -> %idx
4963
// CHECK-LABEL: func @dim_of_dynamic_tensor_from_elements(
5064
// CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index

0 commit comments

Comments
 (0)