Skip to content

Commit e55e36d

Browse files
authored
[mlir] alloc-to-alloca conversion for memref (#65335)
Introduce a simple conversion of a memref.alloc/dealloc pair into an alloca in the same scope. Expose it as a transform op and a pattern. Allocas typically lower to stack allocations as opposed to alloc/dealloc that lower to significantly more expensive malloc/free calls. In addition, this can be combined with allocation hoisting from loops to further improve performance.
1 parent a4605af commit e55e36d

File tree

7 files changed

+171
-2
lines changed

7 files changed

+171
-2
lines changed

mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,23 @@ def MemrefToLLVMTypeConverterOp : Op<Transform_Dialect,
5151
let assemblyFormat = "attr-dict";
5252
}
5353

54+
def ApplyAllocToAllocaOp : Op<Transform_Dialect,
55+
"apply_patterns.memref.alloc_to_alloca",
56+
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface, ["populatePatternsWithState"]>]> {
57+
let description = [{
58+
Collects patterns to rewrite scoped dynamic allocation (`alloc`/`dealloc`
59+
pairs) into automatic allocation (`alloca`) in the same scope, for memrefs
60+
of static shape.
61+
62+
The `size_limit` attribute controls the maximum allocated memory (in bytes,
63+
subject to data layout) for which the pattern applies.
64+
}];
65+
66+
let arguments = (ins
67+
OptionalAttr<I64Attr>:$size_limit);
68+
let assemblyFormat = "(`size_limit` `(` $size_limit^ `)`)? attr-dict";
69+
}
70+
5471
def ApplyExpandOpsPatternsOp : Op<Transform_Dialect,
5572
"apply_patterns.memref.expand_ops",
5673
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {

mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#define MLIR_DIALECT_MEMREF_TRANSFORMS_TRANSFORMS_H
1616

1717
#include "mlir/Support/LogicalResult.h"
18+
#include "llvm/ADT/STLFunctionalExtras.h"
1819

1920
namespace mlir {
2021
class OpBuilder;
@@ -31,6 +32,7 @@ class NarrowTypeEmulationConverter;
3132
namespace memref {
3233
class AllocOp;
3334
class AllocaOp;
35+
class DeallocOp;
3436

3537
//===----------------------------------------------------------------------===//
3638
// Patterns
@@ -196,6 +198,15 @@ FailureOr<Value> replaceWithIndependentOp(RewriterBase &rewriter,
196198
memref::AllocaOp allocaOp,
197199
ValueRange independencies);
198200

201+
/// Replaces the given `alloc` with the corresponding `alloca` and returns it if
202+
/// the following conditions are met:
203+
/// - the corresponding dealloc is available in the same block as the alloc;
204+
/// - the filter, if provided, succeeds on the alloc/dealloc pair.
205+
/// Otherwise returns nullptr and leaves the IR unchanged.
206+
memref::AllocaOp allocToAlloca(
207+
RewriterBase &rewriter, memref::AllocOp alloc,
208+
function_ref<bool(memref::AllocOp, memref::DeallocOp)> filter = nullptr);
209+
199210
} // namespace memref
200211
} // namespace mlir
201212

mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,18 @@ def PatternDescriptorOpInterface : OpInterface<"PatternDescriptorOpInterface"> {
260260
/*name=*/"populatePatterns",
261261
/*arguments=*/(ins "::mlir::RewritePatternSet &":$patterns)
262262
>,
263+
InterfaceMethod<
264+
/*desc=*/[{
265+
Populate rewrite patterns into the given pattern set taking into account
266+
the transform state.
267+
}],
268+
/*returnType=*/"void",
269+
/*name=*/"populatePatternsWithState",
270+
/*arguments=*/(ins "::mlir::RewritePatternSet &":$patterns,
271+
"::mlir::transform::TransformState &":$state),
272+
/*methodBody=*/"",
273+
/*defaultImplementation=*/[{ $_op.populatePatterns(patterns); }]
274+
>
263275
];
264276
}
265277

mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
1010

11+
#include "mlir/Analysis/DataLayoutAnalysis.h"
1112
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
1213
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1314
#include "mlir/Dialect/Arith/IR/Arith.h"
@@ -64,6 +65,42 @@ StringRef transform::MemrefToLLVMTypeConverterOp::getTypeConverterType() {
6465
// Apply...PatternsOp
6566
//===----------------------------------------------------------------------===//
6667

68+
namespace {
69+
class AllocToAllocaPattern : public OpRewritePattern<memref::AllocOp> {
70+
public:
71+
explicit AllocToAllocaPattern(Operation *analysisRoot, int64_t maxSize = 0)
72+
: OpRewritePattern<memref::AllocOp>(analysisRoot->getContext()),
73+
dataLayoutAnalysis(analysisRoot), maxSize(maxSize) {}
74+
75+
LogicalResult matchAndRewrite(memref::AllocOp op,
76+
PatternRewriter &rewriter) const override {
77+
return success(memref::allocToAlloca(
78+
rewriter, op, [this](memref::AllocOp alloc, memref::DeallocOp dealloc) {
79+
MemRefType type = alloc.getMemref().getType();
80+
if (!type.hasStaticShape())
81+
return false;
82+
83+
const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove(alloc);
84+
int64_t elementSize = dataLayout.getTypeSize(type.getElementType());
85+
return maxSize == 0 || type.getNumElements() * elementSize < maxSize;
86+
}));
87+
}
88+
89+
private:
90+
DataLayoutAnalysis dataLayoutAnalysis;
91+
int64_t maxSize;
92+
};
93+
} // namespace
94+
95+
void transform::ApplyAllocToAllocaOp::populatePatterns(
96+
RewritePatternSet &patterns) {}
97+
98+
void transform::ApplyAllocToAllocaOp::populatePatternsWithState(
99+
RewritePatternSet &patterns, transform::TransformState &state) {
100+
patterns.insert<AllocToAllocaPattern>(
101+
state.getTopLevel(), static_cast<int64_t>(getSizeLimit().value_or(0)));
102+
}
103+
67104
void transform::ApplyExpandOpsPatternsOp::populatePatterns(
68105
RewritePatternSet &patterns) {
69106
memref::populateExpandOpsPatterns(patterns);

mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,3 +178,27 @@ FailureOr<Value> memref::replaceWithIndependentOp(RewriterBase &rewriter,
178178
replacement->getDefiningOp());
179179
return replacement;
180180
}
181+
182+
memref::AllocaOp memref::allocToAlloca(
183+
RewriterBase &rewriter, memref::AllocOp alloc,
184+
function_ref<bool(memref::AllocOp, memref::DeallocOp)> filter) {
185+
memref::DeallocOp dealloc = nullptr;
186+
for (Operation &candidate :
187+
llvm::make_range(alloc->getIterator(), alloc->getBlock()->end())) {
188+
dealloc = dyn_cast<memref::DeallocOp>(candidate);
189+
if (dealloc && dealloc.getMemref() == alloc.getMemref() &&
190+
(!filter || filter(alloc, dealloc))) {
191+
break;
192+
}
193+
}
194+
195+
if (!dealloc)
196+
return nullptr;
197+
198+
OpBuilder::InsertionGuard guard(rewriter);
199+
rewriter.setInsertionPoint(alloc);
200+
auto alloca = rewriter.replaceOpWithNewOp<memref::AllocaOp>(
201+
alloc, alloc.getMemref().getType(), alloc.getOperands());
202+
rewriter.eraseOp(dealloc);
203+
return alloca;
204+
}

mlir/lib/Dialect/Transform/IR/TransformOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -378,8 +378,8 @@ DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne(
378378
RewritePatternSet patterns(ctx);
379379
if (!getRegion().empty()) {
380380
for (Operation &op : getRegion().front()) {
381-
cast<transform::PatternDescriptorOpInterface>(&op).populatePatterns(
382-
patterns);
381+
cast<transform::PatternDescriptorOpInterface>(&op)
382+
.populatePatternsWithState(patterns, state);
383383
}
384384
}
385385

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
// RUN: mlir-opt --test-transform-dialect-interpreter=debug-transform-root-tag=all %s | FileCheck %s --check-prefixes=CHECK,ALL
2+
// RUN: mlir-opt --test-transform-dialect-interpreter=debug-transform-root-tag=small %s | FileCheck %s --check-prefixes=CHECK,SMALL
3+
4+
func.func private @callee(memref<*xf32>)
5+
6+
// CHECK-LABEL: @large_alloc
7+
func.func @large_alloc() {
8+
// SMALL: memref.alloc()
9+
// ALL: memref.alloca
10+
%0 = memref.alloc() : memref<100x100xf32>
11+
%1 = memref.cast %0 : memref<100x100xf32> to memref<*xf32>
12+
call @callee(%1) : (memref<*xf32>) -> ()
13+
// SMALL: memref.dealloc
14+
// ALL-NOT: memref.dealloc
15+
memref.dealloc %0 : memref<100x100xf32>
16+
return
17+
}
18+
19+
// CHECK-LABEL: @small_alloc
20+
func.func @small_alloc() {
21+
// CHECK: memref.alloca
22+
%0 = memref.alloc() : memref<2x2xf32>
23+
%1 = memref.cast %0 : memref<2x2xf32> to memref<*xf32>
24+
call @callee(%1) : (memref<*xf32>) -> ()
25+
// CHECK-NOT: memref.dealloc
26+
memref.dealloc %0 : memref<2x2xf32>
27+
return
28+
}
29+
30+
// CHECK-LABEL: @no_dealloc
31+
func.func @no_dealloc() {
32+
// CHECK: memref.alloc()
33+
%0 = memref.alloc() : memref<2x2xf32>
34+
%1 = memref.cast %0 : memref<2x2xf32> to memref<*xf32>
35+
call @callee(%1) : (memref<*xf32>) -> ()
36+
return
37+
}
38+
39+
// CHECK-LABEL: @mismatching_scope
40+
func.func @mismatching_scope() {
41+
// CHECK: memref.alloc()
42+
%0 = memref.alloc() : memref<2x2xf32>
43+
%1 = memref.cast %0 : memref<2x2xf32> to memref<*xf32>
44+
call @callee(%1) : (memref<*xf32>) -> ()
45+
scf.execute_region {
46+
memref.dealloc %0 : memref<2x2xf32>
47+
scf.yield
48+
}
49+
return
50+
}
51+
52+
transform.sequence failures(propagate) attributes {transform.target_tag = "all"} {
53+
^bb0(%arg0: !transform.any_op):
54+
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
55+
transform.apply_patterns to %0 {
56+
transform.apply_patterns.memref.alloc_to_alloca
57+
} : !transform.any_op
58+
transform.yield
59+
}
60+
61+
transform.sequence failures(propagate) attributes {transform.target_tag = "small"} {
62+
^bb0(%arg0: !transform.any_op):
63+
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
64+
transform.apply_patterns to %0 {
65+
transform.apply_patterns.memref.alloc_to_alloca size_limit(32)
66+
} : !transform.any_op
67+
transform.yield
68+
}

0 commit comments

Comments
 (0)