-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[TOSA] Add TosaToMLProgram conversion #69787
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: Jerry-Ge (Jerry-Ge) ChangesThis patch adds a new pass to lower TOSA StatefulOps to corresponding ML Program ops (https://mlir.llvm.org/docs/Dialects/MLProgramOps/). This is part of the TOSA statefulOps effort:The currently lowering is:
After finishing the tfl->tosa.stateful and tosa.stateful->ml_program.stateful .
Full diff: https://github.com/llvm/llvm-project/pull/69787.diff 9 Files Affected:
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index e714f5070f23db8..637b69fc3f157b9 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -59,6 +59,7 @@
#include "mlir/Conversion/TensorToSPIRV/TensorToSPIRVPass.h"
#include "mlir/Conversion/TosaToArith/TosaToArith.h"
#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
+#include "mlir/Conversion/TosaToMLProgram/TosaToMLProgram.h"
#include "mlir/Conversion/TosaToSCF/TosaToSCF.h"
#include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index a269fb4a83af41f..6bdd9223329a8b4 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -388,8 +388,8 @@ def ConvertFuncToLLVMPass : Pass<"convert-func-to-llvm", "ModuleOp"> {
already present in the IR will be kept as is.
An LLVM datalayout string can be attached as an attribute to the module on
- which the pass anchors. Such an attribute is attached by calling the
- set-module-datalayout pass. If present, an llvm::DataLayout object is
+ which the pass anchors. Such an attribute is attached by calling the
+ set-module-datalayout pass. If present, an llvm::DataLayout object is
created from this attribute and used in the conversion to LLVM.
#### Output IR
@@ -1107,6 +1107,19 @@ def TosaToLinalgNamed
let constructor = "tosa::createTosaToLinalgNamed()";
}
+//===----------------------------------------------------------------------===//
+// TosaToMLProgram
+//===----------------------------------------------------------------------===//
+
+def TosaToMLProgram : Pass<"tosa-to-mlprogram", "ModuleOp"> {
+ let summary = "Lower TOSA to the ml_program dialect";
+ let dependentDialects = ["ml_program::MLProgramDialect"];
+ let description = [{
+ Pass that converts TOSA's variable operator operations to the equivalent
+ ml_program operations.
+ }];
+}
+
//===----------------------------------------------------------------------===//
// TosaToSCF
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Conversion/TosaToMLProgram/TosaToMLProgram.h b/mlir/include/mlir/Conversion/TosaToMLProgram/TosaToMLProgram.h
new file mode 100644
index 000000000000000..4c8f87e14797b66
--- /dev/null
+++ b/mlir/include/mlir/Conversion/TosaToMLProgram/TosaToMLProgram.h
@@ -0,0 +1,32 @@
+//===-- TosaToMLProgram.h - TOSA to ML_Program dialect lowerings -------------*-
+//C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file declares the passes for the TOSA to the SCF Dialect conversion.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_TOSATOMLPROGRAM_TOSATOMLPROGRAM_H
+#define MLIR_CONVERSION_TOSATOMLPROGRAM_TOSATOMLPROGRAM_H
+
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+
+#define GEN_PASS_DECL_TOSATOMLPROGRAM
+#include "mlir/Conversion/Passes.h.inc"
+
+namespace tosa {
+
+void populateTosaToMLProgramConversionPatterns(RewritePatternSet *patterns);
+
+} // namespace tosa
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_TOSATOMLPROGRAM_TOSATOMLPROGRAM_H
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 35790254be137be..664804f0453509f 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -49,6 +49,7 @@ add_subdirectory(TensorToLinalg)
add_subdirectory(TensorToSPIRV)
add_subdirectory(TosaToArith)
add_subdirectory(TosaToLinalg)
+add_subdirectory(TosaToMLProgram)
add_subdirectory(TosaToSCF)
add_subdirectory(TosaToTensor)
add_subdirectory(UBToLLVM)
diff --git a/mlir/lib/Conversion/TosaToMLProgram/CMakeLists.txt b/mlir/lib/Conversion/TosaToMLProgram/CMakeLists.txt
new file mode 100644
index 000000000000000..82941424f1d1025
--- /dev/null
+++ b/mlir/lib/Conversion/TosaToMLProgram/CMakeLists.txt
@@ -0,0 +1,19 @@
+add_mlir_conversion_library(MLIRTosaToMLProgram
+ TosaToMLProgram.cpp
+ TosaToMLProgramPass.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRMLProgramDialect
+ MLIRPass
+ MLIRTosaDialect
+ MLIRTosaTransforms
+ MLIRSupport
+ )
diff --git a/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp b/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp
new file mode 100644
index 000000000000000..4fe83fe9d81fcae
--- /dev/null
+++ b/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp
@@ -0,0 +1,77 @@
+//===- TosaToMLProgram.cpp - Lowering Tosa to MLProgram Dialect
+//-----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// These rewriters lower from the TOSA dialect to the MLProgram dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/TosaToMLProgram/TosaToMLProgram.h"
+#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+using namespace tosa;
+namespace {
+
+class VariableOpConverter : public OpRewritePattern<tosa::VariableOp> {
+public:
+ using OpRewritePattern<tosa::VariableOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::VariableOp op,
+ PatternRewriter &rewriter) const final {
+ auto newVariable = rewriter.create<mlir::ml_program::GlobalOp>(
+ op.getLoc(), op.getName(), op.getType(), /*is_mutable=*/true,
+ op.getInitialValueAttr(), nullptr);
+ newVariable.setPrivate();
+ rewriter.replaceOp(op, newVariable);
+ return success();
+ }
+};
+
+class VariableWriteOpConverter
+ : public OpRewritePattern<tosa::VariableWriteOp> {
+public:
+ using OpRewritePattern<tosa::VariableWriteOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::VariableWriteOp op,
+ PatternRewriter &rewriter) const final {
+ auto globalSymbolRef =
+ SymbolRefAttr::get(rewriter.getContext(), op.getName());
+ auto newVariableWrite = rewriter.create<ml_program::GlobalStoreOp>(
+ op.getLoc(), globalSymbolRef, op.getValue());
+ rewriter.replaceOp(op, newVariableWrite);
+ return success();
+ }
+};
+
+class VariableReadOpConverter : public OpRewritePattern<tosa::VariableReadOp> {
+public:
+ using OpRewritePattern<tosa::VariableReadOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::VariableReadOp op,
+ PatternRewriter &rewriter) const final {
+ auto globalSymbolRef =
+ SymbolRefAttr::get(rewriter.getContext(), op.getName());
+ auto newVariableRead = rewriter.create<ml_program::GlobalLoadOp>(
+ op.getLoc(), op.getType(), globalSymbolRef);
+ rewriter.replaceOp(op, newVariableRead);
+
+ return success();
+ }
+};
+
+} // namespace
+
+void mlir::tosa::populateTosaToMLProgramConversionPatterns(
+ RewritePatternSet *patterns) {
+ patterns->add<VariableOpConverter, VariableWriteOpConverter,
+ VariableReadOpConverter>(patterns->getContext());
+}
\ No newline at end of file
diff --git a/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgramPass.cpp b/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgramPass.cpp
new file mode 100644
index 000000000000000..1b67decdabd35be
--- /dev/null
+++ b/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgramPass.cpp
@@ -0,0 +1,50 @@
+//===- TosaToMLProgramPass.cpp - Lowering Tosa to MLProgram Dialect
+//-------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This transformation pass legalizes the TOSA dialect to the MLProgram dialect.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/TosaToMLProgram/TosaToMLProgram.h"
+#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_TOSATOMLPROGRAM
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+using namespace tosa;
+
+namespace {
+struct TosaToMLProgram : public impl::TosaToMLProgramBase<TosaToMLProgram> {
+public:
+ void runOnOperation() override {
+ auto *context = &getContext();
+ auto moduleOp = getOperation();
+
+ RewritePatternSet patterns(context);
+ ConversionTarget target(*context);
+ target.addLegalDialect<ml_program::MLProgramDialect>();
+ target.addIllegalOp<tosa::VariableOp, tosa::VariableReadOp,
+ tosa::VariableWriteOp>();
+ target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
+
+ mlir::tosa::populateTosaToMLProgramConversionPatterns(&patterns);
+
+ if (failed(applyPartialConversion(moduleOp, target, std::move(patterns))))
+ signalPassFailure();
+ }
+};
+} // namespace
diff --git a/mlir/test/Conversion/TosaToMLProgram/tosa-to-mlprogram.mlir b/mlir/test/Conversion/TosaToMLProgram/tosa-to-mlprogram.mlir
new file mode 100644
index 000000000000000..b76358515d9aa79
--- /dev/null
+++ b/mlir/test/Conversion/TosaToMLProgram/tosa-to-mlprogram.mlir
@@ -0,0 +1,16 @@
+// RUN: mlir-opt --split-input-file --tosa-to-mlprogram %s -verify-diagnostics -o -| FileCheck %s
+
+
+// -----
+
+module {
+ // CHECK: ml_program.global private mutable @var_x(dense<7.000000e+00> : tensor<1xf32>) : tensor<1xf32>
+ tosa.variable @var_x = dense<7.000000e+00> : tensor<1xf32>
+ func.func @test_stateful_ops(%arg0: tensor<1xf32>) -> (tensor<1xf32>) {
+ // CHECK: ml_program.global_store @var_x = %arg0 : tensor<1xf32>
+ tosa.variable.write @var_x, %arg0 : tensor<1xf32>
+ // CHECK: %[[LOAD:.+]] = ml_program.global_load @var_x : tensor<1xf32>
+ %0 = tosa.variable.read @var_x : tensor<1xf32>
+ return %0 : tensor<1xf32>
+ }
+}
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index eb670ad50163c38..3b2b9f2660164c7 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -3730,6 +3730,7 @@ cc_library(
":TensorToSPIRV",
":TosaToArith",
":TosaToLinalg",
+ ":TosaToMLProgram",
":TosaToSCF",
":TosaToTensor",
":UBToLLVM",
@@ -11054,6 +11055,30 @@ cc_library(
],
)
+cc_library(
+ name = "TosaToMLProgram",
+ srcs = glob([
+ "lib/Conversion/TosaToMLProgram/*.cpp",
+ "lib/Conversion/TosaToMLProgram/*.h",
+ ]),
+ hdrs = glob([
+ "include/mlir/Conversion/TosaToMLProgram/*.h",
+ ]),
+ includes = [
+ "include",
+ "lib/Conversion/TosaToMLProgram",
+ ],
+ deps = [
+ ":ConversionPassIncGen",
+ ":FuncDialect",
+ ":IR",
+ ":Pass",
+ ":MLProgramDialect",
+ ":TosaDialect",
+ ":Transforms",
+ ],
+)
+
cc_library(
name = "TosaToSCF",
srcs = glob([
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
mlir::tosa::populateTosaToMLProgramConversionPatterns(&patterns); | ||
|
||
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) | ||
signalPassFailure(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should log an error via emitMatchFailure
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry, I couldn't find anything about emitMatchFailure
[Link]. When I tried to build it, it also fails. Other passes are also using signalPassFailure()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually don't think this one can fail ... I think it would only fail if this didn't converge after 10 iterations given partial conversion. But fine as is.
(probably meant notifyMatchFailure)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually copied from TosaToLinag here: https://github.com/llvm/llvm-project/blob/d199fd76f7b76d902d8ef210d82689f299934793/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp#L70C3-L70C3.
And all other TosaTo* passes are checking this if condition and do signalPassFailure
If notifyMatchFailure
is the correct one to use here, I suggest I can do another patch to update all the TosaTo* passes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good, just a tweak here and there. I think you could remove the TFL parts in the description here as this is standalone (that being said you probably want to then check the bazel side given you want it to be usable TF side)
LogicalResult matchAndRewrite(tosa::VariableOp op, | ||
PatternRewriter &rewriter) const final { | ||
auto newVariable = rewriter.create<mlir::ml_program::GlobalOp>( | ||
op.getLoc(), op.getName(), op.getType(), /*is_mutable=*/true, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why always mutable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TOSA::Variable
doesn't have the mutable
attribute and assumes the variable is mutable by default. https://www.mlplatform.org/tosa/tosa_spec.html#_variable
RewritePatternSet patterns(context); | ||
ConversionTarget target(*context); | ||
target.addLegalDialect<ml_program::MLProgramDialect>(); | ||
target.addIllegalOp<tosa::VariableOp, tosa::VariableReadOp, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think only this setting may be needed and you could remove the other legal and dynamically legal settings as you are.using a partial conversion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Back from some experiments, seems we can not remove both target.addLegalDialect<ml_program::MLProgramDialect>();
and target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
Error:
staging/llvm-project/mlir/test/Conversion/TosaToMLProgram/tosa-to-mlprogram.mlir:4 offset :5:3: error: failed to legalize operation 'tosa.variable' that was explicitly marked illegal
Am I missing something here?
I propose we can only keep the target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll check here, there should be something simpler. But not blocking.
cc @eric-k256 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG, thanks!
let dependentDialects = ["ml_program::MLProgramDialect"]; | ||
let description = [{ | ||
Pass that converts TOSA's variable operator operations to the equivalent | ||
ml_program operations. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MLProgram (At least I think that's how the dialect doc writes it)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ack, done.
@@ -0,0 +1,31 @@ | |||
//===-- TosaToMLProgram.h - TOSA to ML_Program dialect lowerings-*- C++ -*-===// |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ack, done.
namespace mlir { | ||
|
||
#define GEN_PASS_DECL_TOSATOMLPROGRAM | ||
#include "mlir/Conversion/Passes.h.inc" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this needed here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not needed. verified. removed it.
mlir::tosa::populateTosaToMLProgramConversionPatterns(&patterns); | ||
|
||
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) | ||
signalPassFailure(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually don't think this one can fail ... I think it would only fail if this didn't converge after 10 iterations given partial conversion. But fine as is.
(probably meant notifyMatchFailure)
@@ -0,0 +1,16 @@ | |||
// RUN: mlir-opt --split-input-file --tosa-to-mlprogram %s -verify-diagnostics -o -| FileCheck %s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lets drop --split-input-file -verify-diagnostics and line 4 here, will result in nicer links to failed CHECKs and there are no errors being verified here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
The TOSA variable ops and ml_program ops offer similar functionality. The tosa-to-mlprogram pass defines legalizations from the TOSA op to the MLProgram equivalent op. tosa.variable maps to ml_program.global. tosa.variable_read maps to ml_program.global_load tosa.varaible_write maps to ml_program.global_store Signed-off-by: Jerry Ge <[email protected]>
This patch adds a new pass to lower TOSA StatefulOps to corresponding ML Program ops (https://mlir.llvm.org/docs/Dialects/MLProgramOps/). Signed-off-by: Jerry Ge <[email protected]>
This patch adds a new pass to lower TOSA StatefulOps to corresponding ML Program ops (https://mlir.llvm.org/docs/Dialects/MLProgramOps/).