File tree Expand file tree Collapse file tree 3 files changed +39
-0
lines changed
include/mlir/Dialect/StandardOps/IR
lib/Dialect/StandardOps/IR Expand file tree Collapse file tree 3 files changed +39
-0
lines changed Original file line number Diff line number Diff line change @@ -2234,6 +2234,7 @@ def LoadOp : Std_Op<"load",
2234
2234
operand_range getIndices() { return {operand_begin() + 1, operand_end()}; }
2235
2235
}];
2236
2236
2237
+ let hasCanonicalizer = 1;
2237
2238
let hasFolder = 1;
2238
2239
2239
2240
let assemblyFormat = "$memref `[` $indices `]` attr-dict `:` type($memref)";
Original file line number Diff line number Diff line change @@ -2293,6 +2293,30 @@ OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) {
2293
2293
return OpFoldResult ();
2294
2294
}
2295
2295
2296
+ namespace {
2297
+ // / Fold a load on a tensor_to_memref operation into an extract_element on the
2298
+ // / corresponding tensor.
2299
+ struct LoadOfTensorToMemref : public OpRewritePattern <LoadOp> {
2300
+ using OpRewritePattern<LoadOp>::OpRewritePattern;
2301
+
2302
+ LogicalResult matchAndRewrite (LoadOp load,
2303
+ PatternRewriter &rewriter) const override {
2304
+ auto tensorToMemref = load.memref ().getDefiningOp <TensorToMemrefOp>();
2305
+ if (!tensorToMemref)
2306
+ return failure ();
2307
+
2308
+ rewriter.replaceOpWithNewOp <ExtractElementOp>(load, tensorToMemref.tensor (),
2309
+ load.indices ());
2310
+ return success ();
2311
+ }
2312
+ };
2313
+ } // end anonymous namespace.
2314
+
2315
+ void LoadOp::getCanonicalizationPatterns (OwningRewritePatternList &results,
2316
+ MLIRContext *context) {
2317
+ results.insert <LoadOfTensorToMemref>(context);
2318
+ }
2319
+
2296
2320
// ===----------------------------------------------------------------------===//
2297
2321
// MemRefCastOp
2298
2322
// ===----------------------------------------------------------------------===//
Original file line number Diff line number Diff line change @@ -45,6 +45,20 @@ func @dim_of_tensor_load(%arg0: memref<?xf32>) -> index {
45
45
return %1 : index
46
46
}
47
47
48
+ // Test case: Folding of load(tensor_to_memref(%v, %idxs))
49
+ // -> extract_element(%v, %idx)
50
+ // CHECK-LABEL: func @load_from_tensor_to_memref(
51
+ // CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index
52
+ // CHECK-SAME: %[[TENSOR:[0-9a-z]+]]: tensor<?x?xf32>
53
+ // CHECK: %[[RES:.*]] = extract_element %[[TENSOR]][%[[IDX0]], %[[IDX1]]]
54
+ // CHECK-NOT: load
55
+ // CHECK: return %[[RES]] : f32
56
+ func @load_from_tensor_to_memref (%arg0: index , %arg1: index , %arg2: tensor <?x?xf32 >) -> f32 {
57
+ %0 = tensor_to_memref %arg2 : memref <?x?xf32 >
58
+ %1 = load %0 [%arg0 , %arg1 ] : memref <?x?xf32 >
59
+ return %1 : f32
60
+ }
61
+
48
62
// Test case: Folding of dim(dynamic_tensor_from_elements %idx) -> %idx
49
63
// CHECK-LABEL: func @dim_of_dynamic_tensor_from_elements(
50
64
// CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index
You can’t perform that action at this time.
0 commit comments