Skip to content

Commit 3c58380

Browse files
committed
[mlir][MLProgram] Add MLProgram to MemRef bufferization pass
There is currently no lowering out of `ml_program` in the LLVM repository. This change adds a lowering to `memref` so that it can be can be lowered all the way to LLVM. This lowering was taken from the reference backend in torch-mlir: llvm/torch-mlir@f416953 I had tried implementing the `BufferizableOpInterface` instead of adding a new pass but that did not work because OneShotBufferize does not visit module-level ops like `ml_program.global`.
1 parent 97efd8a commit 3c58380

File tree

4 files changed

+241
-0
lines changed

4 files changed

+241
-0
lines changed

mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.td

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,23 @@
1111

1212
include "mlir/Pass/PassBase.td"
1313

14+
def MLProgramBufferize: Pass<"mlprogram-bufferize", "ModuleOp"> {
15+
let summary = "Bufferize the MLProgram dialect ops";
16+
let description = [{
17+
This pass bufferizes ops in the `ml_program` dialect. It is implemented as a
18+
standalone pass because One-Shot Bufferize does not handle module-level ops
19+
like `ml_program.global`. Therefore, this pass should run just before
20+
One-Shot Bufferize.
21+
22+
This pass is intended to be a generic lowering of `ml_program` ops to allow
23+
for them to be lowered all the way to LLVM. Users may want a more specialized
24+
lowering depending on how they manage global state in their system.
25+
}];
26+
let dependentDialects = [
27+
"bufferization::BufferizationDialect", "memref::MemRefDialect",
28+
];
29+
}
30+
1431
def MLProgramPipelineGlobals : Pass<"mlprogram-pipeline-globals", "ModuleOp"> {
1532
let summary = "Optimize `ml_program` global operations for read and store";
1633
let description = [{
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
//===- Bufferize.cpp - MLProgram bufferize pass ---------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements a bufferization pass for the MLProgram dialect
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Dialect/MLProgram/Transforms/Passes.h"
14+
15+
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
16+
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
17+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
18+
#include "mlir/IR/BuiltinTypes.h"
19+
20+
namespace mlir {
21+
namespace ml_program {
22+
#define GEN_PASS_DEF_MLPROGRAMBUFFERIZE
23+
#include "mlir/Dialect/MLProgram/Transforms/Passes.h.inc"
24+
25+
static LogicalResult bufferizeMLProgramGlobalOp(GlobalOp globalOp,
26+
OpBuilder &builder) {
27+
if (!globalOp.getValue().has_value())
28+
return globalOp.emitError("global op must have a value");
29+
30+
auto tensorType = cast<RankedTensorType>(globalOp.getType());
31+
auto memrefType =
32+
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
33+
34+
builder.setInsertionPointToStart(
35+
globalOp->getParentOfType<ModuleOp>().getBody());
36+
builder.create<memref::GlobalOp>(
37+
globalOp.getLoc(), globalOp.getSymName(),
38+
/*sym_visibility=*/globalOp.getSymVisibilityAttr(),
39+
/*type=*/memrefType,
40+
/*initial_value=*/globalOp.getValue().value(),
41+
/*constant=*/!globalOp.getIsMutable(),
42+
/*alignment=*/nullptr);
43+
return success();
44+
}
45+
46+
static LogicalResult bufferizeMLProgramGlobalLoadOp(GlobalLoadOp globalLoadOp,
47+
OpBuilder &builder) {
48+
auto loc = globalLoadOp.getLoc();
49+
auto tensorType = cast<RankedTensorType>(globalLoadOp.getType());
50+
auto memrefType =
51+
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
52+
53+
builder.setInsertionPoint(globalLoadOp);
54+
Value globalVal = builder.create<memref::GetGlobalOp>(
55+
loc, memrefType, globalLoadOp.getGlobalAttr().getLeafReference());
56+
57+
// We need a copy to guarantee that the produced tensor does not alias with
58+
// any other buffer.
59+
Value alloc = builder.create<memref::AllocOp>(loc, memrefType, ValueRange{});
60+
builder.create<memref::CopyOp>(globalLoadOp->getLoc(), globalVal, alloc);
61+
62+
globalVal = builder.create<bufferization::ToTensorOp>(loc, tensorType, alloc,
63+
/*restrict=*/true);
64+
globalLoadOp->getResult(0).replaceAllUsesWith(globalVal);
65+
return success();
66+
}
67+
68+
static LogicalResult
69+
bufferizeMLProgramGlobalStoreOp(GlobalStoreOp globalStoreOp,
70+
OpBuilder &builder) {
71+
auto loc = globalStoreOp.getLoc();
72+
auto tensorType = cast<RankedTensorType>(globalStoreOp.getValue().getType());
73+
auto memrefType =
74+
MemRefType::get(tensorType.getShape(), tensorType.getElementType());
75+
76+
builder.setInsertionPoint(globalStoreOp);
77+
Value memref = builder.create<memref::GetGlobalOp>(
78+
loc, memrefType, globalStoreOp.getGlobalAttr().getLeafReference());
79+
Value copyValue = builder.create<bufferization::ToMemrefOp>(
80+
loc, memrefType, globalStoreOp.getValue());
81+
builder.create<memref::CopyOp>(loc, copyValue, memref);
82+
return success();
83+
}
84+
85+
namespace {
86+
/// Converts MLProgram operations that work on tensor-type operands or results
87+
/// to work on buffers.
88+
class MLProgramBufferize
89+
: public impl::MLProgramBufferizeBase<MLProgramBufferize> {
90+
void runOnOperation() override {
91+
auto module = getOperation();
92+
OpBuilder builder(module.getBodyRegion());
93+
SmallVector<Operation *> toErase;
94+
95+
auto walkResult = module.walk([&](GlobalOp op) {
96+
if (auto type = dyn_cast<RankedTensorType>(op.getType())) {
97+
if (!type.hasStaticShape()) {
98+
// If the ml_program.global has dynamically shaped tensor.
99+
op.emitError(
100+
"unimplemented: global op bufferization with dynamic shape");
101+
return WalkResult::interrupt();
102+
}
103+
} else {
104+
// If the ml_program.global is of non-tensor type.
105+
op.emitError("unsupported global op type");
106+
return WalkResult::interrupt();
107+
}
108+
109+
if (failed(bufferizeMLProgramGlobalOp(op, builder))) {
110+
op.emitError("bufferization for this op failed");
111+
return WalkResult::interrupt();
112+
}
113+
toErase.push_back(op);
114+
return WalkResult::advance();
115+
});
116+
117+
if (walkResult.wasInterrupted())
118+
return signalPassFailure();
119+
120+
module.walk([&](GlobalLoadOp op) {
121+
if (failed(bufferizeMLProgramGlobalLoadOp(op, builder))) {
122+
op.emitError("bufferization for this op failed");
123+
return;
124+
}
125+
toErase.push_back(op);
126+
});
127+
128+
module.walk([&](GlobalStoreOp op) {
129+
if (failed(bufferizeMLProgramGlobalStoreOp(op, builder))) {
130+
op.emitError("bufferization for this op failed");
131+
return;
132+
}
133+
toErase.push_back(op);
134+
});
135+
136+
for (auto *op : llvm::reverse(toErase))
137+
op->erase();
138+
}
139+
};
140+
} // namespace
141+
} // namespace ml_program
142+
} // namespace mlir

mlir/lib/Dialect/MLProgram/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
add_mlir_dialect_library(MLIRMLProgramTransforms
2+
Bufferize.cpp
23
PipelineGlobalOps.cpp
34

45
ADDITIONAL_HEADER_DIRS
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
// RUN: mlir-opt %s --mlprogram-bufferize -split-input-file -verify-diagnostics | FileCheck %s
2+
3+
// CHECK-LABEL: @global
4+
ml_program.global private mutable @global(dense<0> : tensor<i64>) : tensor<i64>
5+
6+
// CHECK-LABEL: @global_load_store
7+
func.func @global_load_store() -> i64 {
8+
// CHECK-DAG: %[[CST127:.+]] = arith.constant 127
9+
// CHECK-DAG: %[[GLOBAL_1:.+]] = memref.get_global @global
10+
// CHECK-DAG: %[[NEW_ALLOC:.+]] = memref.alloc
11+
// CHECK: memref.copy %[[GLOBAL_1]], %[[NEW_ALLOC]]
12+
// CHECK: %[[TENSOR:.+]] = bufferization.to_tensor %[[NEW_ALLOC]]
13+
// CHECK: %[[EXTRACTED:.+]] = tensor.extract %[[TENSOR]][]
14+
// CHECK: %[[NEW_VALUE:.+]] = arith.muli %[[EXTRACTED]], %[[CST127]]
15+
// CHECK: %[[INSERTED:.+]] = tensor.insert %[[NEW_VALUE]] into %[[TENSOR]][]
16+
// CHECK: %[[GLOBAL_2:.+]] = memref.get_global @global
17+
// CHECK: %[[MEMREF:.+]] = bufferization.to_memref %[[INSERTED]]
18+
// CHECK: memref.copy %[[MEMREF]], %[[GLOBAL_2]]
19+
// CHECK: return %[[NEW_VALUE]]
20+
%c127_i64 = arith.constant 127 : i64
21+
%0 = ml_program.global_load @global : tensor<i64>
22+
%extracted = tensor.extract %0[] : tensor<i64>
23+
%1 = arith.muli %extracted, %c127_i64 : i64
24+
%inserted = tensor.insert %1 into %0[] : tensor<i64>
25+
ml_program.global_store @global = %inserted : tensor<i64>
26+
return %1 : i64
27+
}
28+
29+
// -----
30+
31+
// expected-error @below {{unsupported global op type}}
32+
ml_program.global private mutable @global(0 : i64) : i64
33+
34+
func.func @global_scalar() -> i64 {
35+
%c127_i64 = arith.constant 127 : i64
36+
%0 = ml_program.global_load @global : i64
37+
%1 = arith.muli %0, %c127_i64 : i64
38+
ml_program.global_store @global = %1 : i64
39+
return %1 : i64
40+
}
41+
42+
// -----
43+
44+
// expected-error @below {{unsupported global op type}}
45+
ml_program.global private mutable @global(dense<0> : memref<i64>) : memref<i64>
46+
47+
func.func @global_memref() -> i64 {
48+
%c127_i64 = arith.constant 127 : i64
49+
%0 = ml_program.global_load @global : memref<i64>
50+
%extracted = memref.load %0[] : memref<i64>
51+
%1 = arith.muli %extracted, %c127_i64 : i64
52+
memref.store %1, %0[] : memref<i64>
53+
ml_program.global_store @global = %0 : memref<i64>
54+
return %1 : i64
55+
}
56+
57+
// -----
58+
59+
// expected-error @below {{invalid tensor element type}}
60+
ml_program.global private mutable @global(dense<0> : tensor<memref<i64>>) : tensor<memref<i64>>
61+
62+
func.func @global_tensor_of_memref() -> i64 {
63+
%c127_i64 = arith.constant 127 : i64
64+
return %c127_i64 : i64
65+
}
66+
67+
// -----
68+
69+
// expected-error @below {{unimplemented: global op bufferization with dynamic shape}}
70+
ml_program.global private mutable @global(dense<0> : tensor<1xi64>) : tensor<?xi64>
71+
72+
func.func @global_dynamic_shape() -> i64 {
73+
%c127_i64 = arith.constant 127 : i64
74+
%c0 = arith.constant 0 : index
75+
%0 = ml_program.global_load @global : tensor<?xi64>
76+
%extracted = tensor.extract %0[%c0] : tensor<?xi64>
77+
%1 = arith.muli %extracted, %c127_i64 : i64
78+
%inserted = tensor.insert %1 into %0[%c0] : tensor<?xi64>
79+
ml_program.global_store @global = %inserted : tensor<?xi64>
80+
return %1 : i64
81+
}

0 commit comments

Comments
 (0)