Skip to content

Commit 4146ffb

Browse files
committed
[mlir][MLProgram] Implement BufferizableOpInterface
This commit implements the `BufferizableOpInterface` for `ml_program.global`, `ml_program.global_load` and `ml_program.global_store` so that these ops can be lowered all the way to LLVM.
1 parent 4fcd7cf commit 4146ffb

File tree

7 files changed

+204
-1
lines changed

7 files changed

+204
-1
lines changed
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
//===- BufferizableOpInterfaceImpl.h - Impl. of BufferizableOpInterface ---===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_MLPROGRAM_BUFFERIZABLEOPINTERFACEIMPL_H
10+
#define MLIR_DIALECT_MLPROGRAM_BUFFERIZABLEOPINTERFACEIMPL_H
11+
12+
namespace mlir {
13+
class DialectRegistry;
14+
15+
namespace ml_program {
16+
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
17+
} // namespace ml_program
18+
} // namespace mlir
19+
20+
#endif // MLIR_DIALECT_MLPROGRAM_BUFFERIZABLEOPINTERFACEIMPL_H

mlir/include/mlir/InitAllDialects.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
#include "mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h"
4949
#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
5050
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
51+
#include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h"
5152
#include "mlir/Dialect/Math/IR/Math.h"
5253
#include "mlir/Dialect/MemRef/IR/MemRef.h"
5354
#include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h"
@@ -160,6 +161,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
160161
memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
161162
memref::registerValueBoundsOpInterfaceExternalModels(registry);
162163
memref::registerMemorySlotExternalModels(registry);
164+
ml_program::registerBufferizableOpInterfaceExternalModels(registry);
163165
scf::registerBufferDeallocationOpInterfaceExternalModels(registry);
164166
scf::registerBufferizableOpInterfaceExternalModels(registry);
165167
scf::registerValueBoundsOpInterfaceExternalModels(registry);

mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,10 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
494494
<< "\n//===-------------------------------------------===//\n");
495495
}
496496

497+
// Return early if the top-level op is entirely gone.
498+
if (erasedOps.contains(op))
499+
return success();
500+
497501
// Fold all to_memref(to_tensor(x)) pairs.
498502
for (Operation *op : toMemrefOps) {
499503
rewriter.setInsertionPoint(op);

mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
459459
}
460460

