Skip to content

Commit 9eced81

Browse files
committed
[mlir][MLProgram] Implement BufferizableOpInterface
This commit implements the `BufferizableOpInterface` for `ml_program.global`, `ml_program.global_load` and `ml_program.global_store` so that these ops can be lowered all the way to LLVM.
1 parent 4fcd7cf commit 9eced81

File tree

7 files changed

+238
-1
lines changed

7 files changed

+238
-1
lines changed
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
//===- BufferizableOpInterfaceImpl.h - Impl. of BufferizableOpInterface ---===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DIALECT_MLPROGRAM_BUFFERIZABLEOPINTERFACEIMPL_H
10+
#define MLIR_DIALECT_MLPROGRAM_BUFFERIZABLEOPINTERFACEIMPL_H
11+
12+
namespace mlir {
13+
class DialectRegistry;
14+
15+
namespace ml_program {
16+
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
17+
} // namespace ml_program
18+
} // namespace mlir
19+
20+
#endif // MLIR_DIALECT_MLPROGRAM_BUFFERIZABLEOPINTERFACEIMPL_H

mlir/include/mlir/InitAllDialects.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
#include "mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h"
4949
#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"
5050
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
51+
#include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h"
5152
#include "mlir/Dialect/Math/IR/Math.h"
5253
#include "mlir/Dialect/MemRef/IR/MemRef.h"
5354
#include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h"
@@ -160,6 +161,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
160161
memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
161162
memref::registerValueBoundsOpInterfaceExternalModels(registry);
162163
memref::registerMemorySlotExternalModels(registry);
164+
ml_program::registerBufferizableOpInterfaceExternalModels(registry);
163165
scf::registerBufferDeallocationOpInterfaceExternalModels(registry);
164166
scf::registerBufferizableOpInterfaceExternalModels(registry);
165167
scf::registerValueBoundsOpInterfaceExternalModels(registry);

mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,10 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
494494
<< "\n//===-------------------------------------------===//\n");
495495
}
496496

497+
// Return early if the top-level op is entirely gone.
498+
if (erasedOps.contains(op))
499+
return success();
500+
497501
// Fold all to_memref(to_tensor(x)) pairs.
498502
for (Operation *op : toMemrefOps) {
499503
rewriter.setInsertionPoint(op);

mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
459459
}
460460

461461
// Bufferize all other ops.
462-
for (Operation &op : moduleOp.getOps()) {
462+
for (Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) {
463463
// Functions were already bufferized.
464464
if (isa<func::FuncOp>(&op))
465465
continue;
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
//===- BufferizableOpInterfaceImpl.cpp - Impl. of BufferizableOpInterface -===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
10+
11+
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
12+
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
13+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
14+
15+
using namespace mlir;
16+
using namespace mlir::bufferization;
17+
using namespace mlir::ml_program;
18+
19+
namespace mlir {
20+
namespace ml_program {
21+
namespace {
22+
23+
template <typename Interface, typename Op>
24+
struct ExternalModelBase
25+
: public BufferizableOpInterface::ExternalModel<Interface, Op> {
26+
27+
AliasingValueList getAliasingValues(Operation *, OpOperand &,
28+
const AnalysisState &) const {
29+
return {};
30+
}
31+
32+
BufferRelation bufferRelation(Operation *, OpResult,
33+
const AnalysisState &) const {
34+
return BufferRelation::Unknown;
35+
}
36+
};
37+
38+
/// Bufferization of ml_program.global into a memref.global
39+
struct GlobalOpInterface
40+
: public ExternalModelBase<GlobalOpInterface, GlobalOp> {
41+
42+
bool bufferizesToMemoryRead(Operation *, OpOperand &,
43+
const AnalysisState &) const {
44+
return false;
45+
}
46+
47+
bool bufferizesToMemoryWrite(Operation *, OpOperand &,
48+
const AnalysisState &) const {
49+
return false;
50+
}
51+
52+
bool hasTensorSemantics(Operation *) const { return true; }
53+
54+
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
55+
const BufferizationOptions &) const {
56+
auto globalOp = cast<GlobalOp>(op);
57+
if (!globalOp.getValue().has_value())
58+
return globalOp.emitError("global op must have a value");
59+
60+
auto tensorType = cast<TensorType>(globalOp.getType());
61+
auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
62+
63+
replaceOpWithNewBufferizedOp<memref::GlobalOp>(
64+
rewriter, globalOp, globalOp.getSymName(),
65+
/*sym_visibility=*/globalOp.getSymVisibilityAttr(),
66+
/*type=*/cast<MemRefType>(memrefType),
67+
/*initial_value=*/globalOp.getValue().value(),
68+
/*constant=*/!globalOp.getIsMutable(),
69+
/*alignment=*/nullptr);
70+
return success();
71+
}
72+
};
73+
74+
/// Bufferization of ml_program.global_load into a memref.get_global
75+
struct GlobalLoadOpInterface
76+
: public ExternalModelBase<GlobalLoadOpInterface, GlobalLoadOp> {
77+
78+
bool bufferizesToMemoryRead(Operation *, OpOperand &,
79+
const AnalysisState &) const {
80+
return false;
81+
}
82+
83+
bool bufferizesToMemoryWrite(Operation *, OpOperand &,
84+
const AnalysisState &) const {
85+
return false;
86+
}
87+
88+
bool isWritable(Operation *, Value, const AnalysisState &) const {
89+
return false;
90+
}
91+
92+
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
93+
const BufferizationOptions &) const {
94+
auto globalLoadOp = cast<GlobalLoadOp>(op);
95+
96+
auto tensorType = cast<TensorType>(globalLoadOp.getType());
97+
auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
98+
99+
replaceOpWithNewBufferizedOp<memref::GetGlobalOp>(
100+
rewriter, globalLoadOp, memrefType,
101+
globalLoadOp.getGlobalAttr().getLeafReference());
102+
103+
return success();
104+
}
105+
};
106+
107+
/// Bufferization of ml_program.global_store into a memref.get_global and
108+
/// memcpy
109+
struct GlobalStoreOpInterface
110+
: public ExternalModelBase<GlobalStoreOpInterface, GlobalStoreOp> {
111+
112+
bool bufferizesToMemoryRead(Operation *, OpOperand &,
113+
const AnalysisState &) const {
114+
return false;
115+
}
116+
117+
bool bufferizesToMemoryWrite(Operation *, OpOperand &,
118+
const AnalysisState &) const {
119+
return true;
120+
}
121+
122+
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
123+
const BufferizationOptions &options) const {
124+
auto globalStoreOp = cast<GlobalStoreOp>(op);
125+
126+
auto tensorType = cast<TensorType>(globalStoreOp.getValue().getType());
127+
auto memrefType = getMemRefTypeWithStaticIdentityLayout(tensorType);
128+
129+
auto loc = globalStoreOp.getLoc();
130+
auto targetMemref = rewriter.create<memref::GetGlobalOp>(
131+
loc, memrefType, globalStoreOp.getGlobalAttr().getLeafReference());
132+
133+
auto sourceMemref = getBuffer(rewriter, globalStoreOp.getValue(), options);
134+
if (failed(sourceMemref)) {
135+
return failure();
136+
}
137+
138+
auto memcpy =
139+
options.createMemCpy(rewriter, loc, sourceMemref.value(), targetMemref);
140+
if (failed(memcpy)) {
141+
return failure();
142+
}
143+
rewriter.eraseOp(globalStoreOp);
144+
145+
return success();
146+
}
147+
};
148+
} // namespace
149+
150+
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
151+
registry.addExtension(+[](MLIRContext *ctx, MLProgramDialect *) {
152+
GlobalOp::attachInterface<GlobalOpInterface>(*ctx);
153+
GlobalLoadOp::attachInterface<GlobalLoadOpInterface>(*ctx);
154+
GlobalStoreOp::attachInterface<GlobalStoreOpInterface>(*ctx);
155+
});
156+
}
157+
} // namespace ml_program
158+
} // namespace mlir

mlir/lib/Dialect/MLProgram/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_mlir_dialect_library(MLIRMLProgramTransforms
2+
BufferizableOpInterfaceImpl.cpp
23
PipelineGlobalOps.cpp
34

45
ADDITIONAL_HEADER_DIRS
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
// RUN: mlir-opt %s -one-shot-bufferize -split-input-file | FileCheck %s
2+
3+
// CHECK-LABEL: memref.global "private" @global
4+
ml_program.global private mutable @global(dense<0> : tensor<i64>) : tensor<i64>
5+
6+
// CHECK-LABEL: func.func @global_load_store
7+
func.func @global_load_store() -> i64 {
8+
// CHECK-DAG: %[[CST127:.*]] = arith.constant 127
9+
// CHECK-DAG: %[[GLOBAL_1:.*]] = memref.get_global @global
10+
// CHECK: %[[VALUE:.*]] = memref.load %[[GLOBAL_1]][]
11+
// CHECK: %[[NEW_VALUE:.*]] = arith.muli %[[VALUE]], %[[CST127]]
12+
// CHECK: %[[ALLOC:.*]] = memref.alloc()
13+
// CHECK: memref.copy %[[GLOBAL_1]], %[[ALLOC]]
14+
// CHECK: memref.store %[[NEW_VALUE]], %[[ALLOC]][]
15+
// CHECK: %[[GLOBAL_2:.*]] = memref.get_global @global
16+
// CHECK: memref.copy %[[ALLOC]], %[[GLOBAL_2]]
17+
// CHECK: return %[[NEW_VALUE]]
18+
%c127 = arith.constant 127 : i64
19+
%0 = ml_program.global_load @global : tensor<i64>
20+
%extracted = tensor.extract %0[] : tensor<i64>
21+
%1 = arith.muli %extracted, %c127 : i64
22+
%inserted = tensor.insert %1 into %0[] : tensor<i64>
23+
ml_program.global_store @global = %inserted : tensor<i64>
24+
return %1 : i64
25+
}
26+
27+
// -----
28+
29+
// CHECK-LABEL: memref.global "private" @global
30+
ml_program.global private mutable @global(dense<0> : tensor<i64>) : tensor<i64>
31+
32+
// CHECK-LABEL: func.func @raw_hazard
33+
func.func @raw_hazard() -> i64 {
34+
// CHECK-DAG: %[[CST127:.*]] = arith.constant 127
35+
// CHECK-DAG: %[[GLOBAL_1:.*]] = memref.get_global @global
36+
// CHECK-DAG: %[[GLOBAL_2:.*]] = memref.get_global @global
37+
// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc()
38+
// CHECK: memref.copy %[[GLOBAL_1]], %[[ALLOC]]
39+
// CHECK: memref.store %[[CST127]], %[[ALLOC]][]
40+
// CHECK: %[[VAL:.*]] = memref.load %[[GLOBAL_2]][]
41+
// CHECK: %[[GLOBAL_3:.*]] = memref.get_global @global
42+
// CHECK: memref.copy %[[ALLOC]], %[[GLOBAL_3]]
43+
// CHECK: return %[[VAL]]
44+
%c127 = arith.constant 127 : i64
45+
%0 = ml_program.global_load @global : tensor<i64>
46+
%1 = ml_program.global_load @global : tensor<i64>
47+
%inserted = tensor.insert %c127 into %0[] : tensor<i64>
48+
%extracted = tensor.extract %1[] : tensor<i64>
49+
ml_program.global_store @global = %inserted : tensor<i64>
50+
return %extracted : i64
51+
}
52+

0 commit comments

Comments
 (0)