Skip to content

Commit 3f6c0fb

Browse files
[mlir][linalg][bufferize] Add MemCpyFn to AllocationCallbacks struct
This in preparation of decoupling BufferizableOpInterface, Comprehensive Bufferize and dialects. The goal of this CL is to make `getResultBuffer` (and other `bufferize` functions) independent of `LinalgOps`. Differential Revision: https://reviews.llvm.org/D112907
1 parent 6c6ccc7 commit 3f6c0fb

File tree

2 files changed

+28
-9
lines changed

2 files changed

+28
-9
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/ComprehensiveBufferize.h

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -172,16 +172,28 @@ Optional<Value> defaultAllocationFn(OpBuilder &b, Location loc,
172172
/// `defaultAllocationFn`.
173173
void defaultDeallocationFn(OpBuilder &b, Location loc, Value allocatedBuffer);
174174

175+
/// Default memory copy function that is used by the comprehensive bufferization
176+
/// pass. Creates a `linalg.copy` op.
177+
void defaultMemCpyFn(OpBuilder &b, Location loc, Value from, Value to);
178+
175179
/// Callback functions that are used by the comprehensive bufferization pass to
176180
/// allocate/deallocate memory. These default to use the
177181
/// `defaultAllocationFn`/`defaultDeallocationFn`, but can be overridden by the
178182
/// caller. The `deallocationFn` is gauranteed to recieve the `Value` returned
179183
/// by the `allocationFn`.
180184
struct AllocationCallbacks {
181-
std::function<Optional<Value>(OpBuilder &b, Location loc, Value shapedValue)>
182-
allocationFn = defaultAllocationFn;
183-
std::function<void(OpBuilder &b, Location loc, Value v)> deallocationFn =
184-
defaultDeallocationFn;
185+
using AllocationFn =
186+
std::function<Optional<Value>(OpBuilder &, Location, Value)>;
187+
using DeallocationFn = std::function<void(OpBuilder &, Location, Value)>;
188+
using MemCpyFn = std::function<void(OpBuilder &, Location, Value, Value)>;
189+
190+
AllocationCallbacks(AllocationFn allocFn, DeallocationFn deallocFn,
191+
MemCpyFn copyFn)
192+
: allocationFn(allocFn), deallocationFn(deallocFn), memCpyFn(copyFn) {}
193+
194+
AllocationFn allocationFn;
195+
DeallocationFn deallocationFn;
196+
MemCpyFn memCpyFn;
185197
};
186198

187199
/// Bufferize one particular op.

mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1274,7 +1274,7 @@ static Value getResultBuffer(OpBuilder &b, OpResult result,
12741274
if (!skipCopy) {
12751275
// Set insertion point now that potential alloc/dealloc are introduced.
12761276
b.setInsertionPoint(op);
1277-
b.create<CopyOp>(loc, operandBuffer, resultBuffer);
1277+
allocationFns.memCpyFn(b, loc, operandBuffer, resultBuffer);
12781278
}
12791279
return resultBuffer;
12801280
}
@@ -1669,6 +1669,11 @@ void mlir::linalg::defaultDeallocationFn(OpBuilder &b, Location loc,
16691669
b.create<memref::DeallocOp>(loc, allocatedBuffer);
16701670
}
16711671

1672+
void mlir::linalg::defaultMemCpyFn(OpBuilder &b, Location loc, Value from,
1673+
Value to) {
1674+
b.create<CopyOp>(loc, from, to);
1675+
}
1676+
16721677
LogicalResult mlir::linalg::bufferizeOp(
16731678
Operation *op, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo,
16741679
AllocationCallbacks allocationFns,
@@ -2258,11 +2263,13 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
22582263
// command line option. So this is set up at the start of the pass.
22592264
if (useAlloca) {
22602265
AllocationCallbacks allocaAllocationFns = {
2261-
allocationFnUsingAlloca, [](OpBuilder &b, Location loc, Value v) {}};
2266+
allocationFnUsingAlloca, [](OpBuilder &b, Location loc, Value v) {},
2267+
defaultMemCpyFn};
22622268
allocationFns =
22632269
std::make_unique<AllocationCallbacks>(std::move(allocaAllocationFns));
22642270
} else {
2265-
allocationFns = std::make_unique<AllocationCallbacks>();
2271+
allocationFns = std::make_unique<AllocationCallbacks>(
2272+
defaultAllocationFn, defaultDeallocationFn, defaultMemCpyFn);
22662273
}
22672274
}
22682275
ModuleOp moduleOp = getOperation();
@@ -3222,7 +3229,7 @@ struct ExtractSliceOpInterface
32223229
if (alloc) {
32233230
// Do not copy if the copied data is never read.
32243231
if (isValueRead(extractSliceOp.result()))
3225-
b.create<CopyOp>(extractSliceOp.getLoc(), subView, alloc);
3232+
allocationFn.memCpyFn(b, extractSliceOp.getLoc(), subView, alloc);
32263233
subView = alloc;
32273234
}
32283235

@@ -3344,7 +3351,7 @@ struct InsertSliceOpInterface
33443351
insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides());
33453352
// Insert new alias.
33463353
aliasInfo.insertNewBufferAlias(subView, dstMemref);
3347-
b.create<CopyOp>(insertSliceOp.getLoc(), srcMemref, subView);
3354+
allocationFn.memCpyFn(b, insertSliceOp.getLoc(), srcMemref, subView);
33483355
}
33493356

33503357
map(bvm, insertSliceOp.result(), dstMemref);

0 commit comments

Comments
 (0)