Skip to content

Commit b0e49b3

Browse files
guangy10facebook-github-bot
authored andcommitted
Add Inception V3 model to examples (#74)
Summary: Pull Request resolved: #74 Add Inception V3 model to executorch/examples. Info about Inception V3: https://pytorch.org/hub/pytorch_vision_inception_v3/ Differential Revision: D48439838 fbshipit-source-id: c2e5e3d8cbc637a99a50fec3fcba05aa89663d9c
1 parent c0ef2cc commit b0e49b3

File tree

6 files changed

+71
-0
lines changed

6 files changed

+71
-0
lines changed

examples/export/test/test_export.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,11 @@ def test_w2l_export_to_executorch(self):
8787
self._assert_eager_lowered_same_result(
8888
eager_model, example_inputs, self.validate_tensor_allclose
8989
)
90+
91+
def test_ic3_export_to_executorch(self):
92+
eager_model, example_inputs = MODEL_NAME_TO_MODEL["ic3"]()
93+
eager_model = eager_model.eval()
94+
95+
self._assert_eager_lowered_same_result(
96+
eager_model, example_inputs, self.validate_tensor_allclose
97+
)

examples/models/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ python_library(
88
],
99
deps = [
1010
"//caffe2:torch",
11+
"//executorch/examples/models/inception_v3:ic3_export",
1112
"//executorch/examples/models/mobilenet_v2:mv2_export",
1213
"//executorch/examples/models/mobilenet_v3:mv3_export",
1314
"//executorch/examples/models/torchvision_vit:vit_export",

examples/models/inception_v3/TARGETS

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
2+
3+
python_library(
4+
name = "ic3_export",
5+
srcs = [
6+
"__init__.py",
7+
"export.py",
8+
],
9+
base_module = "executorch.examples.models.inception_v3",
10+
deps = [
11+
"//caffe2:torch",
12+
"//pytorch/vision:torchvision",
13+
],
14+
)
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .export import InceptionV3Model
8+
9+
__all__ = [
10+
InceptionV3Model,
11+
]
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
9+
import torch
10+
from torchvision import models
11+
12+
FORMAT = "[%(filename)s:%(lineno)s] %(message)s"
13+
logging.basicConfig(format=FORMAT)
14+
15+
16+
class InceptionV3Model:
17+
def __init__(self):
18+
pass
19+
20+
@staticmethod
21+
def get_model():
22+
logging.info("loading torchvision inception_v3 model")
23+
inception_v3 = models.inception_v3(weights="IMAGENET1K_V1")
24+
logging.info("loaded torchvision inception_v3 model")
25+
return inception_v3
26+
27+
@staticmethod
28+
def get_example_inputs():
29+
input_shape = (1, 3, 224, 224)
30+
return (torch.randn(input_shape),)

examples/models/models.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,12 @@ def gen_wav2letter_model_and_inputs() -> Tuple[torch.nn.Module, Any]:
102102
return model.get_model(), model.get_example_inputs()
103103

104104

105+
def gen_inception_v3_model_and_inputs() -> Tuple[torch.nn.Module, Any]:
106+
from ..models.inception_v3 import InceptionV3Model
107+
108+
return InceptionV3Model.get_model(), InceptionV3Model.get_example_inputs()
109+
110+
105111
MODEL_NAME_TO_MODEL = {
106112
"mul": lambda: (MulModule(), MulModule.get_example_inputs()),
107113
"linear": lambda: (LinearModule(), LinearModule.get_example_inputs()),
@@ -111,4 +117,5 @@ def gen_wav2letter_model_and_inputs() -> Tuple[torch.nn.Module, Any]:
111117
"mv3": gen_mobilenet_v3_model_inputs,
112118
"vit": gen_torchvision_vit_model_and_inputs,
113119
"w2l": gen_wav2letter_model_and_inputs,
120+
"ic3": gen_inception_v3_model_and_inputs,
114121
}

0 commit comments

Comments
 (0)