Skip to content

Commit 73c6cb3

Browse files
kirklandsignfacebook-github-bot
authored andcommitted
Add Inception V4 model to examples (#86)
Summary: Pull Request resolved: #86 The model is from pypi/timm. Reviewed By: guangy10 Differential Revision: D48524577 fbshipit-source-id: 8d8eb302a9ca5cc66ba8056391a1826cbb2b7a61
1 parent 4885608 commit 73c6cb3

File tree

5 files changed

+63
-0
lines changed

5 files changed

+63
-0
lines changed

examples/models/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ python_library(
99
deps = [
1010
"//caffe2:torch",
1111
"//executorch/examples/models/inception_v3:ic3_export",
12+
"//executorch/examples/models/inception_v4:ic4_export",
1213
"//executorch/examples/models/mobilenet_v2:mv2_export",
1314
"//executorch/examples/models/mobilenet_v3:mv3_export",
1415
"//executorch/examples/models/torchvision_vit:vit_export",

examples/models/inception_v4/TARGETS

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
2+
3+
python_library(
4+
name = "ic4_export",
5+
srcs = [
6+
"__init__.py",
7+
"export.py",
8+
],
9+
base_module = "executorch.examples.models.inception_v4",
10+
deps = [
11+
"fbsource//third-party/pypi/timm:timm",
12+
"//caffe2:torch",
13+
],
14+
)
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .export import InceptionV4Model
8+
9+
__all__ = [
10+
InceptionV4Model,
11+
]
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
9+
import torch
10+
from timm.models import inception_v4
11+
12+
FORMAT = "[%(filename)s:%(lineno)s] %(message)s"
13+
logging.basicConfig(format=FORMAT)
14+
15+
16+
# will refactor this in a separate file.
17+
class InceptionV4Model:
18+
def __init__(self):
19+
pass
20+
21+
@staticmethod
22+
def get_model():
23+
logging.info("loading inception_v4 model")
24+
m = inception_v4(pretrained=True)
25+
logging.info("loaded inception_v4 model")
26+
return m
27+
28+
@staticmethod
29+
def get_example_inputs():
30+
return (torch.randn(3, 299, 299).unsqueeze(0),)

examples/models/models.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,12 @@ def gen_inception_v3_model_and_inputs() -> Tuple[torch.nn.Module, Any]:
108108
return InceptionV3Model.get_model(), InceptionV3Model.get_example_inputs()
109109

110110

111+
def gen_inception_v4_model_and_inputs() -> Tuple[torch.nn.Module, Any]:
112+
from ..models.inception_v4 import InceptionV4Model
113+
114+
return InceptionV4Model.get_model(), InceptionV4Model.get_example_inputs()
115+
116+
111117
MODEL_NAME_TO_MODEL = {
112118
"mul": lambda: (MulModule(), MulModule.get_example_inputs()),
113119
"linear": lambda: (LinearModule(), LinearModule.get_example_inputs()),
@@ -118,4 +124,5 @@ def gen_inception_v3_model_and_inputs() -> Tuple[torch.nn.Module, Any]:
118124
"vit": gen_torchvision_vit_model_and_inputs,
119125
"w2l": gen_wav2letter_model_and_inputs,
120126
"ic3": gen_inception_v3_model_and_inputs,
127+
"ic4": gen_inception_v4_model_and_inputs,
121128
}

0 commit comments

Comments
 (0)