Skip to content

[mlir][MLProgram] Add MLProgram to MemRef bufferization pass #75103

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 30, 2024

Conversation

ryanpholt
Copy link
Contributor

@ryanpholt ryanpholt commented Dec 11, 2023

There is currently no lowering out of ml_program in the LLVM repository. This change adds a lowering to memref so that it can be lowered all the way to LLVM. This lowering was taken from the reference backend in torch-mlir.

I had tried implementing the BufferizableOpInterface for ml_program instead of adding a new pass but that did not work because OneShotBufferize does not visit module-level ops like ml_program.global.

@llvmbot
Copy link
Member

llvmbot commented Dec 11, 2023

@llvm/pr-subscribers-mlir-bufferization
@llvm/pr-subscribers-mlir-mlprogram

@llvm/pr-subscribers-mlir

Author: None (ryan-holt-1)

Changes

There is currently no lowering out of ml_program in the LLVM repository. This change adds a lowering to memref so that it can be lowered all the way to LLVM. This lowering was taken from the reference backend in torch-mlir.

I had tried implementing the BufferizableOpInterface for ml_program instead of adding a new pass but that did not work because OneShotBufferize does not visit global ops outside of a function.


Full diff: https://github.com/llvm/llvm-project/pull/75103.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.h (+2)
  • (modified) mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.td (+8)
  • (added) mlir/lib/Dialect/MLProgram/Transforms/Bufferize.cpp (+146)
  • (modified) mlir/lib/Dialect/MLProgram/Transforms/CMakeLists.txt (+1)
  • (added) mlir/test/Dialect/MLProgram/bufferize.mlir (+81)
