Skip to content

[TOSA] Add TosaToMLProgram conversion #69787

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 10, 2023
Merged

[TOSA] Add TosaToMLProgram conversion #69787

merged 1 commit into from
Nov 10, 2023

Conversation

Jerry-Ge
Copy link
Member

@Jerry-Ge Jerry-Ge commented Oct 20, 2023

This patch adds a new pass to lower TOSA StatefulOps to corresponding ML Program ops (https://mlir.llvm.org/docs/Dialects/MLProgramOps/).

@llvmbot llvmbot added the mlir label Oct 20, 2023
@llvmbot
Copy link
Member

llvmbot commented Oct 20, 2023

@llvm/pr-subscribers-mlir

Author: Jerry-Ge (Jerry-Ge)

Changes

This patch adds a new pass to lower TOSA StatefulOps to corresponding ML Program ops (https://mlir.llvm.org/docs/Dialects/MLProgramOps/).

This is part of the TOSA statefulOps effort:

The currently lowering is:

  • tfl->ml_program

After finishing the tfl->tosa.stateful and tosa.stateful->ml_program.stateful .

  • The tfl->ml_program pass will be removed.

  • Eventually it will be tfl->tosa->ml_program


Full diff: https://github.com/llvm/llvm-project/pull/69787.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Conversion/Passes.h (+1)
  • (modified) mlir/include/mlir/Conversion/Passes.td (+15-2)
  • (added) mlir/include/mlir/Conversion/TosaToMLProgram/TosaToMLProgram.h (+32)
  • (modified) mlir/lib/Conversion/CMakeLists.txt (+1)
  • (added) mlir/lib/Conversion/TosaToMLProgram/CMakeLists.txt (+19)
  • (added) mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp (+77)
  • (added) mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgramPass.cpp (+50)
  • (added) mlir/test/Conversion/TosaToMLProgram/tosa-to-mlprogram.mlir (+16)
  • (modified) utils/bazel/llvm-project-overlay/mlir/BUILD.bazel (+25)
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index e714f5070f23db8..637b69fc3f157b9 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -59,6 +59,7 @@
 #include "mlir/Conversion/TensorToSPIRV/TensorToSPIRVPass.h"
 #include "mlir/Conversion/TosaToArith/TosaToArith.h"
 #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
+#include "mlir/Conversion/TosaToMLProgram/TosaToMLProgram.h"
 #include "mlir/Conversion/TosaToSCF/TosaToSCF.h"
 #include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
 #include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index a269fb4a83af41f..6bdd9223329a8b4 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -388,8 +388,8 @@ def ConvertFuncToLLVMPass : Pass<"convert-func-to-llvm", "ModuleOp"> {
     already present in the IR will be kept as is.
 
     An LLVM datalayout string can be attached as an attribute to the module on
-    which the pass anchors. Such an attribute is attached by calling the 
-    set-module-datalayout pass. If present, an llvm::DataLayout object is 
+    which the pass anchors. Such an attribute is attached by calling the
+    set-module-datalayout pass. If present, an llvm::DataLayout object is
     created from this attribute and used in the conversion to LLVM.
 
     #### Output IR
@@ -1107,6 +1107,19 @@ def TosaToLinalgNamed
   let constructor = "tosa::createTosaToLinalgNamed()";
 }
 
+//===----------------------------------------------------------------------===//
+// TosaToMLProgram
+//===----------------------------------------------------------------------===//
+
+def TosaToMLProgram : Pass<"tosa-to-mlprogram", "ModuleOp"> {
+  let summary = "Lower TOSA to the ml_program dialect";
+  let dependentDialects = ["ml_program::MLProgramDialect"];
+  let description = [{
+    Pass that converts TOSA's variable operator operations to the equivalent
+    ml_program operations.
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // TosaToSCF
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Conversion/TosaToMLProgram/TosaToMLProgram.h b/mlir/include/mlir/Conversion/TosaToMLProgram/TosaToMLProgram.h
new file mode 100644
index 000000000000000..4c8f87e14797b66
--- /dev/null
+++ b/mlir/include/mlir/Conversion/TosaToMLProgram/TosaToMLProgram.h
@@ -0,0 +1,32 @@
+//===-- TosaToMLProgram.h - TOSA to ML_Program dialect lowerings -------------*-
+//C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the passes for the TOSA to the SCF Dialect conversion.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_TOSATOMLPROGRAM_TOSATOMLPROGRAM_H
+#define MLIR_CONVERSION_TOSATOMLPROGRAM_TOSATOMLPROGRAM_H
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+
+#define GEN_PASS_DECL_TOSATOMLPROGRAM
+#include "mlir/Conversion/Passes.h.inc"
+
+namespace tosa {
+
+void populateTosaToMLProgramConversionPatterns(RewritePatternSet *patterns);
+
+} // namespace tosa
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_TOSATOMLPROGRAM_TOSATOMLPROGRAM_H
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 35790254be137be..664804f0453509f 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -49,6 +49,7 @@ add_subdirectory(TensorToLinalg)
 add_subdirectory(TensorToSPIRV)
 add_subdirectory(TosaToArith)
 add_subdirectory(TosaToLinalg)
+add_subdirectory(TosaToMLProgram)
 add_subdirectory(TosaToSCF)
 add_subdirectory(TosaToTensor)
 add_subdirectory(UBToLLVM)
diff --git a/mlir/lib/Conversion/TosaToMLProgram/CMakeLists.txt b/mlir/lib/Conversion/TosaToMLProgram/CMakeLists.txt
new file mode 100644
index 000000000000000..82941424f1d1025
--- /dev/null
+++ b/mlir/lib/Conversion/TosaToMLProgram/CMakeLists.txt
@@ -0,0 +1,19 @@
+add_mlir_conversion_library(MLIRTosaToMLProgram
+  TosaToMLProgram.cpp
+  TosaToMLProgramPass.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_LIBS PUBLIC
+  MLIRIR
+  MLIRMLProgramDialect
+  MLIRPass
+  MLIRTosaDialect
+  MLIRTosaTransforms
+  MLIRSupport
+  )
diff --git a/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp b/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp
new file mode 100644
index 000000000000000..4fe83fe9d81fcae
--- /dev/null
+++ b/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp
@@ -0,0 +1,77 @@
+//===- TosaToMLProgram.cpp - Lowering Tosa to MLProgram Dialect
+//-----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// These rewriters lower from the TOSA dialect to the MLProgram dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/TosaToMLProgram/TosaToMLProgram.h"
+#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+using namespace tosa;
+namespace {
+
+class VariableOpConverter : public OpRewritePattern<tosa::VariableOp> {
+public:
+  using OpRewritePattern<tosa::VariableOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::VariableOp op,
+                                PatternRewriter &rewriter) const final {
+    auto newVariable = rewriter.create<mlir::ml_program::GlobalOp>(
+        op.getLoc(), op.getName(), op.getType(), /*is_mutable=*/true,
+        op.getInitialValueAttr(), nullptr);
+    newVariable.setPrivate();
+    rewriter.replaceOp(op, newVariable);
+    return success();
+  }
+};
+
+class VariableWriteOpConverter
+    : public OpRewritePattern<tosa::VariableWriteOp> {
+public:
+  using OpRewritePattern<tosa::VariableWriteOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::VariableWriteOp op,
+                                PatternRewriter &rewriter) const final {
+    auto globalSymbolRef =
+        SymbolRefAttr::get(rewriter.getContext(), op.getName());
+    auto newVariableWrite = rewriter.create<ml_program::GlobalStoreOp>(
+        op.getLoc(), globalSymbolRef, op.getValue());
+    rewriter.replaceOp(op, newVariableWrite);
+    return success();
+  }
+};
+
+class VariableReadOpConverter : public OpRewritePattern<tosa::VariableReadOp> {
+public:
+  using OpRewritePattern<tosa::VariableReadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::VariableReadOp op,
+                                PatternRewriter &rewriter) const final {
+    auto globalSymbolRef =
+        SymbolRefAttr::get(rewriter.getContext(), op.getName());
+    auto newVariableRead = rewriter.create<ml_program::GlobalLoadOp>(
+        op.getLoc(), op.getType(), globalSymbolRef);
+    rewriter.replaceOp(op, newVariableRead);
+
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::tosa::populateTosaToMLProgramConversionPatterns(
+    RewritePatternSet *patterns) {
+  patterns->add<VariableOpConverter, VariableWriteOpConverter,
+                VariableReadOpConverter>(patterns->getContext());
+}
\ No newline at end of file
diff --git a/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgramPass.cpp b/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgramPass.cpp
new file mode 100644
index 000000000000000..1b67decdabd35be
--- /dev/null
+++ b/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgramPass.cpp
@@ -0,0 +1,50 @@
+//===- TosaToMLProgramPass.cpp - Lowering Tosa to MLProgram Dialect
+//-------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This transformation pass legalizes the TOSA dialect to the MLProgram dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/TosaToMLProgram/TosaToMLProgram.h"
+#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_TOSATOMLPROGRAM
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace tosa;
+
+namespace {
+struct TosaToMLProgram : public impl::TosaToMLProgramBase<TosaToMLProgram> {
+public:
+  void runOnOperation() override {
+    auto *context = &getContext();
+    auto moduleOp = getOperation();
+
+    RewritePatternSet patterns(context);
+    ConversionTarget target(*context);
+    target.addLegalDialect<ml_program::MLProgramDialect>();
+    target.addIllegalOp<tosa::VariableOp, tosa::VariableReadOp,
+                        tosa::VariableWriteOp>();
+    target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
+
+    mlir::tosa::populateTosaToMLProgramConversionPatterns(&patterns);
+
+    if (failed(applyPartialConversion(moduleOp, target, std::move(patterns))))
+      signalPassFailure();
+  }
+};
+} // namespace
diff --git a/mlir/test/Conversion/TosaToMLProgram/tosa-to-mlprogram.mlir b/mlir/test/Conversion/TosaToMLProgram/tosa-to-mlprogram.mlir
new file mode 100644
index 000000000000000..b76358515d9aa79
--- /dev/null
+++ b/mlir/test/Conversion/TosaToMLProgram/tosa-to-mlprogram.mlir
@@ -0,0 +1,16 @@
+// RUN: mlir-opt --split-input-file --tosa-to-mlprogram %s -verify-diagnostics -o -| FileCheck %s
+
+
+// -----
+
+module {
+  // CHECK: ml_program.global private mutable @var_x(dense<7.000000e+00> : tensor<1xf32>) : tensor<1xf32>
+  tosa.variable @var_x = dense<7.000000e+00> : tensor<1xf32>
+  func.func @test_stateful_ops(%arg0: tensor<1xf32>) -> (tensor<1xf32>) {
+    // CHECK: ml_program.global_store @var_x = %arg0 : tensor<1xf32>
+    tosa.variable.write @var_x, %arg0 : tensor<1xf32>
+    // CHECK: %[[LOAD:.+]] = ml_program.global_load @var_x : tensor<1xf32>
+    %0 = tosa.variable.read @var_x : tensor<1xf32>
+    return %0 : tensor<1xf32>
+  }
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index eb670ad50163c38..3b2b9f2660164c7 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -3730,6 +3730,7 @@ cc_library(
         ":TensorToSPIRV",
         ":TosaToArith",
         ":TosaToLinalg",
+        ":TosaToMLProgram",
         ":TosaToSCF",
         ":TosaToTensor",
         ":UBToLLVM",
@@ -11054,6 +11055,30 @@ cc_library(
     ],
 )
 
+cc_library(
+    name = "TosaToMLProgram",
+    srcs = glob([
+        "lib/Conversion/TosaToMLProgram/*.cpp",
+        "lib/Conversion/TosaToMLProgram/*.h",
+    ]),
+    hdrs = glob([
+        "include/mlir/Conversion/TosaToMLProgram/*.h",
+    ]),
+    includes = [
+        "include",
+        "lib/Conversion/TosaToMLProgram",
+    ],
+    deps = [
+        ":ConversionPassIncGen",
+        ":FuncDialect",
+        ":IR",
+        ":Pass",
+        ":MLProgramDialect",
+        ":TosaDialect",
+        ":Transforms",
+    ],
+)
+
 cc_library(
     name = "TosaToSCF",
     srcs = glob([

@github-actions
Copy link

github-actions bot commented Oct 20, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

mlir::tosa::populateTosaToMLProgramConversionPatterns(&patterns);

if (failed(applyPartialConversion(moduleOp, target, std::move(patterns))))
signalPassFailure();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should log an error via emitMatchFailure

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, I couldn't find anything about emitMatchFailure [Link]. When I tried to build it, it also fails. Other passes are also using signalPassFailure()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually don't think this one can fail ... I think it would only fail if this didn't converge after 10 iterations given partial conversion. But fine as is.

(probably meant notifyMatchFailure)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually copied from TosaToLinag here: https://github.com/llvm/llvm-project/blob/d199fd76f7b76d902d8ef210d82689f299934793/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp#L70C3-L70C3.

And all other TosaTo* passes are checking this if condition and do signalPassFailure
If notifyMatchFailure is the correct one to use here, I suggest I can do another patch to update all the TosaTo* passes?

Copy link
Member

@jpienaar jpienaar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good, just a tweak here and there. I think you could remove the TFL parts in the description here as this is standalone (that being said you probably want to then check the bazel side given you want it to be usable TF side)

LogicalResult matchAndRewrite(tosa::VariableOp op,
PatternRewriter &rewriter) const final {
auto newVariable = rewriter.create<mlir::ml_program::GlobalOp>(
op.getLoc(), op.getName(), op.getType(), /*is_mutable=*/true,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why always mutable?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TOSA::Variable doesn't have the mutable attribute and assumes the variable is mutable by default. https://www.mlplatform.org/tosa/tosa_spec.html#_variable

RewritePatternSet patterns(context);
ConversionTarget target(*context);
target.addLegalDialect<ml_program::MLProgramDialect>();
target.addIllegalOp<tosa::VariableOp, tosa::VariableReadOp,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think only this setting may be needed and you could remove the other legal and dynamically legal settings as you are.using a partial conversion.

Copy link
Member Author

@Jerry-Ge Jerry-Ge Oct 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Back from some experiments, seems we can not remove both target.addLegalDialect<ml_program::MLProgramDialect>(); and target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });

Error:

staging/llvm-project/mlir/test/Conversion/TosaToMLProgram/tosa-to-mlprogram.mlir:4 offset :5:3: error: failed to legalize operation 'tosa.variable' that was explicitly marked illegal

Am I missing something here?

I propose we can only keep the target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll check here, there should be something simpler. But not blocking.

@Jerry-Ge
Copy link
Member Author

cc @eric-k256

Copy link
Member

@jpienaar jpienaar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG, thanks!

let dependentDialects = ["ml_program::MLProgramDialect"];
let description = [{
Pass that converts TOSA's variable operator operations to the equivalent
ml_program operations.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MLProgram (At least I think that's how the dialect doc writes it)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ack, done.

@@ -0,0 +1,31 @@
//===-- TosaToMLProgram.h - TOSA to ML_Program dialect lowerings-*- C++ -*-===//
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here too

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ack, done.

namespace mlir {

#define GEN_PASS_DECL_TOSATOMLPROGRAM
#include "mlir/Conversion/Passes.h.inc"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed. verified. removed it.

mlir::tosa::populateTosaToMLProgramConversionPatterns(&patterns);

if (failed(applyPartialConversion(moduleOp, target, std::move(patterns))))
signalPassFailure();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually don't think this one can fail ... I think it would only fail if this didn't converge after 10 iterations given partial conversion. But fine as is.

(probably meant notifyMatchFailure)

@@ -0,0 +1,16 @@
// RUN: mlir-opt --split-input-file --tosa-to-mlprogram %s -verify-diagnostics -o -| FileCheck %s
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets drop --split-input-file -verify-diagnostics and line 4 here, will result in nicer links to failed CHECKs and there are no errors being verified here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

The TOSA variable ops and ml_program ops offer similar functionality.
The tosa-to-mlprogram pass defines legalizations from the TOSA op to
the MLProgram equivalent op.

tosa.variable maps to ml_program.global.
tosa.variable_read maps to ml_program.global_load
tosa.varaible_write maps to ml_program.global_store

Signed-off-by: Jerry Ge <[email protected]>
@jpienaar jpienaar merged commit 6e6352f into llvm:main Nov 10, 2023
zahiraam pushed a commit to zahiraam/llvm-project that referenced this pull request Nov 20, 2023
This patch adds a new pass to lower TOSA StatefulOps to corresponding ML
Program ops (https://mlir.llvm.org/docs/Dialects/MLProgramOps/).

Signed-off-by: Jerry Ge <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants