Skip to content

Commit 0edbdad

Browse files
committed
[TOSA] Add a pass to convert TOSA Variable Ops to MLProgram Global Ops
The TOSA variable ops and ml_program ops offer similar functionality. The tosa-to-mlprogram pass defines legalizations from the TOSA op to the MLProgram equivalent op. tosa.variable maps to ml_program.global. tosa.variable_read maps to ml_program.global_load tosa.varaible_write maps to ml_program.global_store Signed-off-by: Jerry Ge <[email protected]>
1 parent a432358 commit 0edbdad

File tree

9 files changed

+230
-4
lines changed

9 files changed

+230
-4
lines changed

mlir/include/mlir/Conversion/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
#include "mlir/Conversion/TensorToSPIRV/TensorToSPIRVPass.h"
6060
#include "mlir/Conversion/TosaToArith/TosaToArith.h"
6161
#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
62+
#include "mlir/Conversion/TosaToMLProgram/TosaToMLProgram.h"
6263
#include "mlir/Conversion/TosaToSCF/TosaToSCF.h"
6364
#include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
6465
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"

mlir/include/mlir/Conversion/Passes.td

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -388,8 +388,8 @@ def ConvertFuncToLLVMPass : Pass<"convert-func-to-llvm", "ModuleOp"> {
388388
already present in the IR will be kept as is.
389389

390390
An LLVM datalayout string can be attached as an attribute to the module on
391-
which the pass anchors. Such an attribute is attached by calling the
392-
set-module-datalayout pass. If present, an llvm::DataLayout object is
391+
which the pass anchors. Such an attribute is attached by calling the
392+
set-module-datalayout pass. If present, an llvm::DataLayout object is
393393
created from this attribute and used in the conversion to LLVM.
394394

395395
#### Output IR
@@ -794,12 +794,12 @@ def ConvertMemRefToSPIRV : Pass<"convert-memref-to-spirv"> {
794794
def ConvertNVVMToLLVMPass : Pass<"convert-nvvm-to-llvm"> {
795795
let summary = "Convert NVVM to PTX with Inline Assembly in LLVM dialect";
796796
let description = [{
797-
This pass generates PTX instructions using inline assembly for NVVM
797+
This pass generates PTX instructions using inline assembly for NVVM
798798
operations implements `BasicPtxBuilderInterface`.
799799
}];
800800
let dependentDialects = [
801801
"NVVM::NVVMDialect",
802-
];
802+
];
803803
}
804804

805805
//===----------------------------------------------------------------------===//
@@ -1107,6 +1107,19 @@ def TosaToLinalgNamed
11071107
let constructor = "tosa::createTosaToLinalgNamed()";
11081108
}
11091109