diff --git a/mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.h b/mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.h
index 894e35e52724e..75c107d917188 100644
--- a/mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.h
@@ -23,6 +23,8 @@ namespace ml_program {
 // Registration
 //===----------------------------------------------------------------------===//
 
+std::unique_ptr<OperationPass<ModuleOp>> createMLProgramBufferizePass();
+
 std::unique_ptr<OperationPass<ModuleOp>> createMLProgramPipelineGlobalsPass();
 
 /// Generate the code for registering passes.
diff --git a/mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.td b/mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.td
index defe8191cb905..617c24a4d8641 100644
--- a/mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/MLProgram/Transforms/Passes.td
@@ -11,6 +11,14 @@
 
 include "mlir/Pass/PassBase.td"
 
+def MLProgramBufferize: Pass<"mlprogram-bufferize", "ModuleOp"> {
+  let summary = "Bufferize the MLProgram dialect ops";
+  let constructor = "mlir::ml_program::createMLProgramBufferizePass()";
+  let dependentDialects = [
+    "bufferization::BufferizationDialect", "memref::MemRefDialect", 
+  ];
+}
+
 def MLProgramPipelineGlobals : Pass<"mlprogram-pipeline-globals", "ModuleOp"> {
   let summary = "Optimize `ml_program` global operations for read and store";
   let description = [{
diff --git a/mlir/lib/Dialect/MLProgram/Transforms/Bufferize.cpp b/mlir/lib/Dialect/MLProgram/Transforms/Bufferize.cpp
new file mode 100644
index 0000000000000..c462550c706e2
--- /dev/null
+++ b/mlir/lib/Dialect/MLProgram/Transforms/Bufferize.cpp
@@ -0,0 +1,146 @@
+//===- Bufferize.cpp - MLProgram bufferize pass ---------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a bufferization pass for the MLProgram dialect
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/MLProgram/Transforms/Passes.h"
+
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/BuiltinTypes.h"
+
+namespace mlir {
+namespace ml_program {
+#define GEN_PASS_DEF_MLPROGRAMBUFFERIZE
+#include "mlir/Dialect/MLProgram/Transforms/Passes.h.inc"
+
+static LogicalResult bufferizeMLProgramGlobalOp(GlobalOp globalOp,
+                                                OpBuilder &builder) {
+  if (!globalOp.getValue().has_value())
+    return globalOp.emitError("global op must have a value");
+
+  auto tensorType = cast<RankedTensorType>(globalOp.getType());
+  auto memrefType =
+      MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+
+  builder.setInsertionPointToStart(
+      globalOp->getParentOfType<ModuleOp>().getBody());
+  builder.create<memref::GlobalOp>(
+      globalOp.getLoc(), globalOp.getSymName(),
+      /*sym_visibility=*/globalOp.getSymVisibilityAttr(),
+      /*type=*/memrefType,
+      /*initial_value=*/globalOp.getValue().value(),
+      /*constant=*/!globalOp.getIsMutable(),
+      /*alignment=*/nullptr);
+  return success();
+}
+
+static LogicalResult bufferizeMLProgramGlobalLoadOp(GlobalLoadOp globalLoadOp,
+                                                    OpBuilder &builder) {
+  auto loc = globalLoadOp.getLoc();
+  auto tensorType = cast<RankedTensorType>(globalLoadOp.getType());
+  auto memrefType =
+      MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+
+  builder.setInsertionPoint(globalLoadOp);
+  Value globalVal = builder.create<memref::GetGlobalOp>(
+      loc, memrefType, globalLoadOp.getGlobalAttr().getLeafReference());
+
+  // We need a copy to guarantee that the produced tensor does not alias with
+  // any other buffer.
+  Value alloc = builder.create<memref::AllocOp>(loc, memrefType, ValueRange{});
+  builder.create<memref::CopyOp>(globalLoadOp->getLoc(), globalVal, alloc);
+
+  globalVal = builder.create<bufferization::ToTensorOp>(loc, tensorType, alloc,
+                                                        /*restrict=*/true);
+  globalLoadOp->getResult(0).replaceAllUsesWith(globalVal);
+  return success();
+}
+
+static LogicalResult
+bufferizeMLProgramGlobalStoreOp(GlobalStoreOp globalStoreOp,
+                                OpBuilder &builder) {
+  auto loc = globalStoreOp.getLoc();
+  auto tensorType = cast<RankedTensorType>(globalStoreOp.getValue().getType());
+  auto memrefType =
+      MemRefType::get(tensorType.getShape(), tensorType.getElementType());
+
+  builder.setInsertionPoint(globalStoreOp);
+  Value memref = builder.create<memref::GetGlobalOp>(
+      loc, memrefType, globalStoreOp.getGlobalAttr().getLeafReference());
+  Value copyValue = builder.create<bufferization::ToMemrefOp>(
+      loc, memrefType, globalStoreOp.getValue());
+  builder.create<memref::CopyOp>(loc, copyValue, memref);
+  return success();
+}
+
+namespace {
+/// Converts MLProgram operations that work on tensor-type operands or results
+/// to work on buffers.
+class MLProgramBufferize
+    : public impl::MLProgramBufferizeBase<MLProgramBufferize> {
+  void runOnOperation() override {
+    auto module = getOperation();
+    OpBuilder builder(module.getBodyRegion());
+    SmallVector<Operation *> toErase;
+
+    auto walkResult = module.walk([&](GlobalOp op) {
+      if (auto type = dyn_cast<RankedTensorType>(op.getType())) {
+        if (!type.hasStaticShape()) {
+          // If the ml_program.global has dynamically shaped tensor.
+          op.emitError(
+              "unimplemented: global op bufferization with dynamic shape");
+          return WalkResult::interrupt();
+        }
+      } else {
+        // If the ml_program.global is of non-tensor type.
+        op.emitError("unsupported global op type");
+        return WalkResult::interrupt();
+      }
+
+      if (failed(bufferizeMLProgramGlobalOp(op, builder))) {
+        op.emitError("bufferization for this op failed");
+        return WalkResult::interrupt();
+      }
+      toErase.push_back(op);
+      return WalkResult::advance();
+    });
+
+    if (walkResult.wasInterrupted())
+      return signalPassFailure();
+
+    module.walk([&](GlobalLoadOp op) {
+      if (failed(bufferizeMLProgramGlobalLoadOp(op, builder))) {
+        op.emitError("bufferization for this op failed");
+        return;
+      }
+      toErase.push_back(op);
+    });
+
+    module.walk([&](GlobalStoreOp op) {
+      if (failed(bufferizeMLProgramGlobalStoreOp(op, builder))) {
+        op.emitError("bufferization for this op failed");
+        return;
+      }
+      toErase.push_back(op);
+    });
+
+    for (auto *op : llvm::reverse(toErase))
+      op->erase();
+  }
+};
+} // namespace
+
+std::unique_ptr<OperationPass<ModuleOp>> createMLProgramBufferizePass() {
+  return std::make_unique<MLProgramBufferize>();
+}
+} // namespace ml_program
+} // namespace mlir
diff --git a/mlir/lib/Dialect/MLProgram/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MLProgram/Transforms/CMakeLists.txt
index db567b62e0e74..dc14bf212434f 100644
--- a/mlir/lib/Dialect/MLProgram/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/MLProgram/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_dialect_library(MLIRMLProgramTransforms
+  Bufferize.cpp
   PipelineGlobalOps.cpp
 
   ADDITIONAL_HEADER_DIRS
diff --git a/mlir/test/Dialect/MLProgram/bufferize.mlir b/mlir/test/Dialect/MLProgram/bufferize.mlir
new file mode 100644
index 0000000000000..5dc71803dc0cf
--- /dev/null
+++ b/mlir/test/Dialect/MLProgram/bufferize.mlir
@@ -0,0 +1,81 @@
+// RUN: mlir-opt %s --mlprogram-bufferize -split-input-file -verify-diagnostics | FileCheck %s
+
+// CHECK-LABEL: @global 
+ml_program.global private mutable @global(dense<0> : tensor<i64>) : tensor<i64>
+
+// CHECK-LABEL: @global_load_store
+func.func @global_load_store() -> i64 {
+  // CHECK-DAG: %[[CST127:.+]] = arith.constant 127
+  // CHECK-DAG: %[[GLOBAL_1:.+]] = memref.get_global @global
+  // CHECK-DAG: %[[NEW_ALLOC:.+]] = memref.alloc
+  // CHECK:     memref.copy %[[GLOBAL_1]], %[[NEW_ALLOC]]
+  // CHECK:     %[[TENSOR:.+]] = bufferization.to_tensor %[[NEW_ALLOC]]
+  // CHECK:     %[[EXTRACTED:.+]] = tensor.extract %[[TENSOR]][]
+  // CHECK:     %[[NEW_VALUE:.+]] = arith.muli %[[EXTRACTED]], %[[CST127]]
+  // CHECK:     %[[INSERTED:.+]] = tensor.insert %[[NEW_VALUE]] into %[[TENSOR]][]
+  // CHECK:     %[[GLOBAL_2:.+]] = memref.get_global @global
+  // CHECK:     %[[MEMREF:.+]] = bufferization.to_memref %[[INSERTED]]
+  // CHECK:     memref.copy %[[MEMREF]], %[[GLOBAL_2]]
+  // CHECK:     return %[[NEW_VALUE]]
+  %c127_i64 = arith.constant 127 : i64
+  %0 = ml_program.global_load @global : tensor<i64>
+  %extracted = tensor.extract %0[] : tensor<i64>
+  %1 = arith.muli %extracted, %c127_i64 : i64
+  %inserted = tensor.insert %1 into %0[] : tensor<i64>
+  ml_program.global_store @global = %inserted : tensor<i64>
+  return %1 : i64
+}
+
+// -----
+
+// expected-error @below {{unsupported global op type}}
+ml_program.global private mutable @global(0 : i64) : i64
+
+func.func @global_scalar() -> i64 {
+  %c127_i64 = arith.constant 127 : i64
+  %0 = ml_program.global_load @global : i64
+  %1 = arith.muli %0, %c127_i64 : i64
+  ml_program.global_store @global = %1 : i64
+  return %1 : i64
+}
+
+// -----
+
+// expected-error @below {{unsupported global op type}}
+ml_program.global private mutable @global(dense<0> : memref<i64>) : memref<i64>
+
+func.func @global_memref() -> i64 {
+  %c127_i64 = arith.constant 127 : i64
+  %0 = ml_program.global_load @global : memref<i64>
+  %extracted = memref.load %0[] : memref<i64>
+  %1 = arith.muli %extracted, %c127_i64 : i64
+  memref.store %1, %0[] : memref<i64>
+  ml_program.global_store @global = %0 : memref<i64>
+  return %1 : i64
+}
+
+// -----
+
+// expected-error @below {{invalid tensor element type}}
+ml_program.global private mutable @global(dense<0> : tensor<memref<i64>>) : tensor<memref<i64>>
+
+func.func @global_tensor_of_memref() -> i64 {
+  %c127_i64 = arith.constant 127 : i64
+  return %c127_i64 : i64
+}
+
+// -----
+
+// expected-error @below {{unimplemented: global op bufferization with dynamic shape}}
+ml_program.global private mutable @global(dense<0> : tensor<1xi64>) : tensor<?xi64>
+
+func.func @global_dynamic_shape() -> i64 {
+  %c127_i64 = arith.constant 127 : i64
+  %c0 = arith.constant 0 : index
+  %0 = ml_program.global_load @global : tensor<?xi64>
+  %extracted = tensor.extract %0[%c0] : tensor<?xi64>
+  %1 = arith.muli %extracted, %c127_i64 : i64
+  %inserted = tensor.insert %1 into %0[%c0] : tensor<?xi64>
+  ml_program.global_store @global = %inserted : tensor<?xi64>
+  return %1 : i64
+}

@matthias-springer
Copy link
Member

Instead of a separate pass, can you implement the BufferizableOpInterface for these three ops? Then it would compose better with the One-Shot Bufferize infrastructure. We keep those interface implementations in (Dialect)/Transforms/BufferizableOpInterfaceImpl.cpp files.

The implementation would be similar to flow.dispatch.tensor.load/store in IREE (https://github.com/openxla/iree/blob/main/compiler/src/iree/compiler/Codegen/Interfaces/BufferizationInterfaces.cpp#L164). I think you also won't need to create a copy when loading a tensor. The implementation of GlobalLoadOp::isWritable can return false, so that every write would trigger a copy during One-Shot Bufferize.

@ryanpholt
Copy link
Contributor Author

@matthias-springer I agree that implementing the interface would be much better. I had tried doing that but it did not work because One-Shot Bufferize would not visit module-level ops that appear outside of a function like ml_program.global. Bufferizing ml_program.global_load and ml_program.global_store worked fine but ml_program.global was skipped. Any thoughts on how I could get that to work?

@ryanpholt
Copy link
Contributor Author

@stellaraccident

@matthias-springer
Copy link
Member

One-Shot Bufferize skips over ops that do not have tensor semantics. Until now we defined "tensor semantics" as "has tensor operand and/or result" (and some special rules for functions). ml_program.global does not have any operands or results, so it was skipped.

#75273 should fix the issue: I added a hasTensorSemantics interface method to BufferizableOpInterface. The default implementation behaves the same as before. But you can override the behavior for ml_program.global and then it should be picked up by the bufferization.

Copy link
Contributor

@stellaraccident stellaraccident left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me. Thank you.

I am not an expert on bufferization, but it looks like Matthias' review was good on that.

Do you need me to merge this?

@ryanpholt
Copy link
Contributor Author

Thanks for the review @stellaraccident. I'm currently waiting on #75273 to land so that I can implement this in terms of the BufferizableOpInterface instead of a separate pass as suggested by Matthias.

@stellaraccident stellaraccident self-requested a review December 19, 2023 16:53
@stellaraccident
Copy link
Contributor

Ok, cool. Feel free to ask for a re review or just have Matthias look.

@matthias-springer
Copy link
Member

Thanks for the review @stellaraccident. I'm currently waiting on #75273 to land so that I can implement this in terms of the BufferizableOpInterface instead of a separate pass as suggested by Matthias.

#75273 is ready to land from my side. Would be great if somebody can review it.

matthias-springer added a commit to matthias-springer/llvm-project that referenced this pull request Jan 12, 2024
Add a new interface method to `BufferizableOpInterface`: `hasTensorSemanticsForBufferization`. This method returns "true" if the op has tensor semantics and should be bufferized.

Until now, we assumed that an op has tensor semantics if it has tensor operands and/or tensor op results. However, there are ops like `ml_program.global` that do not have any results/operands but must still be bufferized (llvm#75103). The new interface method can return "true" for such ops.

This change also decouples `bufferization::bufferizeOp` a bit from the func dialect.
matthias-springer added a commit to matthias-springer/llvm-project that referenced this pull request Jan 16, 2024
Add a new interface method to `BufferizableOpInterface`: `hasTensorSemanticsForBufferization`. This method returns "true" if the op has tensor semantics and should be bufferized.

Until now, we assumed that an op has tensor semantics if it has tensor operands and/or tensor op results. However, there are ops like `ml_program.global` that do not have any results/operands but must still be bufferized (llvm#75103). The new interface method can return "true" for such ops.

This change also decouples `bufferization::bufferizeOp` a bit from the func dialect.
matthias-springer added a commit that referenced this pull request Jan 16, 2024
…s` (#75273)

Add a new interface method to `BufferizableOpInterface`:
`hasTensorSemantics`. This method returns "true" if the op has tensor
semantics and should be bufferized.

Until now, we assumed that an op has tensor semantics if it has tensor
operands and/or tensor op results. However, there are ops like
`ml_program.global` that do not have any results/operands but must still
be bufferized (#75103). The new interface method can return "true" for
such ops.

This change also decouples `bufferization::bufferizeOp` a bit from the
func dialect.
@matthias-springer
Copy link
Member

Sorry that it took so long. #75273 has been merged.

justinfargnoli pushed a commit to justinfargnoli/llvm-project that referenced this pull request Jan 28, 2024
…s` (llvm#75273)

Add a new interface method to `BufferizableOpInterface`:
`hasTensorSemantics`. This method returns "true" if the op has tensor
semantics and should be bufferized.

Until now, we assumed that an op has tensor semantics if it has tensor
operands and/or tensor op results. However, there are ops like
`ml_program.global` that do not have any results/operands but must still
be bufferized (llvm#75103). The new interface method can return "true" for
such ops.

This change also decouples `bufferization::bufferizeOp` a bit from the
func dialect.
@ryanpholt ryanpholt force-pushed the mlprogram-bufferize branch from 3c58380 to e5461d8 Compare January 29, 2024 07:39
@llvmbot llvmbot added the mlir:bufferization Bufferization infrastructure label Jan 29, 2024
@ryanpholt ryanpholt force-pushed the mlprogram-bufferize branch 2 times, most recently from b69999a to 157daec Compare January 29, 2024 07:57
@ryanpholt
Copy link
Contributor Author

Thanks @matthias-springer! I've updated the PR. I had to make a couple of small changes to bufferization to account for the fact that a module-level op may be deleted (like ml_program.global). It works now, though, I'm not entirely sure I've implemented bufferizesToMemoryRead and bufferizesToMemoryWrite correctly.

Copy link

github-actions bot commented Jan 29, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@ryanpholt ryanpholt force-pushed the mlprogram-bufferize branch 2 times, most recently from 6bf6435 to 4146ffb Compare January 29, 2024 08:02
This commit implements the `BufferizableOpInterface` for
`ml_program.global`, `ml_program.global_load` and
`ml_program.global_store` so that these ops can be lowered all the way
to LLVM.
@ryanpholt ryanpholt force-pushed the mlprogram-bufferize branch from 9eced81 to 73feff8 Compare January 30, 2024 13:09
@ryanpholt
Copy link
Contributor Author

Do you mind merging this for me @matthias-springer ? I don't have commit access.

@matthias-springer matthias-springer merged commit fa10121 into llvm:main Jan 30, 2024
@joker-eph
Copy link
Collaborator

Ci is broken: https://lab.llvm.org/buildbot/#/builders/61/builds/53723 ; I have to revert.

@matthias-springer
Copy link
Member

Looks like a missing MLIRBufferizationDialect dependency in mlir/lib/Dialect/MLProgram/Transforms/CMakeLists.txt. If you haven't reverted yet, I can try to fix it now.

@joker-eph
Copy link
Collaborator

Actually I can’t on my phone apparently, if someone can click the buttons for me that’d be nice!

@joker-eph
Copy link
Collaborator

(Or feel free to fix forward @matthias-springer )

matthias-springer added a commit that referenced this pull request Jan 30, 2024
After #75103, `MLPrgramTransforms` depends on `BufferizationDialect`.
Also fix an unrelated compile error in `GreedyPatternRewriteDriver.cpp`.
(This was not failing on CI. I may be running an old compiler locally.)
@ryanpholt
Copy link
Contributor Author

Oops, thanks Matthias.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants