Skip to content

Commit 1cb97e0

Browse files
Initial Implementation of MediaTek Backend for Executorch
Differential Revision: D60970271 Pull Request resolved: #3571
1 parent c541bc1 commit 1cb97e0

File tree

75 files changed

+425878
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

75 files changed

+425878
-0
lines changed

CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,8 @@ option(EXECUTORCH_BUILD_GTESTS "Build googletest based test binaries" OFF)
179179

180180
option(EXECUTORCH_BUILD_MPS "Build the MPS backend" OFF)
181181

182+
option(EXECUTORCH_BUILD_NEURON "Build the backends/mediatek directory" OFF)
183+
182184
option(EXECUTORCH_BUILD_PYBIND "Build the Python Bindings" OFF)
183185

184186
option(EXECUTORCH_BUILD_QNN "Build the Qualcomm backend" OFF)
@@ -624,6 +626,10 @@ if(EXECUTORCH_BUILD_EXTENSION_MODULE)
624626
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/module)
625627
endif()
626628

629+
if(EXECUTORCH_BUILD_NEURON)
630+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/mediatek)
631+
endif()
632+
627633
if(EXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL)
628634
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/runner_util)
629635
endif()

LICENSE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ Copyright (c) Meta Platforms, Inc. and affiliates.
66
Copyright 2023 Arm Limited and/or its affiliates.
77
Copyright (c) Qualcomm Innovation Center, Inc.
88
Copyright (c) 2023 Apple Inc.
9+
Copyright (c) 2024 MediaTek Inc.
910

1011
Redistribution and use in source and binary forms, with or without modification,
1112
are permitted provided that the following conditions are met:

backends/mediatek/CMakeLists.txt

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#[[
2+
/*
3+
* Copyright (c) 2024 MediaTek Inc.
4+
*
5+
* Licensed under the BSD License (the "License"); you may not use this file
6+
* except in compliance with the License. See the license file in the root
7+
* directory of this source tree for more details.
8+
*/
9+
]]
10+
11+
# Let include directory as "executorch/..."
12+
set(_common_include_directories ${CMAKE_CURRENT_SOURCE_DIR}/../../..)
13+
set(NEURON_BUFFER_ALLOCATOR_LIB "" CACHE PATH "Path to Neuron Buffer Allocator library")
14+
message(STATUS "Looking for neuron_buffer_allocator in ${NEURON_BUFFER_ALLOCATOR_LIB}")
15+
16+
include_directories(
17+
BEFORE
18+
${_common_include_directories}
19+
)
20+
21+
# shortcut include directory for neuron headers
22+
include_directories(
23+
BEFORE
24+
${CMAKE_CURRENT_SOURCE_DIR}/runtime/include
25+
)
26+
27+
# targets
28+
add_library(neuron_backend SHARED)
29+
target_link_libraries(neuron_backend
30+
PRIVATE
31+
executorch_no_prim_ops
32+
android
33+
log
34+
${NEURON_BUFFER_ALLOCATOR_LIB}
35+
)
36+
target_sources(neuron_backend
37+
INTERFACE
38+
${CMAKE_CURRENT_LIST_DIR}/runtime/include/NeuronBackend.h
39+
${CMAKE_CURRENT_LIST_DIR}/runtime/include/NeuronBufferAllocator.h
40+
${CMAKE_CURRENT_LIST_DIR}/runtime/include/NeuronExecutor.h
41+
${CMAKE_CURRENT_LIST_DIR}/runtime/include/NeuronLog.h
42+
${CMAKE_CURRENT_LIST_DIR}/runtime/include/api/APUWareUtilsLib.h
43+
${CMAKE_CURRENT_LIST_DIR}/runtime/include/api/NeuronAdapterShim.h
44+
PRIVATE
45+
${CMAKE_CURRENT_LIST_DIR}/runtime/NeuronBackend.cpp
46+
${CMAKE_CURRENT_LIST_DIR}/runtime/NeuronExecutor.cpp
47+
)
48+
target_link_options_shared_lib(neuron_backend)
49+
50+
install(TARGETS neuron_backend DESTINATION lib)

backends/mediatek/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .partitioner import NeuropilotPartitioner
2+
from .preprocess import NeuropilotBackend
3+
from .quantizer import NeuropilotQuantizer, Precision
4+
5+
__all__ = [NeuropilotBackend, NeuropilotPartitioner, NeuropilotQuantizer, Precision]

