Skip to content

Initial Implementation of MediaTek Backend for Executorch #3571

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 42 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
cce2ef1
MediaTek Neuron ExecuTorch Backend
neuropilot-captain May 8, 2024
c2fefd0
Set NeuronDelegateSetting from compile specs
neuropilot-captain May 10, 2024
eac37ac
Add MediaTek Llama runner
neuropilot-captain May 10, 2024
345181c
Merge branch 'pytorch:main' into main
neuropilot-captain May 17, 2024
baa359b
Fix typos and correct variable references in MediaTek backend
neuropilot-captain May 22, 2024
a2d505c
Update README.md
neuropilot-captain May 22, 2024
04f14b9
Update README.md
neuropilot-captain May 22, 2024
d671511
Update README.md
neuropilot-captain May 23, 2024
dc59f82
Update README.md
neuropilot-captain May 23, 2024
3fb9461
Update README.md
neuropilot-captain May 23, 2024
2524c11
Merge branch 'pytorch:main' into main
neuropilot-captain May 23, 2024
22d0e16
Merge branch 'pytorch:main' into main
neuropilot-captain May 29, 2024
151e6fc
Refactor NeuronBackend constants and registration
neuropilot-captain May 29, 2024
216c680
Fix comment
neuropilot-captain Jun 4, 2024
e8e1f52
Add MTK AoT backend and AoT Flow with llama as example
neuropilot-captain Jul 5, 2024
1b074ce
Merge branch 'pytorch:main' into main
neuropilot-captain Jul 9, 2024
527971e
Fix 1t model export bug
neuropilot-captain Jul 14, 2024
a56aeef
Move AOT code to `backends/mediatek` and add skip ops mechanism
neuropilot-captain Jul 16, 2024
7db85b8
Add calibration flow and update compile options
neuropilot-captain Jul 22, 2024
af61e4f
Update llama runner
neuropilot-captain Jul 23, 2024
09aea5d
Update llama runner sample run script
neuropilot-captain Jul 24, 2024
ec9f7a2
Add embedding bin dumping for cmdline
neuropilot-captain Jul 24, 2024
5c81d7f
Upload llama3 8B instruct model config and tokenizer config files for…
neuropilot-captain Jul 25, 2024
863d698
Update README.md
neuropilot-captain Jul 26, 2024
5ecc00b
Add `op_names_to_skip` argument to NeuropilotPartitioner
neuropilot-captain Jul 29, 2024
9c53f4b
Add script to build MediaTek examples.
Aug 5, 2024
b2d116e
Leverage temp allocator for memory management
Aug 6, 2024
208bfce
Update README and remove mtk_neuron and mtk_converter from requiremen…
neuropilot-captain Aug 6, 2024
6c1a0a9
Respect the model input/output data types (quantized or float) in Neu…
neuropilot-captain Aug 6, 2024
5de3a4a
Update README.md
neuropilot-captain Aug 7, 2024
263a92a
Replace delete with destructor in backend destroy
Aug 6, 2024
5dcdb83
Fix neuron backend linked library
neuropilot-captain Aug 8, 2024
b2c155c
Use temp allocator in neuron backend
neuropilot-captain Aug 8, 2024
04fb790
Add missing module import for `FakeQuantize` class
neuropilot-captain Aug 9, 2024
68f72a5
Fix Python linter error
neuropilot-captain Aug 9, 2024
37a56bf
Update mask_builder.h
neuropilot-captain Aug 9, 2024
58a75f8
Fix lint errors of newlines
Aug 9, 2024
9f97b3a
fix tokenizer lint error
neuropilot-captain Aug 12, 2024
2bc3dee
Apply lintrunner patches
neuropilot-captain Aug 12, 2024
1f9b610
fix argument mismatch
neuropilot-captain Aug 12, 2024
8079d71
Merge branch 'pytorch:main' into main
neuropilot-captain Aug 13, 2024
7786412
Merge branch 'pytorch:main' into main
neuropilot-captain Aug 14, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ option(EXECUTORCH_BUILD_GTESTS "Build googletest based test binaries" OFF)

