Skip to content

Commit 8143307

Browse files
[mlir][bufferization] Generalize tensor slice rules to subset ops (#65619)
This commit generalizes the special tensor.extract_slice/tensor.insert_slice bufferization rules to tensor subset ops. Ops that insert a tensor into a tensor at a specified subset (e.g., tensor.insert_slice, tensor.scatter) can implement the `SubsetInsertionOpInterface`. Apart from adding a new op interface (extending the API), this change is NFC. The only ops that currently implement the new interface are tensor.insert_slice and tensor.parallel_insert_slice, and those ops were are supported by One-Shot Bufferize.
1 parent b8ec283 commit 8143307

File tree

12 files changed

+425
-131
lines changed

12 files changed

+425
-131
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ add_mlir_doc(BufferizationOps BufferizationOps Dialects/ -gen-dialect-doc)
33
add_mlir_interface(AllocationOpInterface)
44
add_mlir_interface(BufferDeallocationOpInterface)
55
add_mlir_interface(BufferizableOpInterface)
6+
add_mlir_interface(SubsetInsertionOpInterface)
67

78
set(LLVM_TARGET_DEFINITIONS BufferizationEnums.td)
89
mlir_tablegen(BufferizationEnums.h.inc -gen-enum-decls)
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
//===- SubsetInsertionOpInterface.h - Tensor Subsets ------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_BUFFERIZATION_IR_SUBSETINSERTIONOPINTERFACE_H_
10+
#define MLIR_DIALECT_BUFFERIZATION_IR_SUBSETINSERTIONOPINTERFACE_H_
11+
12+
#include "mlir/IR/OpDefinition.h"
13+
14+
namespace mlir {
15+
namespace bufferization {
16+
namespace detail {
17+
18+
/// Return the destination/"init" operand of the op if it implements the
19+
/// `DestinationStyleOpInterface` and has exactly one "init" operand. Asserts
20+
/// otherwise.
21+
OpOperand &defaultGetDestinationOperand(Operation *op);
22+
23+
} // namespace detail
24+
} // namespace bufferization
25+
} // namespace mlir
26+
27+
#include "mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.h.inc"
28+
29+
#endif // MLIR_DIALECT_BUFFERIZATION_IR_SUBSETINSERTIONOPINTERFACE_H_
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
//===-- SubsetInsertionOpInterface.td - Tensor Subsets -----*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef SUBSET_INSERTION_OP_INTERFACE
10+
#define SUBSET_INSERTION_OP_INTERFACE
11+
12+
include "mlir/IR/OpBase.td"
13+
14+
def SubsetInsertionOpInterface : OpInterface<"SubsetInsertionOpInterface"> {
15+
let description = [{
16+
This interface can be implemented by ops that insert a source tensor into
17+
a destination tensor.
18+
19+
The elements in the destination tensor that are overwritten by this
20+
insertion are called the "subset". How the subset is defined is up to the
21+
op. E.g., "tensor.insert_slice" defines the subset via a hyperrectangular
22+
slice. A scatter operation could define the subset via a list of indices.
23+
24+
Ops that deal with tensor subsets come in two flavours:
25+
- Insertion flavor: Ops that insert a source tensor into a destination
26+
tensor at the specified subset. Such ops usually return a new destination
27+
tensor and implement the `DestinationStyleOpInterface`. Insertion ops can
28+
implement the `SubsetInsertionOpInterface`. Example: "tensor.insert_slice"
29+
- Extraction flavor: Ops that define a tensor subset. They extract a
30+
specified subset from a tensor. There is currently no op interface for
31+
such ops. Example: "tensor.extract_slice"
32+
33+
This interface provides helper methods for efficient bufferization of
34+
subset-based tensor IR. Tensor subsets can bufferize to buffer "views"/
35+
"aliases" (in contrast to one or multiple less efficient buffer allocation).
36+
37+
This interface is queried by One-Shot Bufferize to detect cases where a
38+
seeming read-after-write is not actually a conflict because the respective
39+
ops are operating on equivalent subsets. More details can be found in the
40+
documentation of One-Shot Analysis (see `areNonConflictingSubsets`).
41+
42+
Note: This interface currently assumes that a subset op inserts a single
43+
tensor (source) into a destination tensor at a single subset.
44+
}];
45+
let cppNamespace = "::mlir::bufferization";
46+
let methods = [
47+
InterfaceMethod<
48+
/*desc=*/[{
49+
Return the source tensor operand.
50+
}],
51+
/*retType=*/"::mlir::OpOperand &",
52+
/*methodName=*/"getSourceOperand",
53+
/*args=*/(ins)
54+
>,
55+
InterfaceMethod<
56+
/*desc=*/[{
57+
Return the destination tensor operand.
58+
}],
59+
/*retType=*/"::mlir::OpOperand &",
60+
/*methodName=*/"getDestinationOperand",
61+
/*args=*/(ins),
62+
/*methodBody=*/"",
63+
/*defaultImplementation=*/[{
64+
return ::mlir::bufferization::detail::defaultGetDestinationOperand(
65+
$_op.getOperation());
66+
}]
67+
>,
68+
InterfaceMethod<
69+
/*desc=*/[{
70+
Return "true" if this operation inserts into a subset that is
71+
equivalent to the subset defined by `candidate`.
72+
73+
Two subsets are "equivalent" and "same" if they can bufferize to the
74+
same buffer views/aliases. If they are "equivalent", the tensor IR
75+
may be expressed in terms of different SSA values (but they could
76+
bufferize to MemRef SSA values that can CSE without breaking
77+
correctness). `equivalenceFn` should return "true" if the two given
78+
values are equivalent.
79+
80+
Example:
81+
```
82+
// The subset of the SubsetInsertionOpInterface op %1 is equivalent to
83+
// the subset defined by %2 (but not "same"):
84+
%0 = arith.select %c, %t, %t : tensor<?xf32>
85+
%1 = tensor.insert_slice %x into %0[0][5][1]
86+
: tensor<5xf32> into tensor<?xf32>
87+
%2 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
88+
89+
// The subset of the SubsetInsertionOpInterface op %1 is equivalent to
90+
// and "same" as the subset defined by %2.
91+
%1 = tensor.insert_slice %x into %t[0][5][1]
92+
: tensor<5xf32> into tensor<?xf32>
93+
%2 = tensor.extract_slice %t[0][5][1] : tensor<?xf32> to tensor<5xf32>
94+
```
95+
}],
96+
/*retType=*/"bool",
97+
/*methodName=*/"isEquivalentSubset",
98+
/*args=*/(ins
99+
"::mlir::Value":$candidate,
100+
"::llvm::function_ref<bool(Value, Value)>":$equivalenceFn)
101+
>,
102+
];
103+
104+
let extraClassDeclaration = [{
105+
/// Return "true" if this operation inserts into the same subset as defined
106+
/// by `candidate`.
107+
///
108+
/// Note: This function is useful outside of bufferization, where no tensor
109+
/// equivalence information is available.
110+
bool isSameSubset(OpResult candidate) {
111+
auto subsetOp = cast<::mlir::bufferization::SubsetInsertionOpInterface>(
112+
getOperation());
113+
return subsetOp.isEquivalentSubset(
114+
candidate, [](Value v1, Value v2) { return v1 == v2; });
115+
}
116+
}];
117+
}
118+
119+
#endif // SUBSET_INSERTION_OP_INTERFACE
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
//===- SubsetInsertionOpInterfaceImpl.h - Tensor subsets ------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_TENSOR_SUBSETINSERTIONOPINTERFACEIMPL_H
10+
#define MLIR_DIALECT_TENSOR_SUBSETINSERTIONOPINTERFACEIMPL_H
11+
12+
namespace mlir {
13+
class DialectRegistry;
14+
15+
namespace tensor {
16+
void registerSubsetInsertionOpInterfaceExternalModels(
17+
DialectRegistry &registry);
18+
} // namespace tensor
19+
} // namespace mlir
20+
21+
#endif // MLIR_DIALECT_TENSOR_SUBSETINSERTIONOPINTERFACEIMPL_H

mlir/include/mlir/InitAllDialects.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
#include "mlir/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.h"
7575
#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
7676
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
77+
#include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
7778
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
7879
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
7980
#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
@@ -158,6 +159,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
158159
tensor::registerBufferizableOpInterfaceExternalModels(registry);
159160
tensor::registerFindPayloadReplacementOpInterfaceExternalModels(registry);
160161
tensor::registerInferTypeOpInterfaceExternalModels(registry);
162+
tensor::registerSubsetInsertionOpInterfaceExternalModels(registry);
161163
tensor::registerTilingInterfaceExternalModels(registry);
162164
tensor::registerValueBoundsOpInterfaceExternalModels(registry);
163165
vector::registerBufferizableOpInterfaceExternalModels(registry);

mlir/lib/Dialect/Bufferization/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ add_mlir_dialect_library(MLIRBufferizationDialect
44
BufferDeallocationOpInterface.cpp
55
BufferizationOps.cpp
66
BufferizationDialect.cpp
7+
SubsetInsertionOpInterface.cpp
78
UnstructuredControlFlow.cpp
89

910
ADDITIONAL_HEADER_DIRS
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
//===- SubsetInsertionOpInterface.cpp - Tensor Subsets --------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.h"
10+
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
11+
12+
#include "mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.cpp.inc"
13+
14+
using namespace mlir;
15+
16+
OpOperand &bufferization::detail::defaultGetDestinationOperand(Operation *op) {
17+
auto dstOp = dyn_cast<DestinationStyleOpInterface>(op);
18+
assert(dstOp && "getDestination must be implemented for non-DPS ops");
19+
assert(
20+
dstOp.getNumDpsInits() == 1 &&
21+
"getDestination must be implemented for ops with 0 or more than 1 init");
22+
return *dstOp.getDpsInitOperand(0);
23+
}

mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545

4646
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
4747
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
48+
#include "mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.h"
4849
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
4950
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
5051
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -531,6 +532,105 @@ static bool hasEquivalentValueInReverseUseDefChain(AnalysisState &state,
531532
.empty();
532533
}
533534

535+
/// Return "true" if `value` is originating from a subset that is equivalent to
536+
/// the subset that `subsetOp` inserts into.
537+
static bool matchesInsertDestination(const AnalysisState &state, Value value,
538+
SubsetInsertionOpInterface subsetOp) {
539+
auto matchingSubset = [&](Value val) {
540+
if (auto opResult = dyn_cast<OpResult>(val))
541+
if (subsetOp.isEquivalentSubset(opResult, [&](Value v1, Value v2) {
542+
return state.areEquivalentBufferizedValues(v1, v2);
543+
}))
544+
return true;
545+
return false;
546+
};
547+
// There may be multiple leaves at which the reverse SSA use-def chain lookup
548+
// terminates. All of them must be equivalent subsets.
549+
SetVector<Value> backwardSlice =
550+
state.findValueInReverseUseDefChain(value, matchingSubset);
551+
return static_cast<bool>(llvm::all_of(backwardSlice, matchingSubset));
552+
}
553+
554+
/// Return "true" if the given "read" and potentially conflicting "write" are
555+
/// not conflicting due to their subset relationship. The comments in this
556+
/// function are expressed in terms of tensor.extract_slice/tensor.insert_slice
557+
/// pairs, but apply to any subset ops that implement the
558+
/// `SubsetInsertionOpInterface`.
559+
static bool areNonConflictingSubsets(OpOperand *uRead,
560+
OpOperand *uConflictingWrite,
561+
const AnalysisState &state) {
562+
Operation *readingOp = uRead->getOwner();
563+
Operation *conflictingWritingOp = uConflictingWrite->getOwner();
564+
565+
// Special rules for matching ExtractSliceOp/InsertSliceOp pairs. If
566+
// uRead is an InsertSliceOp...
567+
if (auto subsetOp = dyn_cast<SubsetInsertionOpInterface>(readingOp)) {
568+
// As an example, consider the following IR.
569+
//
570+
// %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
571+
// %1 = linalg.fill %cst, %0 {inplace= [true] }
572+
// %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
573+
// {inplace= [true] }
574+
575+
if (uRead == &subsetOp.getDestinationOperand() &&
576+
matchesInsertDestination(state, uConflictingWrite->get(), subsetOp))
577+
// Case 1: The main insight is that InsertSliceOp reads only part of
578+
// the destination tensor. The overwritten area is not read. If
579+
// uConflictingWrite writes into exactly the memory location that is
580+
// being read by uRead, this is not a conflict.
581+
//
582+
// In the above example:
583+
// uRead = OpOperand 1 (%t) of tensor.insert_slice
584+
// uConflictingWrite = OpOperand 1 (%0) of linalg.fill
585+
//
586+
// The read of %t does not conflict with the write of the FillOp
587+
// (same aliases!) because the area that the FillOp operates on is
588+
// exactly the one that is *not* read via %t.
589+
return true;
590+
591+
if (uRead == &subsetOp.getSourceOperand() &&
592+
uConflictingWrite == &subsetOp.getDestinationOperand() &&
593+
matchesInsertDestination(state, uRead->get(), subsetOp))
594+
// Case 2: The read of the source tensor and the write to the dest
595+
// tensor via an InsertSliceOp is not a conflict if the read is
596+
// reading exactly that part of an equivalent tensor that the
597+
// InsertSliceOp is writing.
598+
//
599+
// In the above example:
600+
// uRead = OpOperand 0 (%1) of tensor.insert_slice
601+
// uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
602+
return true;
603+
}
604+
605+
// If uConflictingWrite is an InsertSliceOp...
606+
if (auto subsetOp =
607+
dyn_cast<SubsetInsertionOpInterface>(conflictingWritingOp))
608+
// As an example, consider the following IR.
609+
//
610+
// %0 = tensor.extract_slice %t[%a, %b][%c, %d][1, 1] {inplace = [true] }
611+
// %1 = linalg.fill %cst, %0 {inplace= [true] }
612+
// %2 = tensor.insert_slice %1 into %t[%a, %b][%c, %d][1, 1]
613+
// {inplace= [true] }
614+
// %3 = vector.transfer_read %1, %cst
615+
//
616+
// In the above example:
617+
// uRead = OpOperand 0 (%1) of vector.transfer_read
618+
// uConflictingWrite = OpOperand 1 (%t) of tensor.insert_slice
619+
// definition = %1
620+
//
621+
// This is not a conflict because the InsertSliceOp overwrites the
622+
// memory segment of %1 with the exact same data. (Effectively, there
623+
// is no memory write here.)
624+
if (uConflictingWrite == &subsetOp.getDestinationOperand() &&
625+
state.areEquivalentBufferizedValues(
626+
uRead->get(), subsetOp.getSourceOperand().get()) &&
627+
matchesInsertDestination(state, subsetOp.getSourceOperand().get(),
628+
subsetOp))
629+
return true;
630+
631+
return false;
632+
}
633+
534634
/// Given sets of uses and writes, return true if there is a RaW conflict under
535635
/// the assumption that all given reads/writes alias the same buffer and that
536636
/// all given writes bufferize inplace.
@@ -684,6 +784,12 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
684784
}
685785
}
686786

787+
// No conflict if the operands are non-conflicting subsets.
788+
if (areNonConflictingSubsets(uRead, uConflictingWrite, state)) {
789+
LLVM_DEBUG(llvm::dbgs() << " no conflict: non-conflicting subsets\n");
790+
continue;
791+
}
792+
687793
// No conflict if the op interface says so.
688794
if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp)) {
689795
if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state)) {

0 commit comments

Comments
 (0)