Skip to content

Commit 1f903c4

Browse files
kirklandsignfacebook-github-bot
authored andcommitted
Enable ResNet-18 (#107)
Summary: Pull Request resolved: #107 Reviewed By: guangy10 Differential Revision: D48591055 fbshipit-source-id: f815b8b31d19f2b26714d4e0484c0ade7eb61dac
1 parent 2d3685f commit 1f903c4

File tree

6 files changed

+38
-5
lines changed

6 files changed

+38
-5
lines changed

examples/export/test/test_export.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,14 @@ def test_ic3_export_to_executorch(self):
9696
eager_model, example_inputs, self.validate_tensor_allclose
9797
)
9898

99+
def test_resnet18_export_to_executorch(self):
100+
eager_model, example_inputs = MODEL_NAME_TO_MODEL["resnet18"]()
101+
eager_model = eager_model.eval()
102+
103+
self._assert_eager_lowered_same_result(
104+
eager_model, example_inputs, self.validate_tensor_allclose
105+
)
106+
99107
def test_resnet50_export_to_executorch(self):
100108
eager_model, example_inputs = MODEL_NAME_TO_MODEL["resnet50"]()
101109
eager_model = eager_model.eval()

examples/models/TARGETS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ python_library(
1212
"//executorch/examples/models/inception_v4:ic4_export",
1313
"//executorch/examples/models/mobilenet_v2:mv2_export",
1414
"//executorch/examples/models/mobilenet_v3:mv3_export",
15-
"//executorch/examples/models/resnet50:resnet50_export",
15+
"//executorch/examples/models/resnet:resnet_export",
1616
"//executorch/examples/models/torchvision_vit:vit_export",
1717
"//executorch/examples/models/wav2letter:w2l_export",
1818
"//executorch/exir/backend:compile_spec_schema",

examples/models/models.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,14 @@ def gen_inception_v4_model_and_inputs() -> Tuple[torch.nn.Module, Any]:
114114
return InceptionV4Model.get_model(), InceptionV4Model.get_example_inputs()
115115

116116

117+
def gen_resnet18_model_and_inputs() -> Tuple[torch.nn.Module, Any]:
118+
from ..models.resnet import ResNet18Model
119+
120+
return ResNet18Model.get_model(), ResNet18Model.get_example_inputs()
121+
122+
117123
def gen_resnet50_model_and_inputs() -> Tuple[torch.nn.Module, Any]:
118-
from ..models.resnet50 import ResNet50Model
124+
from ..models.resnet import ResNet50Model
119125

120126
return ResNet50Model.get_model(), ResNet50Model.get_example_inputs()
121127

@@ -131,5 +137,6 @@ def gen_resnet50_model_and_inputs() -> Tuple[torch.nn.Module, Any]:
131137
"w2l": gen_wav2letter_model_and_inputs,
132138
"ic3": gen_inception_v3_model_and_inputs,
133139
"ic4": gen_inception_v4_model_and_inputs,
140+
"resnet18": gen_resnet18_model_and_inputs,
134141
"resnet50": gen_resnet50_model_and_inputs,
135142
}

examples/models/resnet50/TARGETS renamed to examples/models/resnet/TARGETS

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
22

33
python_library(
4-
name = "resnet50_export",
4+
name = "resnet_export",
55
srcs = [
66
"__init__.py",
77
"export.py",
88
],
9-
base_module = "executorch.examples.models.resnet50",
9+
base_module = "executorch.examples.models.resnet",
1010
deps = [
1111
"//caffe2:torch",
1212
"//pytorch/vision:torchvision",

examples/models/resnet50/__init__.py renamed to examples/models/resnet/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from .export import ResNet50Model
7+
from .export import ResNet18Model, ResNet50Model
88

99
__all__ = [
10+
ResNet18Model,
1011
ResNet50Model,
1112
]

examples/models/resnet50/export.py renamed to examples/models/resnet/export.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,23 @@
1313
logging.basicConfig(format=FORMAT)
1414

1515

16+
class ResNet18Model:
17+
def __init__(self):
18+
pass
19+
20+
@staticmethod
21+
def get_model():
22+
logging.info("loading torchvision resnet18 model")
23+
resnet18 = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
24+
logging.info("loaded torchvision resnet18 model")
25+
return resnet18
26+
27+
@staticmethod
28+
def get_example_inputs():
29+
input_shape = (1, 3, 224, 224)
30+
return (torch.randn(input_shape),)
31+
32+
1633
class ResNet50Model:
1734
def __init__(self):
1835
pass

0 commit comments

Comments
 (0)