option(EXECUTORCH_BUILD_MPS "Build the MPS backend" OFF)

option(EXECUTORCH_BUILD_NEURON "Build the backends/mediatek directory" OFF)

option(EXECUTORCH_BUILD_PYBIND "Build the Python Bindings" OFF)

option(EXECUTORCH_BUILD_QNN "Build the Qualcomm backend" OFF)
Expand Down Expand Up @@ -624,6 +626,10 @@ if(EXECUTORCH_BUILD_EXTENSION_MODULE)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/module)
endif()

if(EXECUTORCH_BUILD_NEURON)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/mediatek)
endif()

if(EXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/runner_util)
endif()
Expand Down
1 change: 1 addition & 0 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Copyright (c) Meta Platforms, Inc. and affiliates.
Copyright 2023 Arm Limited and/or its affiliates.
Copyright (c) Qualcomm Innovation Center, Inc.
Copyright (c) 2023 Apple Inc.
Copyright (c) 2024 MediaTek Inc.

Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:
Expand Down
50 changes: 50 additions & 0 deletions backends/mediatek/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#[[
/*
* Copyright (c) 2024 MediaTek Inc.
*
* Licensed under the BSD License (the "License"); you may not use this file
* except in compliance with the License. See the license file in the root
* directory of this source tree for more details.
*/
]]

# Let include directory as "executorch/..."
set(_common_include_directories ${CMAKE_CURRENT_SOURCE_DIR}/../../..)
set(NEURON_BUFFER_ALLOCATOR_LIB "" CACHE PATH "Path to Neuron Buffer Allocator library")
message(STATUS "Looking for neuron_buffer_allocator in ${NEURON_BUFFER_ALLOCATOR_LIB}")

include_directories(
BEFORE
${_common_include_directories}
)

# shortcut include directory for neuron headers
include_directories(
BEFORE
${CMAKE_CURRENT_SOURCE_DIR}/runtime/include
)

# targets
add_library(neuron_backend SHARED)
target_link_libraries(neuron_backend
PRIVATE
executorch_no_prim_ops
android
log
${NEURON_BUFFER_ALLOCATOR_LIB}
)
target_sources(neuron_backend
INTERFACE
${CMAKE_CURRENT_LIST_DIR}/runtime/include/NeuronBackend.h
${CMAKE_CURRENT_LIST_DIR}/runtime/include/NeuronBufferAllocator.h
${CMAKE_CURRENT_LIST_DIR}/runtime/include/NeuronExecutor.h
${CMAKE_CURRENT_LIST_DIR}/runtime/include/NeuronLog.h
${CMAKE_CURRENT_LIST_DIR}/runtime/include/api/APUWareUtilsLib.h
${CMAKE_CURRENT_LIST_DIR}/runtime/include/api/NeuronAdapterShim.h
PRIVATE
${CMAKE_CURRENT_LIST_DIR}/runtime/NeuronBackend.cpp
${CMAKE_CURRENT_LIST_DIR}/runtime/NeuronExecutor.cpp
)
target_link_options_shared_lib(neuron_backend)

install(TARGETS neuron_backend DESTINATION lib)
5 changes: 5 additions & 0 deletions backends/mediatek/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .partitioner import NeuropilotPartitioner
from .preprocess import NeuropilotBackend
from .quantizer import NeuropilotQuantizer, Precision

__all__ = [NeuropilotBackend, NeuropilotPartitioner, NeuropilotQuantizer, Precision]
101 changes: 101 additions & 0 deletions backends/mediatek/partitioner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright (c) 2024 MediaTek Inc.
#
# Licensed under the BSD License (the "License"); you may not use this file
# except in compliance with the License. See the license file in the root
# directory of this source tree for more details.

from typing import Callable, final, List, Optional, Tuple

