-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][MLProgram] Add MLProgram to MemRef bufferization pass #75103
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-bufferization @llvm/pr-subscribers-mlir Author: None (ryan-holt-1) ChangesThere is currently no lowering out of I had tried implementing the Full diff: https://github.com/llvm/llvm-project/pull/75103.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.h b/mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.h
index 894e35e52724e..75c107d917188 100644
--- a/mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.h
@@ -23,6 +23,8 @@ namespace ml_program {
// Registration
//===----------------------------------------------------------------------===//
+std::unique_ptr<OperationPass<ModuleOp>> createMLProgramBufferizePass();
+
std::unique_ptr<OperationPass<ModuleOp>> createMLProgramPipelineGlobalsPass();
/// Generate the code for registering passes.
diff --git a/mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.td b/mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.td
index defe8191cb905..617c24a4d8641 100644
--- a/mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.td
@@ -11,6 +11,14 @@
include "mlir/Pass/PassBase.td"
+def MLProgramBufferize: Pass<"mlprogram-bufferize", "ModuleOp"> {
+ let summary = "Bufferize the MLProgram dialect ops";
+ let constructor = "mlir::ml_program::createMLProgramBufferizePass()";
+ let dependentDialects = [
+ "bufferization::BufferizationDialect", "memref::MemRefDialect",
+ ];
+}
+
def MLProgramPipelineGlobals : Pass<"mlprogram-pipeline-globals", "ModuleOp"> {
let summary = "Optimize `ml_program` global operations for read and store";
let description = [{
diff --git a/mlir/lib/Dialect/MLProgram/Transforms/Bufferize.cpp b/mlir/lib/Dialect/MLProgram/Transforms/Bufferize.cpp
new file mode 100644
index 0000000000000..c462550c706e2
--- /dev/null
+++ b/mlir/lib/Dialect/MLProgram/Transforms/Bufferize.cpp
@@ -0,0 +1,146 @@
+//===- Bufferize.cpp - MLProgram bufferize pass ---------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a bufferization pass for the MLProgram dialect
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MLProgram/Transforms/Passes.h"
+
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/BuiltinTypes.h"
+
+namespace mlir {
+namespace ml_program {
+#define GEN_PASS_DEF_MLPROGRAMBUFFERIZE
+#include "mlir/Dialect/MLProgram/Transforms/Passes.h.inc"
+
+static LogicalResult bufferizeMLProgramGlobalOp(GlobalOp globalOp,
+ OpBuilder &builder) {
+ if (!globalOp.getValue().has_value())
+ return globalOp.emitError("global op must have a value");
+
+ auto tensorType = cast<RankedTensorType>(globalOp.getType());
+ auto memrefType =
+ MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+
+ builder.setInsertionPointToStart(
+ globalOp->getParentOfType<ModuleOp>().getBody());
+ builder.create<memref::GlobalOp>(
+ globalOp.getLoc(), globalOp.getSymName(),
+ /*sym_visibility=*/globalOp.getSymVisibilityAttr(),
+ /*type=*/memrefType,
+ /*initial_value=*/globalOp.getValue().value(),
+ /*constant=*/!globalOp.getIsMutable(),
+ /*alignment=*/nullptr);
+ return success();
+}
+
+static LogicalResult bufferizeMLProgramGlobalLoadOp(GlobalLoadOp globalLoadOp,
+ OpBuilder &builder) {
+ auto loc = globalLoadOp.getLoc();
+ auto tensorType = cast<RankedTensorType>(globalLoadOp.getType());
+ auto memrefType =
+ MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+
+ builder.setInsertionPoint(globalLoadOp);
+ Value globalVal = builder.create<memref::GetGlobalOp>(
+ loc, memrefType, globalLoadOp.getGlobalAttr().getLeafReference());
+
+ // We need a copy to guarantee that the produced tensor does not alias with
+ // any other buffer.
+ Value alloc = builder.create<memref::AllocOp>(loc, memrefType, ValueRange{});
+ builder.create<memref::CopyOp>(globalLoadOp->getLoc(), globalVal, alloc);
+
+ globalVal = builder.create<bufferization::ToTensorOp>(loc, tensorType, alloc,
+ /*restrict=*/true);
+ globalLoadOp->getResult(0).replaceAllUsesWith(globalVal);
+ return success();
+}
+
+static LogicalResult
+bufferizeMLProgramGlobalStoreOp(GlobalStoreOp globalStoreOp,
+ OpBuilder &builder) {
+ auto loc = globalStoreOp.getLoc();
+ auto tensorType = cast<RankedTensorType>(globalStoreOp.getValue().getType());
+ auto memrefType =
+ MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+
+ builder.setInsertionPoint(globalStoreOp);
+ Value memref = builder.create<memref::GetGlobalOp>(
+ loc, memrefType, globalStoreOp.getGlobalAttr().getLeafReference());
+ Value copyValue = builder.create<bufferization::ToMemrefOp>(
+ loc, memrefType, globalStoreOp.getValue());
+ builder.create<memref::CopyOp>(loc, copyValue, memref);
+ return success();
+}
+
+namespace {
+/// Converts MLProgram operations that work on tensor-type operands or results
+/// to work on buffers.
+class MLProgramBufferize
+ : public impl::MLProgramBufferizeBase<MLProgramBufferize> {
+ void runOnOperation() override {
+ auto module = getOperation();
+ OpBuilder builder(module.getBodyRegion());
+ SmallVector<Operation *> toErase;
+
+ auto walkResult = module.walk([&](GlobalOp op) {
+ if (auto type = dyn_cast<RankedTensorType>(op.getType())) {
+ if (!type.hasStaticShape()) {
+ // If the ml_program.global has dynamically shaped tensor.
+ op.emitError(
+ "unimplemented: global op bufferization with dynamic shape");
+ return WalkResult::interrupt();
+ }
+ } else {
+ // If the ml_program.global is of non-tensor type.
+ op.emitError("unsupported global op type");
+ return WalkResult::interrupt();
+ }
+
+ if (failed(bufferizeMLProgramGlobalOp(op, builder))) {
+ op.emitError("bufferization for this op failed");
+ return WalkResult::interrupt();
+ }
+ toErase.push_back(op);
+ return WalkResult::advance();
+ });
+
+ if (walkResult.wasInterrupted())
+ return signalPassFailure();
+
+ module.walk([&](GlobalLoadOp op) {
+ if (failed(bufferizeMLProgramGlobalLoadOp(op, builder))) {
+ op.emitError("bufferization for this op failed");
+ return;
+ }
+ toErase.push_back(op);
+ });
+
+ module.walk([&](GlobalStoreOp op) {
+ if (failed(bufferizeMLProgramGlobalStoreOp(op, builder))) {
+ op.emitError("bufferization for this op failed");
+ return;
+ }
+ toErase.push_back(op);
+ });
+
+ for (auto *op : llvm::reverse(toErase))
+ op->erase();
+ }
+};
+} // namespace
+
+std::unique_ptr<OperationPass<ModuleOp>> createMLProgramBufferizePass() {
+ return std::make_unique<MLProgramBufferize>();
+}
+} // namespace ml_program
+} // namespace mlir
diff --git a/mlir/lib/Dialect/MLProgram/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MLProgram/Transforms/CMakeLists.txt
index db567b62e0e74..dc14bf212434f 100644
--- a/mlir/lib/Dialect/MLProgram/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/MLProgram/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIRMLProgramTransforms
+ Bufferize.cpp
PipelineGlobalOps.cpp
ADDITIONAL_HEADER_DIRS
diff --git a/mlir/test/Dialect/MLProgram/bufferize.mlir b/mlir/test/Dialect/MLProgram/bufferize.mlir
new file mode 100644
index 0000000000000..5dc71803dc0cf
--- /dev/null
+++ b/mlir/test/Dialect/MLProgram/bufferize.mlir
@@ -0,0 +1,81 @@
+// RUN: mlir-opt %s --mlprogram-bufferize -split-input-file -verify-diagnostics | FileCheck %s
+
+// CHECK-LABEL: @global
+ml_program.global private mutable @global(dense<0> : tensor<i64>) : tensor<i64>
+
+// CHECK-LABEL: @global_load_store
+func.func @global_load_store() -> i64 {
+ // CHECK-DAG: %[[CST127:.+]] = arith.constant 127
+ // CHECK-DAG: %[[GLOBAL_1:.+]] = memref.get_global @global
+ // CHECK-DAG: %[[NEW_ALLOC:.+]] = memref.alloc
+ // CHECK: memref.copy %[[GLOBAL_1]], %[[NEW_ALLOC]]
+ // CHECK: %[[TENSOR:.+]] = bufferization.to_tensor %[[NEW_ALLOC]]
+ // CHECK: %[[EXTRACTED:.+]] = tensor.extract %[[TENSOR]][]
+ // CHECK: %[[NEW_VALUE:.+]] = arith.muli %[[EXTRACTED]], %[[CST127]]
+ // CHECK: %[[INSERTED:.+]] = tensor.insert %[[NEW_VALUE]] into %[[TENSOR]][]
+ // CHECK: %[[GLOBAL_2:.+]] = memref.get_global @global
+ // CHECK: %[[MEMREF:.+]] = bufferization.to_memref %[[INSERTED]]
+ // CHECK: memref.copy %[[MEMREF]], %[[GLOBAL_2]]
+ // CHECK: return %[[NEW_VALUE]]
+ %c127_i64 = arith.constant 127 : i64
+ %0 = ml_program.global_load @global : tensor<i64>
+ %extracted = tensor.extract %0[] : tensor<i64>
+ %1 = arith.muli %extracted, %c127_i64 : i64
+ %inserted = tensor.insert %1 into %0[] : tensor<i64>
+ ml_program.global_store @global = %inserted : tensor<i64>
+ return %1 : i64
+}
+
+// -----
+
+// expected-error @below {{unsupported global op type}}
+ml_program.global private mutable @global(0 : i64) : i64
+
+func.func @global_scalar() -> i64 {
+ %c127_i64 = arith.constant 127 : i64
+ %0 = ml_program.global_load @global : i64
+ %1 = arith.muli %0, %c127_i64 : i64
+ ml_program.global_store @global = %1 : i64
+ return %1 : i64
+}
+
+// -----
+
+// expected-error @below {{unsupported global op type}}
+ml_program.global private mutable @global(dense<0> : memref<i64>) : memref<i64>
+
+func.func @global_memref() -> i64 {
+ %c127_i64 = arith.constant 127 : i64
+ %0 = ml_program.global_load @global : memref<i64>
+ %extracted = memref.load %0[] : memref<i64>
+ %1 = arith.muli %extracted, %c127_i64 : i64
+ memref.store %1, %0[] : memref<i64>
+ ml_program.global_store @global = %0 : memref<i64>
+ return %1 : i64
+}
+
+// -----
+
+// expected-error @below {{invalid tensor element type}}
+ml_program.global private mutable @global(dense<0> : tensor<memref<i64>>) : tensor<memref<i64>>
+
+func.func @global_tensor_of_memref() -> i64 {
+ %c127_i64 = arith.constant 127 : i64
+ return %c127_i64 : i64
+}
+
+// -----
+
+// expected-error @below {{unimplemented: global op bufferization with dynamic shape}}
+ml_program.global private mutable @global(dense<0> : tensor<1xi64>) : tensor<?xi64>
+
+func.func @global_dynamic_shape() -> i64 {
+ %c127_i64 = arith.constant 127 : i64
+ %c0 = arith.constant 0 : index
+ %0 = ml_program.global_load @global : tensor<?xi64>
+ %extracted = tensor.extract %0[%c0] : tensor<?xi64>
+ %1 = arith.muli %extracted, %c127_i64 : i64
+ %inserted = tensor.insert %1 into %0[%c0] : tensor<?xi64>
+ ml_program.global_store @global = %inserted : tensor<?xi64>
+ return %1 : i64
+}
|
Instead of a separate pass, can you implement the The implementation would be similar to |
@matthias-springer I agree that implementing the interface would be much better. I had tried doing that but it did not work because One-Shot Bufferize would not visit module-level ops that appear outside of a function like |
922ed87
to
3c58380
Compare
One-Shot Bufferize skips over ops that do not have tensor semantics. Until now we defined "tensor semantics" as "has tensor operand and/or result" (and some special rules for functions). #75273 should fix the issue: I added a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me. Thank you.
I am not an expert on bufferization, but it looks like Matthias' review was good on that.
Do you need me to merge this?
Thanks for the review @stellaraccident. I'm currently waiting on #75273 to land so that I can implement this in terms of the |
Ok, cool. Feel free to ask for a re review or just have Matthias look. |
#75273 is ready to land from my side. Would be great if somebody can review it. |
Add a new interface method to `BufferizableOpInterface`: `hasTensorSemanticsForBufferization`. This method returns "true" if the op has tensor semantics and should be bufferized. Until now, we assumed that an op has tensor semantics if it has tensor operands and/or tensor op results. However, there are ops like `ml_program.global` that do not have any results/operands but must still be bufferized (llvm#75103). The new interface method can return "true" for such ops. This change also decouples `bufferization::bufferizeOp` a bit from the func dialect.
Add a new interface method to `BufferizableOpInterface`: `hasTensorSemanticsForBufferization`. This method returns "true" if the op has tensor semantics and should be bufferized. Until now, we assumed that an op has tensor semantics if it has tensor operands and/or tensor op results. However, there are ops like `ml_program.global` that do not have any results/operands but must still be bufferized (llvm#75103). The new interface method can return "true" for such ops. This change also decouples `bufferization::bufferizeOp` a bit from the func dialect.
…s` (#75273) Add a new interface method to `BufferizableOpInterface`: `hasTensorSemantics`. This method returns "true" if the op has tensor semantics and should be bufferized. Until now, we assumed that an op has tensor semantics if it has tensor operands and/or tensor op results. However, there are ops like `ml_program.global` that do not have any results/operands but must still be bufferized (#75103). The new interface method can return "true" for such ops. This change also decouples `bufferization::bufferizeOp` a bit from the func dialect.
Sorry that it took so long. #75273 has been merged. |
…s` (llvm#75273) Add a new interface method to `BufferizableOpInterface`: `hasTensorSemantics`. This method returns "true" if the op has tensor semantics and should be bufferized. Until now, we assumed that an op has tensor semantics if it has tensor operands and/or tensor op results. However, there are ops like `ml_program.global` that do not have any results/operands but must still be bufferized (llvm#75103). The new interface method can return "true" for such ops. This change also decouples `bufferization::bufferizeOp` a bit from the func dialect.
3c58380
to
e5461d8
Compare
b69999a
to
157daec
Compare
Thanks @matthias-springer! I've updated the PR. I had to make a couple of small changes to bufferization to account for the fact that a module-level op may be deleted (like |
✅ With the latest revision this PR passed the C/C++ code formatter. |
6bf6435
to
4146ffb
Compare
38227ad
to
9eced81
Compare
This commit implements the `BufferizableOpInterface` for `ml_program.global`, `ml_program.global_load` and `ml_program.global_store` so that these ops can be lowered all the way to LLVM.
9eced81
to
73feff8
Compare
Do you mind merging this for me @matthias-springer ? I don't have commit access. |
Ci is broken: https://lab.llvm.org/buildbot/#/builders/61/builds/53723 ; I have to revert. |
Looks like a missing |
Actually I can’t on my phone apparently, if someone can click the buttons for me that’d be nice! |
(Or feel free to fix forward @matthias-springer ) |
After #75103, `MLPrgramTransforms` depends on `BufferizationDialect`. Also fix an unrelated compile error in `GreedyPatternRewriteDriver.cpp`. (This was not failing on CI. I may be running an old compiler locally.)
Oops, thanks Matthias. |
There is currently no lowering out of
ml_program
in the LLVM repository. This change adds a lowering tomemref
so that it can be lowered all the way to LLVM. This lowering was taken from the reference backend in torch-mlir.I had tried implementing the
BufferizableOpInterface
forml_program
instead of adding a new pass but that did not work becauseOneShotBufferize
does not visit module-level ops likeml_program.global
.