-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[TOSA] Add TosaToMLProgram conversion #69787
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
//===-- TosaToMLProgram.h - TOSA to MLProgram dialect lowerings-*- C++ -*-===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// This file declares the passes for the TOSA to MLProgram Dialect conversion. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef MLIR_CONVERSION_TOSATOMLPROGRAM_TOSATOMLPROGRAM_H | ||
#define MLIR_CONVERSION_TOSATOMLPROGRAM_TOSATOMLPROGRAM_H | ||
|
||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Transforms/DialectConversion.h" | ||
|
||
namespace mlir { | ||
|
||
#define GEN_PASS_DECL_TOSATOMLPROGRAM | ||
|
||
namespace tosa { | ||
|
||
void populateTosaToMLProgramConversionPatterns(RewritePatternSet *patterns); | ||
|
||
} // namespace tosa | ||
} // namespace mlir | ||
|
||
#endif // MLIR_CONVERSION_TOSATOMLPROGRAM_TOSATOMLPROGRAM_H |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
add_mlir_conversion_library(MLIRTosaToMLProgram | ||
TosaToMLProgram.cpp | ||
TosaToMLProgramPass.cpp | ||
|
||
ADDITIONAL_HEADER_DIRS | ||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa | ||
${MLIR_MAIN_INCLUDE_DIR}/mlir/IR | ||
|
||
DEPENDS | ||
MLIRConversionPassIncGen | ||
|
||
LINK_LIBS PUBLIC | ||
MLIRIR | ||
MLIRMLProgramDialect | ||
MLIRPass | ||
MLIRTosaDialect | ||
MLIRTosaTransforms | ||
MLIRSupport | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
//===- TosaToMLProgram.cpp - Lowering Tosa to MLProgram Dialect------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// These rewriters lower from the TOSA dialect to the MLProgram dialect. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/Conversion/TosaToMLProgram/TosaToMLProgram.h" | ||
#include "mlir/Dialect/MLProgram/IR/MLProgram.h" | ||
#include "mlir/Dialect/Tosa/IR/TosaOps.h" | ||
#include "mlir/IR/IRMapping.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
|
||
using namespace mlir; | ||
using namespace tosa; | ||
namespace { | ||
|
||
class VariableOpConverter : public OpRewritePattern<tosa::VariableOp> { | ||
public: | ||
using OpRewritePattern<tosa::VariableOp>::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(tosa::VariableOp op, | ||
PatternRewriter &rewriter) const final { | ||
auto newVariable = rewriter.create<mlir::ml_program::GlobalOp>( | ||
op.getLoc(), op.getName(), op.getType(), /*is_mutable=*/true, | ||
op.getInitialValueAttr(), /*sym_visibility=*/nullptr); | ||
newVariable.setPrivate(); | ||
rewriter.replaceOp(op, newVariable); | ||
return success(); | ||
} | ||
}; | ||
|
||
class VariableWriteOpConverter | ||
: public OpRewritePattern<tosa::VariableWriteOp> { | ||
public: | ||
using OpRewritePattern<tosa::VariableWriteOp>::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(tosa::VariableWriteOp op, | ||
PatternRewriter &rewriter) const final { | ||
auto globalSymbolRef = | ||
SymbolRefAttr::get(rewriter.getContext(), op.getName()); | ||
auto newVariableWrite = rewriter.create<ml_program::GlobalStoreOp>( | ||
op.getLoc(), globalSymbolRef, op.getValue()); | ||
rewriter.replaceOp(op, newVariableWrite); | ||
return success(); | ||
} | ||
}; | ||
|
||
class VariableReadOpConverter : public OpRewritePattern<tosa::VariableReadOp> { | ||
public: | ||
using OpRewritePattern<tosa::VariableReadOp>::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(tosa::VariableReadOp op, | ||
PatternRewriter &rewriter) const final { | ||
auto globalSymbolRef = | ||
SymbolRefAttr::get(rewriter.getContext(), op.getName()); | ||
auto newVariableRead = rewriter.create<ml_program::GlobalLoadOp>( | ||
op.getLoc(), op.getType(), globalSymbolRef); | ||
rewriter.replaceOp(op, newVariableRead); | ||
|
||
return success(); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
void mlir::tosa::populateTosaToMLProgramConversionPatterns( | ||
RewritePatternSet *patterns) { | ||
patterns->add<VariableOpConverter, VariableWriteOpConverter, | ||
VariableReadOpConverter>(patterns->getContext()); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
//===- TosaToMLProgramPass.cpp - Lowering Tosa to MLProgram Dialect--------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// This transformation pass legalizes the TOSA dialect to the MLProgram dialect. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#include "mlir/Conversion/TosaToMLProgram/TosaToMLProgram.h" | ||
#include "mlir/Dialect/MLProgram/IR/MLProgram.h" | ||
#include "mlir/Dialect/Tosa/IR/TosaOps.h" | ||
#include "mlir/Dialect/Tosa/Transforms/Passes.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
#include "mlir/Pass/PassManager.h" | ||
#include "mlir/Transforms/DialectConversion.h" | ||
|
||
namespace mlir { | ||
#define GEN_PASS_DEF_TOSATOMLPROGRAM | ||
#include "mlir/Conversion/Passes.h.inc" | ||
} // namespace mlir | ||
|
||
using namespace mlir; | ||
using namespace tosa; | ||
|
||
namespace { | ||
struct TosaToMLProgram : public impl::TosaToMLProgramBase<TosaToMLProgram> { | ||
public: | ||
void runOnOperation() override { | ||
auto *context = &getContext(); | ||
auto moduleOp = getOperation(); | ||
|
||
RewritePatternSet patterns(context); | ||
ConversionTarget target(*context); | ||
target.addIllegalOp<tosa::VariableOp, tosa::VariableReadOp, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think only this setting may be needed and you could remove the other legal and dynamically legal settings as you are.using a partial conversion. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Back from some experiments, seems we can not remove both Error:
Am I missing something here? I propose we can only keep the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll check here, there should be something simpler. But not blocking. |
||
tosa::VariableWriteOp>(); | ||
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); | ||
|
||
mlir::tosa::populateTosaToMLProgramConversionPatterns(&patterns); | ||
|
||
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns)))) | ||
signalPassFailure(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You should log an error via There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sorry, I couldn't find anything about There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I actually don't think this one can fail ... I think it would only fail if this didn't converge after 10 iterations given partial conversion. But fine as is. (probably meant notifyMatchFailure) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. actually copied from TosaToLinag here: https://github.com/llvm/llvm-project/blob/d199fd76f7b76d902d8ef210d82689f299934793/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp#L70C3-L70C3. And all other TosaTo* passes are checking this if condition and do |
||
} | ||
}; | ||
} // namespace |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
// RUN: mlir-opt --tosa-to-mlprogram %s -o -| FileCheck %s | ||
|
||
module { | ||
// CHECK: ml_program.global private mutable @var_x(dense<7.000000e+00> : tensor<1xf32>) : tensor<1xf32> | ||
tosa.variable @var_x = dense<7.000000e+00> : tensor<1xf32> | ||
func.func @test_stateful_ops(%arg0: tensor<1xf32>) -> (tensor<1xf32>) { | ||
// CHECK: ml_program.global_store @var_x = %arg0 : tensor<1xf32> | ||
tosa.variable.write @var_x, %arg0 : tensor<1xf32> | ||
// CHECK: %[[LOAD:.+]] = ml_program.global_load @var_x : tensor<1xf32> | ||
%0 = tosa.variable.read @var_x : tensor<1xf32> | ||
return %0 : tensor<1xf32> | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why always mutable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TOSA::Variable
doesn't have themutable
attribute and assumes the variable is mutable by default. https://www.mlplatform.org/tosa/tosa_spec.html#_variable