Skip to content

Commit ffdbecc

Browse files
[mlir][bufferization] Add bufferization.alloc_tensor op
This change adds a new op `alloc_tensor` to the bufferization dialect. During bufferization, this op is always lowered to a buffer allocation (unless it is "eliminated" by a pre-processing pass). It is useful to have such an op in tensor land, because it allows users to model tensor SSA use-def chains (which drive bufferization decisions) and because tensor SSA use-def chains can be analyzed by One-Shot Bufferize, while memref values cannot. This change also replaces all uses of linalg.init_tensor in bufferization-related code with bufferization.alloc_tensor. linalg.init_tensor and bufferization.alloc_tensor are similar, but the purpose of the former one is just to carry a shape. It does not indicate a memory allocation. linalg.init_tensor is not suitable for modelling SSA use-def chains for bufferization purposes, because linalg.init_tensor is marked as not having side effects (in contrast to alloc_tensor). As such, it is legal to move linalg.init_tensor ops around/CSE them/etc. This is not desirable for alloc_tensor; it represents an explicit buffer allocation while still in tensor land and such allocations should not suddenly disappear or get moved around when running the canonicalizer/CSE/etc. BEGIN_PUBLIC No public commit message needed for presubmit. END_PUBLIC Differential Revision: https://reviews.llvm.org/D126003
1 parent 4f6ac96 commit ffdbecc

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+920
-432
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/Bufferization.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.h"
1313
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
1414
#include "mlir/Interfaces/CopyOpInterface.h"
15+
#include "mlir/Interfaces/InferTypeOpInterface.h"
1516

1617
//===----------------------------------------------------------------------===//
1718
// Bufferization Dialect

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationBase.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@ def Bufferization_Dialect : Dialect {
2525
found in [bufferization](/docs/Bufferization/) and [buffer
2626
deallocation](/docs/BufferDeallocationInternals/).
2727
}];
28-
let dependentDialects = ["memref::MemRefDialect", "tensor::TensorDialect"];
28+
let dependentDialects = [
29+
"AffineDialect", "memref::MemRefDialect", "tensor::TensorDialect"
30+
];
2931

3032
let extraClassDeclaration = [{
3133
/// An attribute that can override writability of buffers of tensor function

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,128 @@
1212
include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
1313
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
1414
include "mlir/Dialect/Bufferization/IR/BufferizationBase.td"
15+
include "mlir/Interfaces/InferTypeOpInterface.td"
1516
include "mlir/Interfaces/SideEffectInterfaces.td"
1617
include "mlir/Interfaces/CopyOpInterface.td"
1718

1819
class Bufferization_Op<string mnemonic, list<Trait> traits = []>
1920
: Op<Bufferization_Dialect, mnemonic, traits>;
2021

22+
//===----------------------------------------------------------------------===//
23+
// AllocTensorOp
24+
//===----------------------------------------------------------------------===//
25+
26+
def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
27+
[BufferizableOpInterface,
28+
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
29+
let summary = "buffer allocation in tensor land";
30+
31+
let description = [{
32+
`bufferization.alloc_tensor` is an operation that bufferizes to a buffer
33+
allocation of a given shape. The shape could be dynamic or static.
34+
Reading from the result of an `alloc_tensor` op yields an undefined value.
35+
36+
`alloc_tensor` is a helper op for bufferization. It marks the beginning of
37+
a new tensor SSA use-def chain and is used to control in-place bufferization
38+
decisions during One-Shot Bufferize.
39+
}];
40+
41+
let arguments =
42+
(ins Variadic<Index>:$sizes, I64ArrayAttr:$static_sizes);
43+
44+
let results = (outs AnyTensor:$result);
45+
46+
let assemblyFormat = [{
47+
custom<OperandsOrIntegersSizesList>($sizes, $static_sizes) attr-dict
48+
`:` type($result)
49+
}];
50+
51+
let extraClassDeclaration = [{
52+
LogicalResult bufferize(RewriterBase &rewriter, BufferizationState &state);
53+
54+
bool isMemoryWrite(OpResult opResult, const AnalysisState &state) const {
55+
// AllocTensorOps allocate but do not write.
56+
return false;
57+
}
58+
59+
static StringRef getStaticSizesAttrName() {
60+
return "static_sizes";
61+
}
62+
63+
RankedTensorType getType() {
64+
return getResult().getType().cast<RankedTensorType>();
65+
}
66+
67+
// Infer the shape of the result tensor given the static shapes
68+
// and element type of the result tensor.
69+
static Type inferResultType(ArrayRef<int64_t> staticSizes, Type elementType,
70+
Attribute encoding = {});
71+
72+
// Return true if the size of the tensor is dynamic at `idx`
73+
bool isDynamicSize(unsigned idx) {
74+
APInt v = *(static_sizes().getAsValueRange<IntegerAttr>().begin() + idx);
75+
return ShapedType::isDynamic(v.getSExtValue());
76+
}
77+
78+
// Assert that the size of the result tensor is static at `idx`
79+
// and return the shape.
80+
int64_t getStaticSize(unsigned idx) {
81+
assert(!isDynamicSize(idx) && "expected static size");
82+
APInt v = *(static_sizes().
83+
template getAsValueRange<IntegerAttr>().begin() + idx);
84+
return v.getSExtValue();
85+
}
86+
87+
// Return the argument position that contains the dynamic size of
88+
// the tensor at dimension `idx`. Asserts that the shape is
89+
// dynamic at that `idx`.
90+
unsigned getIndexOfDynamicSize(unsigned idx) {
91+
assert(isDynamicSize(idx) && "expected dynamic size");
92+
return std::count_if(
93+
static_sizes().getValue().begin(),
94+
static_sizes().getValue().begin() + idx,
95+
[&](Attribute attr) {
96+
return ShapedType::isDynamic(attr.cast<IntegerAttr>().getInt());
97+
});
98+
}
99+
100+
// Return both static and dynamic sizes as a list of `OpFoldResult`.
101+
SmallVector<OpFoldResult> getMixedSizes();
102+
103+
// Return the Value of the dynamic size of the tensor at dimension
104+
// `idx`. Asserts that the shape is dynamic at that `idx.
105+
Value getDynamicSize(unsigned idx) {
106+
return getOperand(getIndexOfDynamicSize(idx));
107+
}
108+
}];
109+
110+
let builders = [
111+
OpBuilder<(ins "ValueRange":$shape,
112+
"ArrayRef<int64_t>":$staticShape, "Type":$elementType),
113+
[{
114+
build($_builder, $_state,
115+
AllocTensorOp::inferResultType(staticShape, elementType),
116+
shape, $_builder.getI64ArrayAttr(staticShape));
117+
}]>,
118+
OpBuilder<(ins "ValueRange":$shape, "Type":$elementType),
119+
[{
120+
SmallVector<int64_t, 4> staticShape(
121+
shape.size(), ShapedType::kDynamicSize);
122+
build($_builder, $_state, shape, staticShape, elementType);
123+
}]>,
124+
OpBuilder<(ins "ArrayRef<int64_t>":$staticShape, "Type":$elementType),
125+
[{
126+
build($_builder, $_state, ValueRange{}, staticShape, elementType);
127+
}]>,
128+
OpBuilder<(ins "ArrayRef<OpFoldResult>":$sizes, "Type":$elementType,
129+
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>
130+
];
131+
132+
let hasCanonicalizer = 1;
133+
let hasCustomAssemblyFormat = 1;
134+
let hasVerifier = 1;
135+
}
136+
21137
//===----------------------------------------------------------------------===//
22138
// CloneOp
23139
//===----------------------------------------------------------------------===//
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
//===- AllocTensorElimination.h - alloc_tensor op elimination -------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ALLOCTENSORELIMINATION_H
10+
#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ALLOCTENSORELIMINATION_H
11+
12+
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
13+
14+
namespace mlir {
15+
namespace bufferization {
16+
17+
/// A function that matches anchor OpOperands for AllocTensorOp elimination.
18+
/// If an OpOperand is matched, the function should populate the SmallVector
19+
/// with all values that are needed during `RewriteFn` to produce the
20+
/// replacement value.
21+
using AnchorMatchFn = std::function<bool(OpOperand &, SmallVector<Value> &)>;
22+
23+
/// A function that rewrites matched anchors.
24+
using RewriteFn = std::function<Value(OpBuilder &, Location, OpOperand &)>;
25+
26+
/// Try to eliminate AllocTensorOps inside `op`.
27+
///
28+
/// * `rewriteFunc` generates the replacement for the AllocTensorOp.
29+
/// * Only AllocTensorOps that are anchored on a matching OpOperand as per
30+
/// `anchorMatchFunc` are considered. "Anchored" means that there is a path
31+
/// on the reverse SSA use-def chain, starting from the OpOperand and always
32+
/// following the aliasing OpOperand, that eventually ends at a single
33+
/// AllocTensorOp.
34+
LogicalResult eliminateAllocTensors(RewriterBase &rewriter, Operation *op,
35+
bufferization::AnalysisState &state,
36+
AnchorMatchFn anchorMatchFunc,
37+
RewriteFn rewriteFunc);
38+
39+
/// Try to eliminate AllocTensorOps inside `op` that are anchored on an
40+
/// InsertSliceOp, i.e., if it is eventually inserted into another tensor
41+
/// (and some other conditions are met).
42+
LogicalResult insertSliceAnchoredAllocTensorEliminationStep(
43+
RewriterBase &rewriter, Operation *op, bufferization::AnalysisState &state);
44+
45+
} // namespace bufferization
46+
} // namespace mlir
47+
48+
#endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ALLOCTENSORELIMINATION_H

mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,13 @@ createPromoteBuffersToStackPass(unsigned maxAllocSizeInBytes = 1024,
6464
std::unique_ptr<Pass>
6565
createPromoteBuffersToStackPass(std::function<bool(Value)> isSmallAlloc);
6666

67+
/// Create a pass that tries to eliminate alloc_tensor ops that are anchored on
68+
/// insert_slice ops.
69+
std::unique_ptr<Pass> createAllocTensorEliminationPass();
70+
71+
/// Create a pass that bufferizes ops from the bufferization dialect.
72+
std::unique_ptr<Pass> createBufferizationBufferizePass();
73+
6774
//===----------------------------------------------------------------------===//
6875
// Registration
6976
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Bufferization/Transforms/Passes.td

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,11 @@ def FinalizingBufferize : Pass<"finalizing-bufferize", "func::FuncOp"> {
149149
let constructor = "mlir::bufferization::createFinalizingBufferizePass()";
150150
}
151151

152+
def BufferizationBufferize : Pass<"bufferization-bufferize", "func::FuncOp"> {
153+
let summary = "Bufferize the `bufferization` dialect";
154+
let constructor = "mlir::bufferization::createBufferizationBufferizePass()";
155+
}
156+
152157
def OneShotBufferize : Pass<"one-shot-bufferize", "ModuleOp"> {
153158
let summary = "One-Shot Bufferize";
154159
let description = [{
@@ -309,4 +314,16 @@ def PromoteBuffersToStack : Pass<"promote-buffers-to-stack", "func::FuncOp"> {
309314
];
310315
}
311316

317+
def AllocTensorElimination : Pass<"eliminate-alloc-tensors"> {
318+
let summary = "Try to eliminate all alloc_tensor ops.";
319+
let description = [{
320+
This pass tries to eliminate all insert_slice op-anchored alloc_tensor ops.
321+
I.e., when a value that is equivalent to an alloc_tensor op is inserted into
322+
another tensor, this pass tries to rewrite the IR in such a way that the
323+
destination tensor of the insert_slice op is used directly instead of the
324+
alloc_tensor result.
325+
}];
326+
let constructor = "mlir::bufferization::createAllocTensorEliminationPass()";
327+
}
328+
312329
#endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_PASSES

mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,16 @@ class Linalg_Op<string mnemonic, list<Trait> traits = []> :
2727
def Linalg_InitTensorOp : Linalg_Op<"init_tensor",
2828
[NoSideEffect,
2929
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
30-
let summary = "operation to define a tensor of particular value";
30+
let summary = "operation to define a tensor of particular shape";
3131

3232
let description = [{
33-
`linalg.init_tensor` is an operation that materializes a tensor of
34-
a given shape. The shape could be dynamic or static.
33+
`linalg.init_tensor` is an operation that defines a tensor of a particular
34+
shape. The shape could be dynamic or static. The contents of the tensor are
35+
unspecified and the only purpose of the op result is to materialize the
36+
specified shape in IR and make it available to other transformations.
37+
38+
Note: This op can be lowered to a `bufferization.alloc_tensor`, at which
39+
point it turns into an explicit buffer allocation.
3540
}];
3641

3742
let arguments =

mlir/include/mlir/Dialect/Linalg/Passes.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,8 @@ createConvertLinalgToParallelLoopsPass();
6262
std::unique_ptr<OperationPass<func::FuncOp>>
6363
createConvertLinalgToAffineLoopsPass();
6464

65-
/// Create a pass that tries to eliminate init_tensor ops that are anchored on
66-
/// insert_slice ops.
67-
std::unique_ptr<Pass> createLinalgInitTensorEliminationPass();
65+
/// Create a pass that rewrites init_tensor to alloc_tensor.
66+
std::unique_ptr<Pass> createLinalgInitTensorToAllocTensorPass();
6867

6968
/// Create a pass to convert Linalg operations which work on tensors to use
7069
/// buffers instead.

mlir/include/mlir/Dialect/Linalg/Passes.td

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,14 @@ def ConvertElementwiseToLinalg : Pass<"convert-elementwise-to-linalg", ""> {
2424
let dependentDialects = ["linalg::LinalgDialect", "memref::MemRefDialect"];
2525
}
2626

27-
def LinalgInitTensorElimination : Pass<"linalg-eliminate-init-tensors"> {
28-
let summary = "Try to eliminate all init_tensor ops.";
27+
def LinalgInitTensorToAllocTensor : Pass<"linalg-init-tensor-to-alloc-tensor"> {
28+
let summary = "Replace all init_tensor ops by alloc_tensor ops.";
2929
let description = [{
30-
This pass tries to eliminate all insert_slice op-anchored init_tensor ops.
31-
I.e., when a value that is aliasing with an init_tensor op is inserted into
32-
another tensor, this pass tries to rewrite the IR in such a way that the
33-
destination tensor of the insert_slice op is used directly instead of the
34-
init_tensor result.
30+
init_tensor ops return a tensor of unspecified contents who's only purpose
31+
is to carry the tensor shape. This pass converts such ops to
32+
bufferization.alloc_tensor ops, which bufferize to buffer allocations.
3533
}];
36-
let constructor = "mlir::createLinalgInitTensorEliminationPass()";
34+
let constructor = "mlir::createLinalgInitTensorToAllocTensorPass()";
3735
}
3836

3937
def LinalgFoldUnitExtentDims : Pass<"linalg-fold-unit-extent-dims", ""> {

mlir/include/mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,43 +9,11 @@
99
#ifndef MLIR_DIALECT_LINALG_BUFFERIZABLEOPINTERFACEIMPL_H
1010
#define MLIR_DIALECT_LINALG_BUFFERIZABLEOPINTERFACEIMPL_H
1111

12-
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
13-
1412
namespace mlir {
1513
class DialectRegistry;
1614

1715
namespace linalg {
18-
19-
/// A function that matches anchor OpOperands for InitTensorOp elimination.
20-
/// If an OpOperand is matched, the function should populate the SmallVector
21-
/// with all values that are needed during `RewriteFn` to produce the
22-
/// replacement value.
23-
using AnchorMatchFn = std::function<bool(OpOperand &, SmallVector<Value> &)>;
24-
25-
/// A function that rewrites matched anchors.
26-
using RewriteFn = std::function<Value(OpBuilder &, Location, OpOperand &)>;
27-
28-
/// Try to eliminate InitTensorOps inside `op`.
29-
///
30-
/// * `rewriteFunc` generates the replacement for the InitTensorOp.
31-
/// * Only InitTensorOps that are anchored on a matching OpOperand as per
32-
/// `anchorMatchFunc` are considered. "Anchored" means that there is a path
33-
/// on the reverse SSA use-def chain, starting from the OpOperand and always
34-
/// following the aliasing OpOperand, that eventually ends at a single
35-
/// InitTensorOp.
36-
LogicalResult eliminateInitTensors(RewriterBase &rewriter, Operation *op,
37-
bufferization::AnalysisState &state,
38-
AnchorMatchFn anchorMatchFunc,
39-
RewriteFn rewriteFunc);
40-
41-
/// Try to eliminate InitTensorOps inside `op` that are anchored on an
42-
/// InsertSliceOp, i.e., if it is eventually inserted into another tensor
43-
/// (and some other conditions are met).
44-
LogicalResult insertSliceAnchoredInitTensorEliminationStep(
45-
RewriterBase &rewriter, Operation *op, bufferization::AnalysisState &state);
46-
4716
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
48-
4917
} // namespace linalg
5018
} // namespace mlir
5119

mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
910
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
1011
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1112
#include "mlir/Dialect/Tensor/IR/Tensor.h"

0 commit comments

Comments
 (0)