Skip to content

Commit 6b94d87

Browse files
Support more open-source models
1. Add export scripts for new models 2. Fix build scripts
1 parent 012f571 commit 6b94d87

File tree

19 files changed

+638
-16
lines changed

19 files changed

+638
-16
lines changed

backends/mediatek/partitioner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def ops_to_not_decompose(
8181
torch.ops.aten.upsample_bilinear2d.vec,
8282
torch.ops.aten.upsample_nearest2d.default,
8383
torch.ops.aten.upsample_nearest2d.vec,
84+
torch.ops.aten._safe_softmax.default,
8485
]
8586
return (ops_not_decompose, None)
8687

backends/mediatek/scripts/mtk_build.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ rm -rf cmake-android-out && mkdir cmake-android-out && cd cmake-android-out
3333
cmake -DBUCK2="$BUCK_PATH" \
3434
-DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK/build/cmake/android.toolchain.cmake" \
3535
-DANDROID_ABI=arm64-v8a \
36+
-DANDROID_PLATFORM=android-26 \
3637
-DEXECUTORCH_BUILD_NEURON=ON \
3738
-DNEURON_BUFFER_ALLOCATOR_LIB="$NEURON_BUFFER_ALLOCATOR_LIB" \
3839
..

examples/mediatek/aot_utils/oss_utils/utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import os
8-
from typing import Optional
8+
from typing import Optional, Dict
99

