Skip to content

Commit 3a223f4

Browse files
[mlir][Bufferization] Add support for controlled bufferization of alloc_tensor (#70957)
This revision adds support to `transform.structured.bufferize_to_allocation` to bufferize `bufferization.alloc_tensor()` ops. This is useful as a means path to control the bufferization of `tensor.empty` ops that have bene previously `bufferization.empty_tensor_to_alloc_tensor`'ed.
1 parent 65bad23 commit 3a223f4

File tree

3 files changed

+60
-0
lines changed

3 files changed

+60
-0
lines changed

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <utility>
1313

1414
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
15+
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
1516
#include "mlir/Dialect/Linalg/Utils/Utils.h"
1617
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1718
#include "mlir/Dialect/SCF/Utils/Utils.h"
@@ -28,6 +29,7 @@
2829

2930
namespace mlir {
3031
namespace bufferization {
32+
class AllocTensorOp;
3133
class OneShotAnalysisState;
3234
} // namespace bufferization
3335

@@ -110,6 +112,18 @@ Value bufferizeToAllocation(RewriterBase &rewriter,
110112
vector::MaskOp maskOp, Attribute memorySpace = {},
111113
Operation *insertionPoint = nullptr);
112114

115+
/// Materialize a buffer allocation for the given bufferization.alloc_tensor op
116+
/// and lower the op to memref.alloc + memref.tensor_store.
117+
///
118+
/// In addition to rewriting the IR, this function returns the newly allocated
119+
/// buffer. The `insertionPoint` parameter can be used to specify a custom
120+
/// insertion point for the buffer allocation.
121+
Value bufferizeToAllocation(RewriterBase &rewriter,
122+
const BufferizeToAllocationOptions &options,
123+
bufferization::AllocTensorOp allocTensorOp,
124+
Attribute memorySpace = {},
125+
Operation *insertionPoint = nullptr);
126+
113127
/// Bufferize the given op with tensor semantics and materialize the result in
114128
/// a newly allocated buffer.
115129
///

mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,27 @@ Value linalg::bufferizeToAllocation(
317317
return alloc;
318318
}
319319

320+
Value linalg::bufferizeToAllocation(
321+
RewriterBase &rewriter, const linalg::BufferizeToAllocationOptions &options,
322+
bufferization::AllocTensorOp allocTensorOp, Attribute memorySpace,
323+
Operation *insertionPoint) {
324+
Location loc = allocTensorOp.getLoc();
325+
OpBuilder::InsertionGuard g(rewriter);
326+
rewriter.setInsertionPoint(insertionPoint ? insertionPoint : allocTensorOp);
327+
bufferization::BufferizationOptions bufferizationOptions;
328+
329+
// Create buffer allocation.
330+
Value alloc = createAllocationForTensor(
331+
rewriter, loc, allocTensorOp.getResult(), options, memorySpace);
332+
333+
// Create bufferization.to_tensor with "restrict" and "writable". The returned
334+
// tensor is a new buffer allocation, so it does not alias with any buffer.
335+
Value toTensorOp = rewriter.create<bufferization::ToTensorOp>(
336+
loc, alloc, /*restrict=*/true, /*writable=*/true);
337+
rewriter.replaceOp(allocTensorOp, toTensorOp);
338+
return alloc;
339+
}
340+
320341
/// Lower tensor.from_elements to a sequence of chained tensor.insert.
321342
FailureOr<Operation *> mlir::linalg::rewriteInDestinationPassingStyle(
322343
RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp) {
@@ -454,6 +475,8 @@ Value linalg::bufferizeToAllocation(
454475
return bufferizeToAllocation(rewriter, options, padOp, memorySpace);
455476
if (auto maskOp = dyn_cast<vector::MaskOp>(op))
456477
return bufferizeToAllocation(rewriter, options, maskOp, memorySpace);
478+
if (auto allocTensorOp = dyn_cast<bufferization::AllocTensorOp>(op))
479+
return bufferizeToAllocation(rewriter, options, allocTensorOp, memorySpace);
457480

458481
// Only bufferizable ops are supported.
459482
auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);

mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,3 +215,26 @@ func.func @buffer_loop_hoisting(%lb: index, %ub: index, %step: index, %f: f32, %
215215
}
216216
return
217217
}
218+
219+
// -----
220+
221+
module attributes {transform.with_named_sequence} {
222+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
223+
%alloc_tensor = transform.structured.match ops{["bufferization.alloc_tensor"]} in %arg1
224+
: (!transform.any_op) -> !transform.op<"bufferization.alloc_tensor">
225+
%2, %new = transform.structured.bufferize_to_allocation %alloc_tensor
226+
{alloc_op = "memref.alloca"}
227+
: !transform.op<"bufferization.alloc_tensor">
228+
transform.yield
229+
}
230+
}
231+
232+
// Expect `bufferization.bufferize_to_allocation` to create an alloc.
233+
// CHECK-LABEL: func.func @empty_to_tensor_alloc()
234+
func.func @empty_to_tensor_alloc() -> tensor<2x2xf32> {
235+
// CHECK-NEXT: %[[alloca:.*]] = memref.alloca() : memref<2x2xf32>
236+
// CHECK-NEXT: %[[tensor:.*]] = bufferization.to_tensor %[[alloca]] restrict writable : memref<2x2xf32>
237+
// CHECK-NEXT: return %[[tensor]] : tensor<2x2xf32>
238+
%0 = bufferization.alloc_tensor() : tensor<2x2xf32>
239+
return %0 : tensor<2x2xf32>
240+
}

0 commit comments

Comments
 (0)