1110+
//===----------------------------------------------------------------------===//
1111+
// TosaToMLProgram
1112+
//===----------------------------------------------------------------------===//
1113+
1114+
def TosaToMLProgram : Pass<"tosa-to-mlprogram", "ModuleOp"> {
1115+
let summary = "Lower TOSA to the MLProgram dialect";
1116+
let dependentDialects = ["ml_program::MLProgramDialect"];
1117+
let description = [{
1118+
Pass that converts TOSA's variable operator operations to the equivalent
1119+
MLProgram operations.
1120+
}];
1121+
}
1122+
11101123
//===----------------------------------------------------------------------===//
11111124
// TosaToSCF
11121125
//===----------------------------------------------------------------------===//
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
//===-- TosaToMLProgram.h - TOSA to MLProgram dialect lowerings-*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file declares the passes for the TOSA to MLProgram Dialect conversion.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_CONVERSION_TOSATOMLPROGRAM_TOSATOMLPROGRAM_H
14+
#define MLIR_CONVERSION_TOSATOMLPROGRAM_TOSATOMLPROGRAM_H
15+
16+
#include "mlir/Pass/Pass.h"
17+
#include "mlir/Transforms/DialectConversion.h"
18+
19+
namespace mlir {
20+
21+
#define GEN_PASS_DECL_TOSATOMLPROGRAM
22+
23+
namespace tosa {
24+
25+
void populateTosaToMLProgramConversionPatterns(RewritePatternSet *patterns);
26+
27+
} // namespace tosa
28+
} // namespace mlir
29+
30+
#endif // MLIR_CONVERSION_TOSATOMLPROGRAM_TOSATOMLPROGRAM_H

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ add_subdirectory(TensorToLinalg)
4949
add_subdirectory(TensorToSPIRV)
5050
add_subdirectory(TosaToArith)
5151
add_subdirectory(TosaToLinalg)
52+
add_subdirectory(TosaToMLProgram)
5253
add_subdirectory(TosaToSCF)
5354
add_subdirectory(TosaToTensor)
5455
add_subdirectory(UBToLLVM)
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
add_mlir_conversion_library(MLIRTosaToMLProgram
2+
TosaToMLProgram.cpp
3+
TosaToMLProgramPass.cpp
4+
5+
ADDITIONAL_HEADER_DIRS
6+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa
7+
${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
8+
9+
DEPENDS
10+
MLIRConversionPassIncGen
11+
12+
LINK_LIBS PUBLIC
13+
MLIRIR
14+
MLIRMLProgramDialect
15+
MLIRPass
16+
MLIRTosaDialect
17+
MLIRTosaTransforms
18+
MLIRSupport
19+
)
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
//===- TosaToMLProgram.cpp - Lowering Tosa to MLProgram Dialect------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// These rewriters lower from the TOSA dialect to the MLProgram dialect.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Conversion/TosaToMLProgram/TosaToMLProgram.h"
14+
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
15+
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
16+
#include "mlir/IR/IRMapping.h"
17+
#include "mlir/IR/PatternMatch.h"
18+
19+
using namespace mlir;
20+
using namespace tosa;
21+
namespace {
22+
23+
class VariableOpConverter : public OpRewritePattern<tosa::VariableOp> {
24+
public:
25+
using OpRewritePattern<tosa::VariableOp>::OpRewritePattern;
26+
27+
LogicalResult matchAndRewrite(tosa::VariableOp op,
28+
PatternRewriter &rewriter) const final {
29+
auto newVariable = rewriter.create<mlir::ml_program::GlobalOp>(
30+
op.getLoc(), op.getName(), op.getType(), /*is_mutable=*/true,
31+
op.getInitialValueAttr(), /*sym_visibility=*/nullptr);
32+
newVariable.setPrivate();
33+
rewriter.replaceOp(op, newVariable);
34+
return success();
35+
}
36+
};
37+
38+
class VariableWriteOpConverter
39+
: public OpRewritePattern<tosa::VariableWriteOp> {
40+
public:
41+
using OpRewritePattern<tosa::VariableWriteOp>::OpRewritePattern;
42+
43+
LogicalResult matchAndRewrite(tosa::VariableWriteOp op,
44+
PatternRewriter &rewriter) const final {
45+
auto globalSymbolRef =
46+
SymbolRefAttr::get(rewriter.getContext(), op.getName());
47+
auto newVariableWrite = rewriter.create<ml_program::GlobalStoreOp>(
48+
op.getLoc(), globalSymbolRef, op.getValue());
49+
rewriter.replaceOp(op, newVariableWrite);
50+
return success();
51+
}
52+
};
53+
54+
class VariableReadOpConverter : public OpRewritePattern<tosa::VariableReadOp> {
55+
public:
56+
using OpRewritePattern<tosa::VariableReadOp>::OpRewritePattern;
57+
58+
LogicalResult matchAndRewrite(tosa::VariableReadOp op,
59+
PatternRewriter &rewriter) const final {
60+
auto globalSymbolRef =
61+
SymbolRefAttr::get(rewriter.getContext(), op.getName());
62+
auto newVariableRead = rewriter.create<ml_program::GlobalLoadOp>(
63+
op.getLoc(), op.getType(), globalSymbolRef);
64+
rewriter.replaceOp(op, newVariableRead);
65+
66+
return success();
67+
}
68+
};
69+
70+
} // namespace
71+
72+
void mlir::tosa::populateTosaToMLProgramConversionPatterns(
73+
RewritePatternSet *patterns) {
74+
patterns->add<VariableOpConverter, VariableWriteOpConverter,
75+
VariableReadOpConverter>(patterns->getContext());
76+
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
//===- TosaToMLProgramPass.cpp - Lowering Tosa to MLProgram Dialect--------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This transformation pass legalizes the TOSA dialect to the MLProgram dialect.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Conversion/TosaToMLProgram/TosaToMLProgram.h"
14+
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
15+
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
16+
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
17+
#include "mlir/IR/PatternMatch.h"
18+
#include "mlir/Pass/PassManager.h"
19+
#include "mlir/Transforms/DialectConversion.h"
20+
21+
namespace mlir {
22+
#define GEN_PASS_DEF_TOSATOMLPROGRAM
23+
#include "mlir/Conversion/Passes.h.inc"
24+
} // namespace mlir
25+
26+
using namespace mlir;
27+
using namespace tosa;
28+
29+
namespace {
30+
struct TosaToMLProgram : public impl::TosaToMLProgramBase<TosaToMLProgram> {
31+
public:
32+
void runOnOperation() override {
33+
auto *context = &getContext();
34+
auto moduleOp = getOperation();
35+
36+
RewritePatternSet patterns(context);
37+
ConversionTarget target(*context);
38+
target.addIllegalOp<tosa::VariableOp, tosa::VariableReadOp,
39+
tosa::VariableWriteOp>();
40+
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
41+
42+
mlir::tosa::populateTosaToMLProgramConversionPatterns(&patterns);
43+
44+
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns))))
45+
signalPassFailure();
46+
}
47+
};
48+
} // namespace
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
// RUN: mlir-opt --tosa-to-mlprogram %s -o -| FileCheck %s
2+
3+
module {
4+
// CHECK: ml_program.global private mutable @var_x(dense<7.000000e+00> : tensor<1xf32>) : tensor<1xf32>
5+
tosa.variable @var_x = dense<7.000000e+00> : tensor<1xf32>
6+
func.func @test_stateful_ops(%arg0: tensor<1xf32>) -> (tensor<1xf32>) {
7+
// CHECK: ml_program.global_store @var_x = %arg0 : tensor<1xf32>
8+
tosa.variable.write @var_x, %arg0 : tensor<1xf32>
9+
// CHECK: %[[LOAD:.+]] = ml_program.global_load @var_x : tensor<1xf32>
10+
%0 = tosa.variable.read @var_x : tensor<1xf32>
11+
return %0 : tensor<1xf32>
12+
}
13+
}

utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3730,6 +3730,7 @@ cc_library(
37303730
":TensorToSPIRV",
37313731
":TosaToArith",
37323732
":TosaToLinalg",
3733+
":TosaToMLProgram",
37333734
":TosaToSCF",
37343735
":TosaToTensor",
37353736
":UBToLLVM",
@@ -11054,6 +11055,30 @@ cc_library(
1105411055
],
1105511056
)
1105611057

11058+
cc_library(
11059+
name = "TosaToMLProgram",
11060+
srcs = glob([
11061+
"lib/Conversion/TosaToMLProgram/*.cpp",
11062+
"lib/Conversion/TosaToMLProgram/*.h",
11063+
]),
11064+
hdrs = glob([
11065+
"include/mlir/Conversion/TosaToMLProgram/*.h",
11066+
]),
11067+
includes = [
11068+
"include",
11069+
"lib/Conversion/TosaToMLProgram",
11070+
],
11071+
deps = [
11072+
":ConversionPassIncGen",
11073+
":FuncDialect",
11074+
":IR",
11075+
":Pass",
11076+
":MLProgramDialect",
11077+
":TosaDialect",
11078+
":Transforms",
11079+
],
11080+
)
11081+
1105711082
cc_library(
1105811083
name = "TosaToSCF",
1105911084
srcs = glob([

0 commit comments

Comments
 (0)