Skip to content

Commit 461dafd

Browse files
[mlir][bufferization] Add OneShotBufferize transform op
This commit allows for One-Shot Bufferize to be used through the transform dialect. No op handle is currently returned for the bufferized IR. Differential Revision: https://reviews.llvm.org/D125098
1 parent cedfb54 commit 461dafd

File tree

10 files changed

+391
-0
lines changed

10 files changed

+391
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
add_subdirectory(IR)
2+
add_subdirectory(TransformOps)
23
add_subdirectory(Transforms)
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
//===- BufferizationTransformOps.h - Buff. transf. ops ----------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMOPS_BUFFERIZATIONTRANSFORMOPS_H
10+
#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMOPS_BUFFERIZATIONTRANSFORMOPS_H
11+
12+
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
13+
#include "mlir/IR/OpImplementation.h"
14+
15+
//===----------------------------------------------------------------------===//
16+
// Bufferization Transform Operations
17+
//===----------------------------------------------------------------------===//
18+
19+
#define GET_OP_CLASSES
20+
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h.inc"
21+
22+
namespace mlir {
23+
class DialectRegistry;
24+
25+
namespace bufferization {
26+
void registerTransformDialectExtension(DialectRegistry &registry);
27+
} // namespace bufferization
28+
} // namespace mlir
29+
30+
#endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMOPS_BUFFERIZATIONTRANSFORMOPS_H
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
//===- BufferizationTransformOps.td - Buff. transf. ops ----*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef BUFFERIZATION_TRANSFORM_OPS
10+
#define BUFFERIZATION_TRANSFORM_OPS
11+
12+
include "mlir/Dialect/Transform/IR/TransformDialect.td"
13+
include "mlir/Dialect/Transform/IR/TransformEffects.td"
14+
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
15+
include "mlir/Dialect/PDL/IR/PDLTypes.td"
16+
include "mlir/Interfaces/SideEffectInterfaces.td"
17+
include "mlir/IR/OpBase.td"
18+
19+
def OneShotBufferizeOp
20+
: Op<Transform_Dialect, "bufferization.one_shot_bufferize",
21+
[DeclareOpInterfaceMethods<TransformOpInterface>,
22+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
23+
let description = [{
24+
Indicates that the given `target` op should be bufferized with One-Shot
25+
Bufferize. The bufferization can be configured with various attributes that
26+
corresponding to options in `BufferizationOptions` and the
27+
`one-shot-bufferize` pass. More information can be found in the pass
28+
documentation.
29+
30+
If `target_is_module` is set, `target` must be a module. In that case the
31+
`target` handle can be reused by other transform ops. When bufferizing other
32+
ops, the `target` handled is freed after bufferization and can no longer be
33+
used.
34+
35+
Note: Only ops that implement `BufferizableOpInterface` are bufferized. All
36+
other ops are ignored if `allow_unknown_ops`. If `allow_unknown_ops` is
37+
unset, this transform fails when an unknown/non-bufferizable op is found.
38+
Many ops implement `BufferizableOpInterface` via an external model. These
39+
external models must be registered when applying this transform op;
40+
otherwise, said ops would be considered non-bufferizable.
41+
}];
42+
43+
let arguments = (
44+
ins PDL_Operation:$target,
45+
DefaultValuedAttr<BoolAttr, "false">:$allow_return_allocs,
46+
DefaultValuedAttr<BoolAttr, "false">:$allow_unknown_ops,
47+
DefaultValuedAttr<BoolAttr, "false">:$bufferize_function_boundaries,
48+
DefaultValuedAttr<BoolAttr, "true">:$create_deallocs,
49+
DefaultValuedAttr<BoolAttr, "true">:$target_is_module,
50+
DefaultValuedAttr<BoolAttr, "false">:$test_analysis_only,
51+
DefaultValuedAttr<BoolAttr, "false">:$print_conflicts);
52+
53+
let results = (outs);
54+
55+
let assemblyFormat = "$target attr-dict";
56+
}
57+
58+
#endif // BUFFERIZATION_TRANSFORM_OPS
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
set(LLVM_TARGET_DEFINITIONS BufferizationTransformOps.td)
2+
mlir_tablegen(BufferizationTransformOps.h.inc -gen-op-decls)
3+
mlir_tablegen(BufferizationTransformOps.cpp.inc -gen-op-defs)
4+
add_public_tablegen_target(MLIRBufferizationTransformOpsIncGen)

mlir/include/mlir/InitAllDialects.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
2424
#include "mlir/Dialect/Async/IR/Async.h"
2525
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
26+
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
2627
#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
2728
#include "mlir/Dialect/Complex/IR/Complex.h"
2829
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
@@ -107,6 +108,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
107108
// clang-format on
108109