import torch
from executorch.backends.mediatek.preprocess import NeuropilotBackend
from executorch.exir.backend.backend_details import CompileSpec
from executorch.exir.backend.partitioner import (
DelegationSpec,
Partitioner,
PartitionResult,
)
from executorch.exir.backend.utils import tag_constant_data

from mtk_converter.python.converters.pytorch import importer_v2
from torch.export.exported_program import ExportedProgram
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.fx.passes.operator_support import OperatorSupportBase


class NeuropilotOperatorsSupport(OperatorSupportBase):

def __init__(
self,
op_types_to_skip: Optional[set] = None,
op_names_to_skip: Optional[set] = None,
) -> None:
if op_types_to_skip is None:
op_types_to_skip = set()
if op_names_to_skip is None:
op_names_to_skip = set()

self._op_types_to_skip = op_types_to_skip
self._op_names_to_skip = op_names_to_skip

def is_node_supported(self, _, node: torch.fx.Node) -> bool:
# Handle 'call_function' only cause 'placeholder' and 'output' cannot be tagged.
# Ref: https://github.com/pytorch/executorch/pull/1398
if node.op != "call_function":
return False

op_type = node.target.__name__
if op_type in self._op_types_to_skip or node.name in self._op_names_to_skip:
print(
f"[Neuropilot Backend] The {op_type} operator with name '{node.name}' is skipped."
)
return False

return importer_v2.is_fx_node_supported(node)


@final
class NeuropilotPartitioner(Partitioner):

def __init__(
self,
compile_spec: List[CompileSpec],
op_types_to_skip: Optional[set] = None,
op_names_to_skip: Optional[set] = None,
) -> None:
self.delegation_spec = DelegationSpec(NeuropilotBackend.__name__, compile_spec)
self._op_types_to_skip = op_types_to_skip
self._op_names_to_skip = op_names_to_skip

def ops_to_not_decompose(
self,
ep: ExportedProgram,
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
ops_not_decompose = [
torch.ops.aten.pixel_shuffle.default,
torch.ops.aten.upsample_bilinear2d.default,
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.upsample_nearest2d.default,
torch.ops.aten.upsample_nearest2d.vec,
]
return (ops_not_decompose, None)

def partition(self, exported_program: ExportedProgram) -> PartitionResult:
capability_partitioner = CapabilityBasedPartitioner(
exported_program.graph_module,
NeuropilotOperatorsSupport(self._op_types_to_skip, self._op_names_to_skip),
allows_single_node_partition=True,
)
partition_list = capability_partitioner.propose_partitions()

partition_tags = {}
for partition in partition_list:
for node in partition.nodes:
tag = f"tag{partition.id}"
node.meta["delegation_tag"] = tag
partition_tags[tag] = self.delegation_spec

tag_constant_data(exported_program)

return PartitionResult(
tagged_exported_program=exported_program, partition_tags=partition_tags
)
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) 2024 MediaTek Inc.
#
# Licensed under the BSD License (the "License"); you may not use this file
# except in compliance with the License. See the license file in the root
# directory of this source tree for more details.

import torch

from executorch.exir.pass_base import ExportPass, PassResult
from torch._decomp import get_decompositions
from torch.fx import Graph
from torch.fx.experimental.proxy_tensor import make_fx


def _get_input_node_names(graph: Graph):
input_names = []
for node in graph.nodes:
if node.op == "placeholder":
input_names.append(node.name)
return input_names


class DecomposeScaledDotProductAttention(ExportPass):
"""Decompose the single SDPA operator."""

def call(self, graph_module: torch.fx.GraphModule):
graph = graph_module.graph
for node in graph.nodes:
if node.target != torch.ops.aten.scaled_dot_product_attention.default:
continue

decom_mappings = get_decompositions(
[torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default]
)
input_tensors = (arg.meta["val"] for arg in node.args)
decomposed_module = make_fx(node.target, decom_mappings, "fake", True)(
*input_tensors
)
decomposed_input_names = _get_input_node_names(decomposed_module.graph)
with graph.inserting_before(node):
name_to_input_tensor_map = {}
for idx, arg in enumerate(node.args):
name_to_input_tensor_map[decomposed_input_names[idx]] = arg

