Skip to content

[TOSA] Add TosaToMLProgram conversion #69787

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Conversion/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
#include "mlir/Conversion/TensorToSPIRV/TensorToSPIRVPass.h"
#include "mlir/Conversion/TosaToArith/TosaToArith.h"
#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
#include "mlir/Conversion/TosaToMLProgram/TosaToMLProgram.h"
#include "mlir/Conversion/TosaToSCF/TosaToSCF.h"
#include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
Expand Down
21 changes: 17 additions & 4 deletions mlir/include/mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -388,8 +388,8 @@ def ConvertFuncToLLVMPass : Pass<"convert-func-to-llvm", "ModuleOp"> {
already present in the IR will be kept as is.

An LLVM datalayout string can be attached as an attribute to the module on
which the pass anchors. Such an attribute is attached by calling the
set-module-datalayout pass. If present, an llvm::DataLayout object is
which the pass anchors. Such an attribute is attached by calling the
set-module-datalayout pass. If present, an llvm::DataLayout object is
created from this attribute and used in the conversion to LLVM.

#### Output IR
Expand Down Expand Up @@ -794,12 +794,12 @@ def ConvertMemRefToSPIRV : Pass<"convert-memref-to-spirv"> {
def ConvertNVVMToLLVMPass : Pass<"convert-nvvm-to-llvm"> {
let summary = "Convert NVVM to PTX with Inline Assembly in LLVM dialect";
let description = [{
This pass generates PTX instructions using inline assembly for NVVM
This pass generates PTX instructions using inline assembly for NVVM
operations implements `BasicPtxBuilderInterface`.
}];
let dependentDialects = [
"NVVM::NVVMDialect",
];
];
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1107,6 +1107,19 @@ def TosaToLinalgNamed
let constructor = "tosa::createTosaToLinalgNamed()";
}

//===----------------------------------------------------------------------===//
// TosaToMLProgram
//===----------------------------------------------------------------------===//

def TosaToMLProgram : Pass<"tosa-to-mlprogram", "ModuleOp"> {
let summary = "Lower TOSA to the MLProgram dialect";
let dependentDialects = ["ml_program::MLProgramDialect"];
let description = [{
Pass that converts TOSA's variable operator operations to the equivalent
MLProgram operations.
}];
}

//===----------------------------------------------------------------------===//
// TosaToSCF
//===----------------------------------------------------------------------===//
Expand Down
30 changes: 30 additions & 0 deletions mlir/include/mlir/Conversion/TosaToMLProgram/TosaToMLProgram.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
//===-- TosaToMLProgram.h - TOSA to MLProgram dialect lowerings-*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file declares the passes for the TOSA to MLProgram Dialect conversion.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_CONVERSION_TOSATOMLPROGRAM_TOSATOMLPROGRAM_H
#define MLIR_CONVERSION_TOSATOMLPROGRAM_TOSATOMLPROGRAM_H

#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"

namespace mlir {

#define GEN_PASS_DECL_TOSATOMLPROGRAM

namespace tosa {

void populateTosaToMLProgramConversionPatterns(RewritePatternSet *patterns);

} // namespace tosa
} // namespace mlir

#endif // MLIR_CONVERSION_TOSATOMLPROGRAM_TOSATOMLPROGRAM_H
1 change: 1 addition & 0 deletions mlir/lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ add_subdirectory(TensorToLinalg)
add_subdirectory(TensorToSPIRV)
add_subdirectory(TosaToArith)
add_subdirectory(TosaToLinalg)
add_subdirectory(TosaToMLProgram)
add_subdirectory(TosaToSCF)
add_subdirectory(TosaToTensor)
add_subdirectory(UBToLLVM)
Expand Down
19 changes: 19 additions & 0 deletions mlir/lib/Conversion/TosaToMLProgram/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
add_mlir_conversion_library(MLIRTosaToMLProgram
TosaToMLProgram.cpp
TosaToMLProgramPass.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa
${MLIR_MAIN_INCLUDE_DIR}/mlir/IR

DEPENDS
MLIRConversionPassIncGen

LINK_LIBS PUBLIC
MLIRIR
MLIRMLProgramDialect
MLIRPass
MLIRTosaDialect
MLIRTosaTransforms
MLIRSupport
)
76 changes: 76 additions & 0 deletions mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
//===- TosaToMLProgram.cpp - Lowering Tosa to MLProgram Dialect------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// These rewriters lower from the TOSA dialect to the MLProgram dialect.
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/TosaToMLProgram/TosaToMLProgram.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"

using namespace mlir;
using namespace tosa;
namespace {

class VariableOpConverter : public OpRewritePattern<tosa::VariableOp> {
public:
using OpRewritePattern<tosa::VariableOp>::OpRewritePattern;

LogicalResult matchAndRewrite(tosa::VariableOp op,
PatternRewriter &rewriter) const final {
auto newVariable = rewriter.create<mlir::ml_program::GlobalOp>(
op.getLoc(), op.getName(), op.getType(), /*is_mutable=*/true,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why always mutable?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TOSA::Variable doesn't have the mutable attribute and assumes the variable is mutable by default. https://www.mlplatform.org/tosa/tosa_spec.html#_variable

op.getInitialValueAttr(), /*sym_visibility=*/nullptr);
newVariable.setPrivate();
rewriter.replaceOp(op, newVariable);
return success();
}
};

class VariableWriteOpConverter
: public OpRewritePattern<tosa::VariableWriteOp> {
public:
using OpRewritePattern<tosa::VariableWriteOp>::OpRewritePattern;

LogicalResult matchAndRewrite(tosa::VariableWriteOp op,
PatternRewriter &rewriter) const final {
auto globalSymbolRef =
SymbolRefAttr::get(rewriter.getContext(), op.getName());
auto newVariableWrite = rewriter.create<ml_program::GlobalStoreOp>(
op.getLoc(), globalSymbolRef, op.getValue());
rewriter.replaceOp(op, newVariableWrite);
return success();
}
};

class VariableReadOpConverter : public OpRewritePattern<tosa::VariableReadOp> {
public:
using OpRewritePattern<tosa::VariableReadOp>::OpRewritePattern;

LogicalResult matchAndRewrite(tosa::VariableReadOp op,
PatternRewriter &rewriter) const final {
auto globalSymbolRef =
SymbolRefAttr::get(rewriter.getContext(), op.getName());
auto newVariableRead = rewriter.create<ml_program::GlobalLoadOp>(
op.getLoc(), op.getType(), globalSymbolRef);
rewriter.replaceOp(op, newVariableRead);

return success();
}
};

} // namespace

void mlir::tosa::populateTosaToMLProgramConversionPatterns(
RewritePatternSet *patterns) {
patterns->add<VariableOpConverter, VariableWriteOpConverter,
VariableReadOpConverter>(patterns->getContext());
}
48 changes: 48 additions & 0 deletions mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgramPass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
//===- TosaToMLProgramPass.cpp - Lowering Tosa to MLProgram Dialect--------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This transformation pass legalizes the TOSA dialect to the MLProgram dialect.
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/TosaToMLProgram/TosaToMLProgram.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/DialectConversion.h"

namespace mlir {
#define GEN_PASS_DEF_TOSATOMLPROGRAM
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir

using namespace mlir;
using namespace tosa;

namespace {
struct TosaToMLProgram : public impl::TosaToMLProgramBase<TosaToMLProgram> {
public:
void runOnOperation() override {
auto *context = &getContext();
auto moduleOp = getOperation();

RewritePatternSet patterns(context);
ConversionTarget target(*context);
target.addIllegalOp<tosa::VariableOp, tosa::VariableReadOp,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think only this setting may be needed and you could remove the other legal and dynamically legal settings as you are.using a partial conversion.

Copy link
Member Author

@Jerry-Ge Jerry-Ge Oct 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Back from some experiments, seems we can not remove both target.addLegalDialect<ml_program::MLProgramDialect>(); and target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });

Error:

staging/llvm-project/mlir/test/Conversion/TosaToMLProgram/tosa-to-mlprogram.mlir:4 offset :5:3: error: failed to legalize operation 'tosa.variable' that was explicitly marked illegal

Am I missing something here?

I propose we can only keep the target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll check here, there should be something simpler. But not blocking.

tosa::VariableWriteOp>();
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });

mlir::tosa::populateTosaToMLProgramConversionPatterns(&patterns);

if (failed(applyPartialConversion(moduleOp, target, std::move(patterns))))
signalPassFailure();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should log an error via emitMatchFailure

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, I couldn't find anything about emitMatchFailure [Link]. When I tried to build it, it also fails. Other passes are also using signalPassFailure()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually don't think this one can fail ... I think it would only fail if this didn't converge after 10 iterations given partial conversion. But fine as is.

(probably meant notifyMatchFailure)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually copied from TosaToLinag here: https://github.com/llvm/llvm-project/blob/d199fd76f7b76d902d8ef210d82689f299934793/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp#L70C3-L70C3.

And all other TosaTo* passes are checking this if condition and do signalPassFailure
If notifyMatchFailure is the correct one to use here, I suggest I can do another patch to update all the TosaTo* passes?

}
};
} // namespace
13 changes: 13 additions & 0 deletions mlir/test/Conversion/TosaToMLProgram/tosa-to-mlprogram.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// RUN: mlir-opt --tosa-to-mlprogram %s -o -| FileCheck %s

