Skip to content

Commit 913286b

Browse files
[mlir][linalg] Add SubsetInsertionOpInterface to linalg.copy (#67524)
This commit enables empty tensor elimination on `linalg.copy` ops.
1 parent 977289e commit 913286b

File tree

6 files changed

+94
-0
lines changed

6 files changed

+94
-0
lines changed
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
//===- SubsetInsertionOpInterfaceImpl.h - Tensor subsets --------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_LINALG_SUBSETINSERTIONOPINTERFACEIMPL_H
10+
#define MLIR_DIALECT_LINALG_SUBSETINSERTIONOPINTERFACEIMPL_H
11+
12+
namespace mlir {
13+
class DialectRegistry;
14+
15+
namespace linalg {
16+
void registerSubsetInsertionOpInterfaceExternalModels(
17+
DialectRegistry &registry);
18+
} // namespace linalg
19+
} // namespace mlir
20+
21+
#endif // MLIR_DIALECT_LINALG_SUBSETINSERTIONOPINTERFACEIMPL_H

mlir/include/mlir/InitAllDialects.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
#include "mlir/Dialect/Linalg/IR/Linalg.h"
4646
#include "mlir/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.h"
4747
#include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h"
48+
#include "mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h"
4849
#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
4950
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
5051
#include "mlir/Dialect/Math/IR/Math.h"
@@ -148,6 +149,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
148149
cf::registerBufferDeallocationOpInterfaceExternalModels(registry);
149150
gpu::registerBufferDeallocationOpInterfaceExternalModels(registry);
150151
linalg::registerBufferizableOpInterfaceExternalModels(registry);
152+
linalg::registerSubsetInsertionOpInterfaceExternalModels(registry);
151153
linalg::registerTilingInterfaceExternalModels(registry);
152154
linalg::registerValueBoundsOpInterfaceExternalModels(registry);
153155
memref::registerAllocationOpInterfaceExternalModels(registry);

mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
110110
// be replaced, but the transformation may not be beneficial.
111111
if (!state.isInPlace(source))
112112
return WalkResult::skip();
113+
113114
// All values that are needed to create the replacement op.
114115
SmallVector<Value> neededValues =
115116
op.getValuesNeededToBuildSubsetExtraction();

mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
2727
Split.cpp
2828
SplitReduction.cpp
2929
SubsetHoisting.cpp
30+
SubsetInsertionOpInterfaceImpl.cpp
3031
SwapExtractSliceWithFillPatterns.cpp
3132
Tiling.cpp
3233
TilingInterfaceImpl.cpp
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
//===- SubsetInsertionOpInterfaceImpl.cpp - Tensor subsets ----------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h"
10+
11+
#include "mlir/Dialect/Bufferization/IR/SubsetInsertionOpInterface.h"
12+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
13+
14+
using namespace mlir;
15+
using namespace mlir::bufferization;
16+
using namespace mlir::linalg;
17+
18+
namespace {
19+
struct LinalgCopyOpInterface
20+
: public SubsetInsertionOpInterface::ExternalModel<LinalgCopyOpInterface,
21+
linalg::CopyOp> {
22+
OpOperand &getSourceOperand(Operation *op) const {
23+
auto copyOp = cast<CopyOp>(op);
24+
assert(copyOp.getInputs().size() == 1 && "expected single input");
25+
return copyOp.getInputsMutable()[0];
26+
}
27+
28+
bool
29+
isEquivalentSubset(Operation *op, Value candidate,
30+
function_ref<bool(Value, Value)> equivalenceFn) const {
31+
auto copyOp = cast<CopyOp>(op);
32+
assert(copyOp.getOutputs().size() == 1 && "expected single output");
33+
return equivalenceFn(candidate, copyOp.getOutputs()[0]);
34+
}
35+
36+
Value buildSubsetExtraction(Operation *op, OpBuilder &builder,
37+
Location loc) const {
38+
auto copyOp = cast<CopyOp>(op);
39+
assert(copyOp.getOutputs().size() == 1 && "expected single output");
40+
return copyOp.getOutputs()[0];
41+
}
42+
43+
SmallVector<Value>
44+
getValuesNeededToBuildSubsetExtraction(Operation *op) const {
45+
auto copyOp = cast<CopyOp>(op);
46+
assert(copyOp.getOutputs().size() == 1 && "expected single output");
47+
return {copyOp.getOutputs()[0]};
48+
}
49+
};
50+
} // namespace
51+
52+
void mlir::linalg::registerSubsetInsertionOpInterfaceExternalModels(
53+
DialectRegistry &registry) {
54+
registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) {
55+
linalg::CopyOp::attachInterface<LinalgCopyOpInterface>(*ctx);
56+
});
57+
}

mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,3 +305,15 @@ func.func @materialize_in_destination(%t: tensor<5xf32>, %f: f32) -> tensor<5xf3
305305
return %1 : tensor<5xf32>
306306
}
307307

308+
// -----
309+
310+
// CHECK-LABEL: func @linalg_copy(
311+
// CHECK-SAME: %[[m:.*]]: memref<5xf32, strided<[?], offset: ?>>,
312+
// CHECK: linalg.fill {{.*}} outs(%[[m]]
313+
// CHECK: return %[[m]]
314+
func.func @linalg_copy(%t: tensor<5xf32>, %f: f32) -> tensor<5xf32> {
315+
%0 = tensor.empty() : tensor<5xf32>
316+
%filled = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>
317+
%1 = linalg.copy ins(%filled : tensor<5xf32>) outs(%t : tensor<5xf32>) -> tensor<5xf32>
318+
return %1 : tensor<5xf32>
319+
}

0 commit comments

Comments
 (0)