1010
import torch
1111
from executorch import exir
@@ -24,6 +24,8 @@ def build_executorch_binary(
2424
file_name,
2525
dataset,
2626
quant_dtype: Optional[Precision] = None,
27+
skip_op_name: Optional[set] = None,
28+
skip_op_type: Optional[set] = None
2729
):
2830
if quant_dtype is not None:
2931
quantizer = NeuropilotQuantizer()
@@ -47,14 +49,12 @@ def build_executorch_binary(
4749
from executorch.exir.program._program import to_edge_transform_and_lower
4850

4951
edge_compile_config = exir.EdgeCompileConfig(_check_ir_validity=False)
50-
# skipped op names are used for deeplabV3 model
5152
neuro_partitioner = NeuropilotPartitioner(
5253
[CompileSpec("platform-config", b"mt6989")],
53-
op_names_to_skip={
54-
"aten_convolution_default_106",
55-
"aten_convolution_default_107",
56-
},
54+
op_types_to_skip=skip_op_type,
55+
op_names_to_skip=skip_op_name,
5756
)
57+
5858
edge_prog = to_edge_transform_and_lower(
5959
aten_dialect,
6060
compile_config=edge_compile_config,
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright (c) MediaTek Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import os
7+
import sys
8+
9+
if os.getcwd() not in sys.path:
10+
sys.path.append(os.getcwd())
11+
import json
12+
import os
13+
import numpy as np
14+
import argparse
15+
16+
import torch
17+
import dcgan_main
18+
from executorch.backends.mediatek import Precision
19+
from aot_utils.oss_utils.utils import (
20+
build_executorch_binary,
21+
make_output_dir,
22+
)
23+
24+
25+
class NhwcWrappedModel(torch.nn.Module):
26+
def __init__(self, is_gen=True):
27+
super(NhwcWrappedModel, self).__init__()
28+
if is_gen:
29+
self.dcgan = dcgan_main.Generator()
30+
else:
31+
self.dcgan = dcgan_main.Discriminator()
32+
33+
def forward(self, input1):
34+
nchw_input1 = input1.permute(0, 3, 1, 2)
35+
output = self.dcgan(nchw_input1)
36+
return output
37+
38+
39+
if __name__ == "__main__":
40+
parser = argparse.ArgumentParser()
41+
42+
parser.add_argument(
43+
"-a",
44+
"--artifact",
45+
help="path for storing generated artifacts by this example. "
46+
"Default ./dcgan",
47+
default="./dcgan",
48+
type=str,
49+
)
50+
51+
args = parser.parse_args()
52+
53+
# ensure the working directory exist.
54+
os.makedirs(args.artifact, exist_ok=True)
55+
56+
# prepare dummy data
57+
inputG = torch.randn(1, 1, 1, 100)
58+
inputD = torch.randn(1, 64, 64, 3)
59+
60+
# build Generator
61+
netG_instance = NhwcWrappedModel(True)
62+
netG_pte_filename = "dcgan_netG_mtk"
63+
build_executorch_binary(
64+
netG_instance.eval(),
65+
(torch.randn(1, 1, 1, 100),),
66+
f"{args.artifact}/{netG_pte_filename}",
67+
[(inputG,)],
68+
quant_dtype=Precision.A8W8,
69+
)
70+
71+
# build Discriminator
72+
netD_instance = NhwcWrappedModel(False)
73+
netD_pte_filename = "dcgan_netD_mtk"
74+
build_executorch_binary(
75+
netD_instance.eval(),
76+
(torch.randn(1, 64, 64, 3),),
77+
f"{args.artifact}/{netD_pte_filename}",
78+
[(inputD,)],
79+
quant_dtype=Precision.A8W8,
80+
)
81+
82+
# save data to inference on device
83+
input_list_file = f"{args.artifact}/input_list_G.txt"
84+
with open(input_list_file, "w") as f:
85+
f.write("inputG_0_0.bin")
86+
f.flush()
87+
file_name = f"{args.artifact}/inputG_0_0.bin"
88+
inputG.detach().numpy().tofile(file_name)
89+
file_name = f"{args.artifact}/goldenG_0_0.bin"
90+
goldenG = netG_instance(inputG)
91+
goldenG.detach().numpy().tofile(file_name)
92+
93+
input_list_file = f"{args.artifact}/input_list_D.txt"
94+
with open(input_list_file, "w") as f:
95+
f.write("inputD_0_0.bin")
96+
f.flush()
97+
file_name = f"{args.artifact}/inputD_0_0.bin"
98+
inputD.detach().numpy().tofile(file_name)
99+
file_name = f"{args.artifact}/goldenD_0_0.bin"
100+
goldenD = netD_instance(inputD)
101+
goldenD.detach().numpy().tofile(file_name)
102+
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
"""Ref https://github.com/pytorch/examples/blob/main/dcgan/main.py"""
2+
3+
import torch.nn as nn
4+
5+
6+
class Generator(nn.Module):
7+
def __init__(self):
8+
super().__init__()
9+
self.main = nn.Sequential(
10+
# input is Z, going into a convolution
11+
nn.ConvTranspose2d(100, 64 * 8, 4, 1, 0, bias=False),
12+
nn.BatchNorm2d(64 * 8),
13+
nn.ReLU(True),
14+
# state size. (64*8) x 4 x 4
15+
nn.ConvTranspose2d(64 * 8, 64 * 4, 4, 2, 1, bias=False),
16+
nn.BatchNorm2d(64 * 4),
17+
nn.ReLU(True),
18+
# state size. (64*4) x 8 x 8
19+
nn.ConvTranspose2d(64 * 4, 64 * 2, 4, 2, 1, bias=False),
20+
nn.BatchNorm2d(64 * 2),
21+
nn.ReLU(True),
22+
# state size. (64*2) x 16 x 16
23+
nn.ConvTranspose2d(64 * 2, 64, 4, 2, 1, bias=False),
24+
nn.BatchNorm2d(64),
25+
nn.ReLU(True),
26+
# state size. (64) x 32 x 32
27+
nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
28+
nn.Tanh()
29+
# state size. (3) x 64 x 64
30+
)
31+
32+
def forward(self, input):
33+
output = self.main(input)
34+
return output
35+
36+
# main_netG_input_shape = [1, 100, 1, 1]
37+
# model = Generator()
38+
39+
40+
class Discriminator(nn.Module):
41+
def __init__(self):
42+
super().__init__()
43+
self.main = nn.Sequential(
44+
# input is (3) x 64 x 64
45+
nn.Conv2d(3, 64, 4, 2, 1, bias=False),
46+
nn.LeakyReLU(0.2, inplace=True),
47+
# state size. (64) x 32 x 32
48+
nn.Conv2d(64, 64 * 2, 4, 2, 1, bias=False),
49+
nn.BatchNorm2d(64 * 2),
50+
nn.LeakyReLU(0.2, inplace=True),
51+
# state size. (64*2) x 16 x 16
52+
nn.Conv2d(64 * 2, 64 * 4, 4, 2, 1, bias=False),
53+
nn.BatchNorm2d(64 * 4),
54+
nn.LeakyReLU(0.2, inplace=True),
55+
# state size. (64*4) x 8 x 8
56+
nn.Conv2d(64 * 4, 64 * 8, 4, 2, 1, bias=False),
57+
nn.BatchNorm2d(64 * 8),
58+
nn.LeakyReLU(0.2, inplace=True),
59+
# state size. (64*8) x 4 x 4
60+
nn.Conv2d(64 * 8, 1, 4, 1, 0, bias=False),
61+
nn.Sigmoid()
62+
)
63+
64+
def forward(self, input):
65+
output = self.main(input)
66+
67+
return output.view(-1, 1).squeeze(1)
68+
69+
# main_netD_input_shape = [1, 3, 64, 64]
70+
# model = Discriminator()

examples/mediatek/model_export_scripts/deeplab_v3.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,19 @@
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
6+
import os
7+
import sys
68

9+
if os.getcwd() not in sys.path:
10+
sys.path.append(os.getcwd())
711
import argparse
8-
import os
912
import random
1013

1114
import numpy as np
1215

1316
import torch
1417
from executorch.backends.mediatek import Precision
15-
from executorch.examples.mediatek.aot_utils.oss_utils.utils import (
18+
from aot_utils.oss_utils.utils import (
1619
build_executorch_binary,
1720
)
1821
from executorch.examples.models.deeplab_v3 import DeepLabV3ResNet101Model
@@ -26,7 +29,7 @@ def __init__(self):
2629
def forward(self, input1):
2730
nchw_input1 = input1.permute(0, 3, 1, 2)
2831
nchw_output = self.deeplabv3(nchw_input1)
29-
return nchw_output.permute(0, 2, 3, 1)
32+
return nchw_output
3033

3134

3235
def get_dataset(data_size, dataset_dir, download):
@@ -121,4 +124,8 @@ def get_dataset(data_size, dataset_dir, download):
121124
f"{args.artifact}/{pte_filename}",
122125
inputs,
123126
quant_dtype=Precision.A8W8,
127+
skip_op_name = {
128+
"aten_convolution_default_106",
129+
"aten_convolution_default_107",
130+
},
124131
)

examples/mediatek/model_export_scripts/edsr.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,15 @@
66

77
import argparse
88
import os
9+
import sys
910

11+
if os.getcwd() not in sys.path:
12+
sys.path.append(os.getcwd())
1013
import numpy as np
1114

1215
import torch
1316
from executorch.backends.mediatek import Precision
14-
from executorch.examples.mediatek.aot_utils.oss_utils.utils import (
17+
from aot_utils.oss_utils.utils import (
1518
build_executorch_binary,
1619
)
1720
from executorch.examples.models.edsr import EdsrModel

0 commit comments

Comments
 (0)