decomposed_node_to_subgraph_node = {}
for decomposed_node in decomposed_module.graph.nodes:
if decomposed_node.op == "placeholder":
decomposed_node_to_subgraph_node[decomposed_node] = (
name_to_input_tensor_map[decomposed_node.name]
)

# Copy node from decompose graph module
for decomposed_node in decomposed_module.graph.nodes:
if decomposed_node.op == "placeholder":
continue
if decomposed_node.op == "output":
for user in node.users.copy():
new_node = decomposed_node_to_subgraph_node[
decomposed_node.args[0]
]
user.replace_input_with(node, new_node)
continue

subgraph_node = graph.node_copy(
decomposed_node,
arg_transform=lambda x, d=decomposed_node_to_subgraph_node: d[
x
],
)
subgraph_node.meta["source_fn_stack"] = [
(subgraph_node, subgraph_node.target)
]
decomposed_node_to_subgraph_node[decomposed_node] = subgraph_node

graph.erase_node(node)

graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)
73 changes: 73 additions & 0 deletions backends/mediatek/preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (c) 2024 MediaTek Inc.
#
# Licensed under the BSD License (the "License"); you may not use this file
# except in compliance with the License. See the license file in the root
# directory of this source tree for more details.

import contextlib
import struct

from typing import final, List

import mtk_converter
import mtk_neuron
import torch
from executorch.exir.backend.backend_details import (
BackendDetails,
ExportedProgram,
PreprocessResult,
)
from executorch.exir.backend.compile_spec_schema import CompileSpec

SKIP_COMPILE_SPEC_KEYS = {"ImportForever"}


@final
class NeuropilotBackend(BackendDetails):

@classmethod
def preprocess(
cls, edge_program: ExportedProgram, module_compile_spec: List[CompileSpec]
) -> PreprocessResult:

name_to_node_mappings = {node.name: node for node in edge_program.graph.nodes}
input_names = edge_program.graph_signature.user_inputs
output_names = edge_program.graph_signature.user_outputs
fp_input_indices = [
idx
for idx, name in enumerate(input_names)
if name_to_node_mappings[name].meta["val"].dtype == torch.float32
]
fp_output_indices = [
idx
for idx, name in enumerate(output_names)
if name_to_node_mappings[name].meta["val"].dtype == torch.float32
]

# This default compile options are only for mt6989 SOC
compile_options = ["--arch=mdla5.1,edpa1.0", "--relax-fp32", "--opt=3"]
for spec in module_compile_spec:
if spec.key in SKIP_COMPILE_SPEC_KEYS:
continue
if spec.value == b"":
compile_options.append(f"--{spec.key}")
else:
value = spec.value.decode("utf-8")
compile_options.append(f"--{spec.key}={value}")

converter = mtk_converter.PyTorchV2Converter.from_exported_program(edge_program)
converter.quantize = True
converter.input_quantization_bitwidths = None
converter.allow_missing_quantization_ranges = True
converter.prepend_input_quantize_ops = True
converter.prepend_input_quantize_ops_indices = fp_input_indices
converter.append_output_dequantize_ops = True
converter.append_output_dequantize_ops_indices = fp_output_indices
with contextlib.redirect_stdout(None):
mlir_str = converter.convert_to_mlir()
model_bytes = mtk_neuron.compile(mlir_str, " ".join(compile_options))

num_inputs = len(input_names)
num_outputs = len(output_names)
header = struct.pack("<BIII", 1, num_inputs, num_outputs, len(model_bytes))
return PreprocessResult(processed_bytes=bytes(header + model_bytes))
4 changes: 4 additions & 0 deletions backends/mediatek/quantizer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .qconfig import Precision
from .quantizer import NeuropilotQuantizer

__all__ = [NeuropilotQuantizer, Precision]
Loading
Loading