module {
// CHECK: ml_program.global private mutable @var_x(dense<7.000000e+00> : tensor<1xf32>) : tensor<1xf32>
tosa.variable @var_x = dense<7.000000e+00> : tensor<1xf32>
func.func @test_stateful_ops(%arg0: tensor<1xf32>) -> (tensor<1xf32>) {
// CHECK: ml_program.global_store @var_x = %arg0 : tensor<1xf32>
tosa.variable.write @var_x, %arg0 : tensor<1xf32>
// CHECK: %[[LOAD:.+]] = ml_program.global_load @var_x : tensor<1xf32>
%0 = tosa.variable.read @var_x : tensor<1xf32>
return %0 : tensor<1xf32>
}
}
25 changes: 25 additions & 0 deletions utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -3730,6 +3730,7 @@ cc_library(
":TensorToSPIRV",
":TosaToArith",
":TosaToLinalg",
":TosaToMLProgram",
":TosaToSCF",
":TosaToTensor",
":UBToLLVM",
Expand Down Expand Up @@ -11054,6 +11055,30 @@ cc_library(
],
)

cc_library(
name = "TosaToMLProgram",
srcs = glob([
"lib/Conversion/TosaToMLProgram/*.cpp",
"lib/Conversion/TosaToMLProgram/*.h",
]),
hdrs = glob([
"include/mlir/Conversion/TosaToMLProgram/*.h",
]),
includes = [
"include",
"lib/Conversion/TosaToMLProgram",
],
deps = [
":ConversionPassIncGen",
":FuncDialect",
":IR",
":Pass",
":MLProgramDialect",
":TosaDialect",
":Transforms",
],
)

cc_library(
name = "TosaToSCF",
srcs = glob([
Expand Down