109110
// Register all dialect extensions.
111+
bufferization::registerTransformDialectExtension(registry);
110112
linalg::registerTransformDialectExtension(registry);
111113
scf::registerTransformDialectExtension(registry);
112114

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
add_subdirectory(IR)
2+
add_subdirectory(TransformOps)
23
add_subdirectory(Transforms)
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
//===- BufferizationTransformOps.h - Bufferization transform ops ----------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
10+
11+
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
12+
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
13+
#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
14+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
15+
#include "mlir/Dialect/PDL/IR/PDL.h"
16+
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
17+
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
18+
19+
using namespace mlir;
20+
using namespace mlir::bufferization;
21+
using namespace mlir::transform;
22+
23+
//===----------------------------------------------------------------------===//
24+
// OneShotBufferizeOp
25+
//===----------------------------------------------------------------------===//
26+
27+
LogicalResult
28+
transform::OneShotBufferizeOp::apply(TransformResults &transformResults,
29+
TransformState &state) {
30+
OneShotBufferizationOptions options;
31+
options.allowReturnAllocs = getAllowReturnAllocs();
32+
options.allowUnknownOps = getAllowUnknownOps();
33+
options.bufferizeFunctionBoundaries = getBufferizeFunctionBoundaries();
34+
options.createDeallocs = getCreateDeallocs();
35+
options.testAnalysisOnly = getTestAnalysisOnly();
36+
options.printConflicts = getPrintConflicts();
37+
38+
ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
39+
for (Operation *target : payloadOps) {
40+
auto moduleOp = dyn_cast<ModuleOp>(target);
41+
if (getTargetIsModule() && !moduleOp)
42+
return emitError("expected ModuleOp target");
43+
if (options.bufferizeFunctionBoundaries) {
44+
if (!moduleOp)
45+
return emitError("expected ModuleOp target");
46+
if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options)))
47+
return emitError("bufferization failed");
48+
} else {
49+
if (failed(bufferization::runOneShotBufferize(target, options)))
50+
return emitError("bufferization failed");
51+
}
52+
}
53+
54+
return success();
55+
}
56+
57+
void transform::OneShotBufferizeOp::getEffects(
58+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
59+
effects.emplace_back(MemoryEffects::Read::get(), getTarget(),
60+
TransformMappingResource::get());
61+
62+
// Handles that are not modules are not longer usable.
63+
if (!getTargetIsModule())
64+
effects.emplace_back(MemoryEffects::Free::get(), getTarget(),
65+
TransformMappingResource::get());
66+
}
67+
//===----------------------------------------------------------------------===//
68+
// Transform op registration
69+
//===----------------------------------------------------------------------===//
70+
71+
namespace {
72+
/// Registers new ops and declares PDL as dependent dialect since the additional
73+
/// ops are using PDL types for operands and results.
74+
class BufferizationTransformDialectExtension
75+
: public transform::TransformDialectExtension<
76+
BufferizationTransformDialectExtension> {
77+
public:
78+
BufferizationTransformDialectExtension() {
79+
declareDependentDialect<bufferization::BufferizationDialect>();
80+
declareDependentDialect<pdl::PDLDialect>();
81+
declareDependentDialect<memref::MemRefDialect>();
82+
registerTransformOps<
83+
#define GET_OP_LIST
84+
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"
85+
>();
86+
}
87+
};
88+
} // namespace
89+
90+
#define GET_OP_CLASSES
91+
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp.inc"
92+
93+
void mlir::bufferization::registerTransformDialectExtension(
94+
DialectRegistry &registry) {
95+
registry.addExtensions<BufferizationTransformDialectExtension>();
96+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
add_mlir_dialect_library(MLIRBufferizationTransformOps
2+
BufferizationTransformOps.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Bufferization/TransformOps
6+
7+
DEPENDS
8+
MLIRBufferizationTransformOpsIncGen
9+
10+
LINK_LIBS PUBLIC
11+
MLIRIR
12+
MLIRBufferization
13+
MLIRBufferizationTransforms
14+
MLIRParser
15+
MLIRPDL
16+
MLIRSideEffectInterfaces
17+
MLIRTransformDialect
18+
)
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
// RUN: mlir-opt --test-transform-dialect-interpreter %s -split-input-file -verify-diagnostics | FileCheck %s
2+
3+
// Test One-Shot Bufferize.
4+
5+
transform.with_pdl_patterns {
6+
^bb0(%arg0: !pdl.operation):
7+
sequence %arg0 {
8+
^bb0(%arg1: !pdl.operation):
9+
%0 = pdl_match @pdl_target in %arg1
10+
transform.bufferization.one_shot_bufferize %0
11+
{target_is_module = false}
12+
}
13+
14+
pdl.pattern @pdl_target : benefit(1) {
15+
%0 = operation "func.func"
16+
rewrite %0 with "transform.dialect"
17+
}
18+
}
19+
20+
// CHECK-LABEL: func @test_function(
21+
// CHECK-SAME: %[[A:.*]]: tensor<?xf32>
22+
func.func @test_function(%A : tensor<?xf32>, %v : vector<4xf32>) -> (tensor<?xf32>) {
23+
%c0 = arith.constant 0 : index
24+
25+
// CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]]
26+
// CHECK: %[[dim:.*]] = memref.dim %[[A_memref]]
27+
// CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]])
28+
// CHECK: memref.copy %[[A_memref]], %[[alloc]]
29+
// CHECK: vector.transfer_write %{{.*}}, %[[alloc]]
30+
// CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[alloc]]
31+
%0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor<?xf32>
32+
33+
// CHECK: memref.dealloc %[[alloc]]
34+
// CHECK: return %[[res_tensor]]
35+
return %0 : tensor<?xf32>
36+
}
37+
38+
// -----
39+
40+
// Test analysis of One-Shot Bufferize only.
41+
42+
transform.with_pdl_patterns {
43+
^bb0(%arg0: !pdl.operation):
44+
sequence %arg0 {
45+
^bb0(%arg1: !pdl.operation):
46+
%0 = pdl_match @pdl_target in %arg1
47+
transform.bufferization.one_shot_bufferize %0
48+
{target_is_module = false, test_analysis_only = true}
49+
}
50+
51+
pdl.pattern @pdl_target : benefit(1) {
52+
%0 = operation "func.func"
53+
rewrite %0 with "transform.dialect"
54+
}
55+
}
56+
57+
// CHECK-LABEL: func @test_function_analysis(
58+
// CHECK-SAME: %[[A:.*]]: tensor<?xf32>
59+
func.func @test_function_analysis(%A : tensor<?xf32>, %v : vector<4xf32>) -> (tensor<?xf32>) {
60+
%c0 = arith.constant 0 : index
61+
// CHECK: vector.transfer_write
62+
// CHECK-SAME: {__inplace_operands_attr__ = ["none", "false", "none"]}
63+
// CHECK-SAME: tensor<?xf32>
64+
%0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor<?xf32>
65+
return %0 : tensor<?xf32>
66+
}
67+
68+
// -----
69+
70+
// Test One-Shot Bufferize transform failure with an unknown op. This would be
71+
// allowed with `allow_unknown_ops`.
72+
73+
transform.with_pdl_patterns {
74+
^bb0(%arg0: !pdl.operation):
75+
sequence %arg0 {
76+
^bb0(%arg1: !pdl.operation):
77+
%0 = pdl_match @pdl_target in %arg1
78+
// expected-error @+1 {{bufferization failed}}
79+
transform.bufferization.one_shot_bufferize %0 {target_is_module = false}
80+
}
81+
82+
pdl.pattern @pdl_target : benefit(1) {
83+
%0 = operation "func.func"
84+
rewrite %0 with "transform.dialect"
85+
}
86+
}
87+
88+
func.func @test_unknown_op_failure() -> (tensor<?xf32>) {
89+
// expected-error @+1 {{op was not bufferized}}
90+
%0 = "test.dummy_op"() : () -> (tensor<?xf32>)
91+
return %0 : tensor<?xf32>
92+
}
93+
94+
// -----
95+
96+
// Test One-Shot Bufferize transform failure with a module op.
97+
98+
transform.with_pdl_patterns {
99+
^bb0(%arg0: !pdl.operation):
100+
sequence %arg0 {
101+
^bb0(%arg1: !pdl.operation):
102+
// %arg1 is the module
103+
transform.bufferization.one_shot_bufferize %arg1
104+
}
105+
}
106+
107+
module {
108+
// CHECK-LABEL: func @test_function(
109+
// CHECK-SAME: %[[A:.*]]: tensor<?xf32>
110+
func.func @test_function(%A : tensor<?xf32>, %v : vector<4xf32>) -> (tensor<?xf32>) {
111+
%c0 = arith.constant 0 : index
112+
113+
// CHECK: %[[A_memref:.*]] = bufferization.to_memref %[[A]]
114+
// CHECK: %[[dim:.*]] = memref.dim %[[A_memref]]
115+
// CHECK: %[[alloc:.*]] = memref.alloc(%[[dim]])
116+
// CHECK: memref.copy %[[A_memref]], %[[alloc]]
117+
// CHECK: vector.transfer_write %{{.*}}, %[[alloc]]
118+
// CHECK: %[[res_tensor:.*]] = bufferization.to_tensor %[[alloc]]
119+
%0 = vector.transfer_write %v, %A[%c0] : vector<4xf32>, tensor<?xf32>
120+
121+
// CHECK: memref.dealloc %[[alloc]]
122+
// CHECK: return %[[res_tensor]]
123+
return %0 : tensor<?xf32>
124+
}
125+
}

0 commit comments

Comments
 (0)