Skip to content

[mlir][Bufferization] Add support for controlled bufferization of alloc_tensor #70957

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <utility>

#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
Expand All @@ -28,6 +29,7 @@

namespace mlir {
namespace bufferization {
class AllocTensorOp;
class OneShotAnalysisState;
} // namespace bufferization

Expand Down Expand Up @@ -110,6 +112,18 @@ Value bufferizeToAllocation(RewriterBase &rewriter,
vector::MaskOp maskOp, Attribute memorySpace = {},
Operation *insertionPoint = nullptr);

/// Materialize a buffer allocation for the given bufferization.alloc_tensor op
/// and lower the op to memref.alloc + memref.tensor_store.
///
/// In addition to rewriting the IR, this function returns the newly allocated
/// buffer. The `insertionPoint` parameter can be used to specify a custom
/// insertion point for the buffer allocation.
Value bufferizeToAllocation(RewriterBase &rewriter,
const BufferizeToAllocationOptions &options,
bufferization::AllocTensorOp allocTensorOp,
Attribute memorySpace = {},
Operation *insertionPoint = nullptr);

/// Bufferize the given op with tensor semantics and materialize the result in
/// a newly allocated buffer.
///
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,27 @@ Value linalg::bufferizeToAllocation(
return alloc;
}

Value linalg::bufferizeToAllocation(
RewriterBase &rewriter, const linalg::BufferizeToAllocationOptions &options,
bufferization::AllocTensorOp allocTensorOp, Attribute memorySpace,
Operation *insertionPoint) {
Location loc = allocTensorOp.getLoc();
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(insertionPoint ? insertionPoint : allocTensorOp);
bufferization::BufferizationOptions bufferizationOptions;

// Create buffer allocation.
Value alloc = createAllocationForTensor(
rewriter, loc, allocTensorOp.getResult(), options, memorySpace);

// Create bufferization.to_tensor with "restrict" and "writable". The returned
// tensor is a new buffer allocation, so it does not alias with any buffer.
Value toTensorOp = rewriter.create<bufferization::ToTensorOp>(
loc, alloc, /*restrict=*/true, /*writable=*/true);
rewriter.replaceOp(allocTensorOp, toTensorOp);
return alloc;
}

/// Lower tensor.from_elements to a sequence of chained tensor.insert.
FailureOr<Operation *> mlir::linalg::rewriteInDestinationPassingStyle(
RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp) {
Expand Down Expand Up @@ -454,6 +475,8 @@ Value linalg::bufferizeToAllocation(
return bufferizeToAllocation(rewriter, options, padOp, memorySpace);
if (auto maskOp = dyn_cast<vector::MaskOp>(op))
return bufferizeToAllocation(rewriter, options, maskOp, memorySpace);
if (auto allocTensorOp = dyn_cast<bufferization::AllocTensorOp>(op))
return bufferizeToAllocation(rewriter, options, allocTensorOp, memorySpace);

// Only bufferizable ops are supported.
auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
Expand Down
23 changes: 23 additions & 0 deletions mlir/test/Dialect/Bufferization/Transforms/transform-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -215,3 +215,26 @@ func.func @buffer_loop_hoisting(%lb: index, %ub: index, %step: index, %f: f32, %
}
return
}

// -----

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%alloc_tensor = transform.structured.match ops{["bufferization.alloc_tensor"]} in %arg1
: (!transform.any_op) -> !transform.op<"bufferization.alloc_tensor">
%2, %new = transform.structured.bufferize_to_allocation %alloc_tensor
{alloc_op = "memref.alloca"}
: !transform.op<"bufferization.alloc_tensor">
transform.yield
}
}

// Expect `bufferization.bufferize_to_allocation` to create an alloc.
// CHECK-LABEL: func.func @empty_to_tensor_alloc()
func.func @empty_to_tensor_alloc() -> tensor<2x2xf32> {
// CHECK-NEXT: %[[alloca:.*]] = memref.alloca() : memref<2x2xf32>
// CHECK-NEXT: %[[tensor:.*]] = bufferization.to_tensor %[[alloca]] restrict writable : memref<2x2xf32>
// CHECK-NEXT: return %[[tensor]] : tensor<2x2xf32>
%0 = bufferization.alloc_tensor() : tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}