Skip to content

Commit c3f3ef4

Browse files
kirklandsignfacebook-github-bot
authored andcommitted
Add EDSR model (#201)
Summary: Pull Request resolved: #201 Export, runtime, op coverage OK Reviewed By: guangy10 Differential Revision: D48882276 fbshipit-source-id: de96eca3b0e29987a84913461822b0debba0d925
1 parent 5169d8d commit c3f3ef4

File tree

6 files changed

+45
-0
lines changed

6 files changed

+45
-0
lines changed

.ci/docker/requirements-ci.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ ruamel.yaml==0.17.32
55
sympy==1.12
66
timm==0.6.13
77
tomli==2.0.1
8+
torchsr==1.0.4
89
transformers==4.31.0
910
zstd==1.5.5.1
1011
pytest==7.2.0

examples/models/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ python_library(
1010
"//caffe2:torch",
1111
"//executorch/examples/models:model_base", # @manual
1212
"//executorch/examples/models/deeplab_v3:dl3_model", # @manual
13+
"//executorch/examples/models/edsr:edsr_model", # @manual
1314
"//executorch/examples/models/emformer_rnnt:emformer_rnnt_model", # @manual
1415
"//executorch/examples/models/inception_v3:ic3_model", # @manual
1516
"//executorch/examples/models/inception_v4:ic4_model", # @manual

examples/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
"add": ("toy_model", "AddModule"),
1313
"add_mul": ("toy_model", "AddMulModule"),
1414
"dl3": ("deeplab_v3", "DeepLabV3ResNet50Model"),
15+
"edsr": ("edsr", "EdsrModel"),
1516
"emformer_transcribe": ("emformer_rnnt", "EmformerRnntTranscriberModel"),
1617
"emformer_predict": ("emformer_rnnt", "EmformerRnntPredictorModel"),
1718
"emformer_join": ("emformer_rnnt", "EmformerRnntJoinerModel"),

examples/models/edsr/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .model import EdsrModel
8+
9+
__all__ = [
10+
EdsrModel,
11+
]

examples/models/edsr/model.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
9+
import torch
10+
11+
from torchsr.models import edsr_r16f64 # @manual
12+
13+
from ..model_base import EagerModelBase
14+
15+
16+
class EdsrModel(EagerModelBase):
17+
def __init__(self):
18+
pass
19+
20+
def get_eager_model(self) -> torch.nn.Module:
21+
logging.info("Loading edsr model")
22+
m = edsr_r16f64(2, True)
23+
logging.info("Loaded edsr model")
24+
return m
25+
26+
def get_example_inputs(self):
27+
tensor_size = (1, 3, 224, 224)
28+
return (torch.randn(tensor_size),)

install_requirements.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,6 @@ pip install --pre timm==${TIMM_VERSION}
3333

3434
TRANSFORMERS_VERSION=4.31.0
3535
pip install --pre transformers==${TRANSFORMERS_VERSION}
36+
37+
TORCHSR_VERSION=1.0.4
38+
pip install --pre torchsr==${TORCHSR_VERSION}

0 commit comments

Comments
 (0)