461461
// Bufferize all other ops.
462-
for (Operation &op : moduleOp.getOps()) {
462+
for (Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) {
463463
// Functions were already bufferized.
464464
if (isa<func::FuncOp>(&op))
465465
continue;
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
10+
11+
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12+
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
13+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
14+
15+
using namespace mlir;
16+
using namespace mlir::bufferization;
17+
using namespace mlir::ml_program;
18+
19+
namespace mlir {
20+
namespace ml_program {
21+
22+
template <typename Interface, typename Op>
23+
struct ExternalModelBase
24+
: public BufferizableOpInterface::ExternalModel<Interface, Op> {
25+
26+
AliasingValueList getAliasingValues(Operation *, OpOperand &,
27+
const AnalysisState &) const {
28+
return {};
29+
}
30+
31+
BufferRelation bufferRelation(Operation *, OpResult,
32+
const AnalysisState &) const {
33+
return BufferRelation::Unknown;
34+
}
35+
};
36+
37+
/// Bufferization of ml_program.global into a memref.global
38+
struct GlobalOpInterface
39+
: public ExternalModelBase<GlobalOpInterface, GlobalOp> {
40+
41+
bool bufferizesToMemoryRead(Operation *, OpOperand &,
42+
const AnalysisState &) const {
43+
return false;
44+
}
45+
46+
bool bufferizesToMemoryWrite(Operation *, OpOperand &,
47+
const AnalysisState &) const {
48+
return false;
49+
}
50+
51+
bool hasTensorSemantics(Operation *) const { return true; }
52+
53+
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
54+
const BufferizationOptions &) const {
55+
auto globalOp = cast<GlobalOp>(op);
56+
if (!globalOp.getValue().has_value())
57+
return globalOp.emitError("global op must have a value");
58+
59+
auto tensorType = cast<TensorType>(globalOp.getType());
60+
auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
61+
62+
replaceOpWithNewBufferizedOp<memref::GlobalOp>(
63+
rewriter, globalOp, globalOp.getSymName(),
64+
/*sym_visibility=*/globalOp.getSymVisibilityAttr(),
65+
/*type=*/cast<MemRefType>(memrefType),
66+
/*initial_value=*/globalOp.getValue().value(),
67+
/*constant=*/!globalOp.getIsMutable(),
68+
/*alignment=*/nullptr);
69+
return success();
70+
}
71+
};
72+
73+
/// Bufferization of ml_program.global_load into a memref.get_global
74+
struct GlobalLoadOpInterface
75+
: public ExternalModelBase<GlobalLoadOpInterface, GlobalLoadOp> {
76+
77+
bool bufferizesToMemoryRead(Operation *, OpOperand &,
78+
const AnalysisState &) const {
79+
return false;
80+
}
81+
82+
bool bufferizesToMemoryWrite(Operation *, OpOperand &,
83+
const AnalysisState &) const {
84+
return false;
85+
}
86+
87+
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
88+
const BufferizationOptions &) const {
89+
auto globalLoadOp = cast<GlobalLoadOp>(op);
90+
91+
auto tensorType = cast<TensorType>(globalLoadOp.getType());
92+
auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
93+
94+
replaceOpWithNewBufferizedOp<memref::GetGlobalOp>(
95+
rewriter, globalLoadOp, memrefType,
96+
globalLoadOp.getGlobalAttr().getLeafReference());
97+
98+
return success();
99+
}
100+
};
101+
102+
/// Bufferization of ml_program.global_store into a memref.get_global and
103+
/// memcpy
104+
struct GlobalStoreOpInterface
105+
: public ExternalModelBase<GlobalStoreOpInterface, GlobalStoreOp> {
106+
107+
bool bufferizesToMemoryRead(Operation *, OpOperand &,
108+
const AnalysisState &) const {
109+
return false;
110+
}
111+
112+
bool bufferizesToMemoryWrite(Operation *, OpOperand &,
113+
const AnalysisState &) const {
114+
return true;
115+
}
116+
117+
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
118+
const BufferizationOptions &options) const {
119+
auto globalStoreOp = cast<GlobalStoreOp>(op);
120+
121+
auto tensorType = cast<TensorType>(globalStoreOp.getValue().getType());
122+
auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
123+
124+
auto loc = globalStoreOp.getLoc();
125+
auto targetMemref = rewriter.create<memref::GetGlobalOp>(
126+
loc, memrefType, globalStoreOp.getGlobalAttr().getLeafReference());
127+
128+
auto sourceMemref = getBuffer(rewriter, globalStoreOp.getValue(), options);
129+
if (failed(sourceMemref)) {
130+
return failure();
131+
}
132+
133+
auto memcpy =
134+
options.createMemCpy(rewriter, loc, sourceMemref.value(), targetMemref);
135+
if (failed(memcpy)) {
136+
return failure();
137+
}
138+
rewriter.eraseOp(globalStoreOp);
139+
140+
return success();
141+
}
142+
};
143+
144+
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
145+
registry.addExtension(+[](MLIRContext *ctx, MLProgramDialect *) {
146+
GlobalOp::attachInterface<GlobalOpInterface>(*ctx);
147+
GlobalLoadOp::attachInterface<GlobalLoadOpInterface>(*ctx);
148+
GlobalStoreOp::attachInterface<GlobalStoreOpInterface>(*ctx);
149+
});
150+
}
151+
152+
} // namespace ml_program
153+
} // namespace mlir

mlir/lib/Dialect/MLProgram/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_mlir_dialect_library(MLIRMLProgramTransforms
2+
BufferizableOpInterfaceImpl.cpp
23
PipelineGlobalOps.cpp
34

45
ADDITIONAL_HEADER_DIRS
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// RUN: mlir-opt %s -one-shot-bufferize -split-input-file | FileCheck %s
2+
3+
// CHECK-LABEL: memref.global "private" @global
4+
ml_program.global private mutable @global(dense<0> : tensor<i64>) : tensor<i64>
5+
6+
// CHECK-LABEL: @global_load_store
7+
func.func @global_load_store() -> i64 {
8+
// CHECK-DAG: %[[CST127:.+]] = arith.constant 127
9+
// CHECK-DAG: %[[GLOBAL_1:.+]] = memref.get_global @global
10+
// CHECK: %[[VALUE:.+]] = memref.load %[[GLOBAL_1]][]
11+
// CHECK: %[[NEW_VALUE:.+]] = arith.muli %[[VALUE]], %[[CST127]]
12+
// CHECK: memref.store %[[NEW_VALUE]], %[[GLOBAL_1]][]
13+
// CHECK: %[[GLOBAL_2:.+]] = memref.get_global @global
14+
// CHECK: memref.copy %[[GLOBAL_1]], %[[GLOBAL_2]]
15+
// CHECK: return %[[NEW_VALUE]]
16+
%c127_i64 = arith.constant 127 : i64
17+
%0 = ml_program.global_load @global : tensor<i64>
18+
%extracted = tensor.extract %0[] : tensor<i64>
19+
%1 = arith.muli %extracted, %c127_i64 : i64
20+
%inserted = tensor.insert %1 into %0[] : tensor<i64>
21+
ml_program.global_store @global = %inserted : tensor<i64>
22+
return %1 : i64
23+
}

0 commit comments

Comments
 (0)