backends/mediatek/partitioner.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright (c) 2024 MediaTek Inc.
2+
#
3+
# Licensed under the BSD License (the "License"); you may not use this file
4+
# except in compliance with the License. See the license file in the root
5+
# directory of this source tree for more details.
6+
7+
from typing import Callable, final, List, Optional, Tuple
8+
9+
import torch
10+
from executorch.backends.mediatek.preprocess import NeuropilotBackend
11+
from executorch.exir.backend.backend_details import CompileSpec
12+
from executorch.exir.backend.partitioner import (
13+
DelegationSpec,
14+
Partitioner,
15+
PartitionResult,
16+
)
17+
from executorch.exir.backend.utils import tag_constant_data
18+
19+
from mtk_converter.python.converters.pytorch import importer_v2
20+
from torch.export.exported_program import ExportedProgram
21+
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
22+
from torch.fx.passes.operator_support import OperatorSupportBase
23+
24+
25+
class NeuropilotOperatorsSupport(OperatorSupportBase):
26+
27+
def __init__(
28+
self,
29+
op_types_to_skip: Optional[set] = None,
30+
op_names_to_skip: Optional[set] = None,
31+
) -> None:
32+
if op_types_to_skip is None:
33+
op_types_to_skip = set()
34+
if op_names_to_skip is None:
35+
op_names_to_skip = set()
36+
37+
self._op_types_to_skip = op_types_to_skip
38+
self._op_names_to_skip = op_names_to_skip
39+
40+
def is_node_supported(self, _, node: torch.fx.Node) -> bool:
41+
# Handle 'call_function' only cause 'placeholder' and 'output' cannot be tagged.
42+
# Ref: https://github.com/pytorch/executorch/pull/1398
43+
if node.op != "call_function":
44+
return False
45+
46+
op_type = node.target.__name__
47+
if op_type in self._op_types_to_skip or node.name in self._op_names_to_skip:
48+
print(
49+
f"[Neuropilot Backend] The {op_type} operator with name '{node.name}' is skipped."
50+
)
51+
return False
52+
53+
return importer_v2.is_fx_node_supported(node)
54+
55+
56+
@final
57+
class NeuropilotPartitioner(Partitioner):
58+
59+
def __init__(
60+
self,
61+
compile_spec: List[CompileSpec],
62+
op_types_to_skip: Optional[set] = None,
63+
op_names_to_skip: Optional[set] = None,
64+
) -> None:
65+
self.delegation_spec = DelegationSpec(NeuropilotBackend.__name__, compile_spec)
66+
self._op_types_to_skip = op_types_to_skip
67+
self._op_names_to_skip = op_names_to_skip
68+
69+
def ops_to_not_decompose(
70+
self,
71+
ep: ExportedProgram,
72+
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
73+
ops_not_decompose = [
74+
torch.ops.aten.pixel_shuffle.default,
75+
torch.ops.aten.upsample_bilinear2d.default,
76+
torch.ops.aten.upsample_bilinear2d.vec,
77+
torch.ops.aten.upsample_nearest2d.default,
78+
torch.ops.aten.upsample_nearest2d.vec,
79+
]
80+
return (ops_not_decompose, None)
81+
82+
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
83+
capability_partitioner = CapabilityBasedPartitioner(
84+
exported_program.graph_module,
85+
NeuropilotOperatorsSupport(self._op_types_to_skip, self._op_names_to_skip),
86+
allows_single_node_partition=True,
87+
)
88+
partition_list = capability_partitioner.propose_partitions()
89+
90+
partition_tags = {}
91+
for partition in partition_list:
92+
for node in partition.nodes:
93+
tag = f"tag{partition.id}"
94+
node.meta["delegation_tag"] = tag
95+
partition_tags[tag] = self.delegation_spec
96+
97+
tag_constant_data(exported_program)
98+
99+
return PartitionResult(
100+
tagged_exported_program=exported_program, partition_tags=partition_tags
101+
)

backends/mediatek/passes/__init__.py

Whitespace-only changes.
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright (c) 2024 MediaTek Inc.
2+
#
3+
# Licensed under the BSD License (the "License"); you may not use this file
4+
# except in compliance with the License. See the license file in the root
5+
# directory of this source tree for more details.
6+
7+
import torch
8+
9+
from executorch.exir.pass_base import ExportPass, PassResult
10+
from torch._decomp import get_decompositions
11+
from torch.fx import Graph
12+
from torch.fx.experimental.proxy_tensor import make_fx
13+
14+
15+
def _get_input_node_names(graph: Graph):
16+
input_names = []
17+
for node in graph.nodes:
18+
if node.op == "placeholder":
19+
input_names.append(node.name)
20+
return input_names
21+
22+
23+
class DecomposeScaledDotProductAttention(ExportPass):
24+
"""Decompose the single SDPA operator."""
25+
26+
def call(self, graph_module: torch.fx.GraphModule):
27+
graph = graph_module.graph
28+
for node in graph.nodes:
29+
if node.target != torch.ops.aten.scaled_dot_product_attention.default:
30+
continue
31+
32+
decom_mappings = get_decompositions(
33+
[torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default]
34+
)
35+
input_tensors = (arg.meta["val"] for arg in node.args)
36+
decomposed_module = make_fx(node.target, decom_mappings, "fake", True)(
37+
*input_tensors
38+
)
39+
decomposed_input_names = _get_input_node_names(decomposed_module.graph)
40+
with graph.inserting_before(node):
41+
name_to_input_tensor_map = {}
42+
for idx, arg in enumerate(node.args):
43+
name_to_input_tensor_map[decomposed_input_names[idx]] = arg
44+
45+
decomposed_node_to_subgraph_node = {}
46+
for decomposed_node in decomposed_module.graph.nodes:
47+
if decomposed_node.op == "placeholder":
48+
decomposed_node_to_subgraph_node[decomposed_node] = (
49+
name_to_input_tensor_map[decomposed_node.name]
50+
)
51+
52+
# Copy node from decompose graph module
53+
for decomposed_node in decomposed_module.graph.nodes:
54+
if decomposed_node.op == "placeholder":
55+
continue
56+
if decomposed_node.op == "output":
57+
for user in node.users.copy():
58+
new_node = decomposed_node_to_subgraph_node[
59+
decomposed_node.args[0]
60+
]
61+
user.replace_input_with(node, new_node)
62+
continue
63+
64+
subgraph_node = graph.node_copy(
65+
decomposed_node,
66+
arg_transform=lambda x, d=decomposed_node_to_subgraph_node: d[
67+
x
68+
],
69+
)
70+
subgraph_node.meta["source_fn_stack"] = [
71+
(subgraph_node, subgraph_node.target)
72+
]
73+
decomposed_node_to_subgraph_node[decomposed_node] = subgraph_node
74+
75+
graph.erase_node(node)
76+
77+
graph.eliminate_dead_code()
78+
graph_module.recompile()
79+
return PassResult(graph_module, True)

backends/mediatek/preprocess.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright (c) 2024 MediaTek Inc.
2+
#
3+
# Licensed under the BSD License (the "License"); you may not use this file
4+
# except in compliance with the License. See the license file in the root
5+
# directory of this source tree for more details.
6+
7+
import contextlib
8+
import struct
9+
10+
from typing import final, List
11+
12+
import mtk_converter
13+
import mtk_neuron
14+
import torch
15+
from executorch.exir.backend.backend_details import (
16+
BackendDetails,
17+
ExportedProgram,
18+
PreprocessResult,
19+
)
20+
from executorch.exir.backend.compile_spec_schema import CompileSpec
21+
22+
SKIP_COMPILE_SPEC_KEYS = {"ImportForever"}
23+
24+
25+
@final
26+
class NeuropilotBackend(BackendDetails):
27+
28+
@classmethod
29+
def preprocess(
30+
cls, edge_program: ExportedProgram, module_compile_spec: List[CompileSpec]
31+
) -> PreprocessResult:
32+
33+
name_to_node_mappings = {node.name: node for node in edge_program.graph.nodes}
34+
input_names = edge_program.graph_signature.user_inputs
35+
output_names = edge_program.graph_signature.user_outputs
36+
fp_input_indices = [
37+
idx
38+
for idx, name in enumerate(input_names)
39+
if name_to_node_mappings[name].meta["val"].dtype == torch.float32
40+
]
41+
fp_output_indices = [
42+
idx
43+
for idx, name in enumerate(output_names)
44+
if name_to_node_mappings[name].meta["val"].dtype == torch.float32
45+
]
46+
47+
# This default compile options are only for mt6989 SOC
48+
compile_options = ["--arch=mdla5.1,edpa1.0", "--relax-fp32", "--opt=3"]
49+
for spec in module_compile_spec:
50+
if spec.key in SKIP_COMPILE_SPEC_KEYS:
51+
continue
52+
if spec.value == b"":
53+
compile_options.append(f"--{spec.key}")
54+
else:
55+
value = spec.value.decode("utf-8")
56+
compile_options.append(f"--{spec.key}={value}")
57+
58+
converter = mtk_converter.PyTorchV2Converter.from_exported_program(edge_program)
59+
converter.quantize = True
60+
converter.input_quantization_bitwidths = None
61+
converter.allow_missing_quantization_ranges = True
62+
converter.prepend_input_quantize_ops = True
63+
converter.prepend_input_quantize_ops_indices = fp_input_indices
64+
converter.append_output_dequantize_ops = True
65+
converter.append_output_dequantize_ops_indices = fp_output_indices
66+
with contextlib.redirect_stdout(None):
67+
mlir_str = converter.convert_to_mlir()
68+
model_bytes = mtk_neuron.compile(mlir_str, " ".join(compile_options))
69+
70+
num_inputs = len(input_names)
71+
num_outputs = len(output_names)
72+
header = struct.pack("<BIII", 1, num_inputs, num_outputs, len(model_bytes))
73+
return PreprocessResult(processed_bytes=bytes(header + model_bytes))
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .qconfig import Precision
2+
from .quantizer import NeuropilotQuantizer
3+
4+
__all__ = [NeuropilotQuantizer, Precision]

0 commit comments

Comments
 (0)