Skip to content

Commit 3263d29

Browse files
committed
Code Review feedback
Organize imports Code Review feedback
1 parent 47b0ed9 commit 3263d29

File tree

2 files changed

+22
-24
lines changed

2 files changed

+22
-24
lines changed

sagemaker_neo_compilation_jobs/pytorch_torchvision/code/resnet18.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
1-
import os
1+
import io
2+
import json
23
import logging
4+
import os
5+
import pickle
6+
7+
import numpy as np
38
import torch
9+
import torchvision.transforms as transforms
10+
from PIL import Image # Training container doesn't have this package
411

512
logger = logging.getLogger(__name__)
613
logger.setLevel(logging.DEBUG)
714

15+
816
def transform_fn(model, payload, request_content_type,
917
response_content_type):
10-
from PIL import Image # Training container doesn't have this package
11-
import logging
12-
import numpy as np
13-
import io
14-
import json
15-
import torchvision.transforms as transforms
16-
1718

1819
logger.info('Invoking user-defined transform function')
1920

@@ -34,7 +35,7 @@ def transform_fn(model, payload, request_content_type,
3435
])
3536
normalized = preprocess(decoded)
3637
batchified = normalized.unsqueeze(0)
37-
38+
3839
# predict
3940
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
4041
batchified = batchified.to(device)
@@ -51,9 +52,7 @@ def transform_fn(model, payload, request_content_type,
5152
return response_body, content_type
5253

5354

54-
5555
def model_fn(model_dir):
56-
import pickle
5756

5857
logger.info('model_fn')
5958
with torch.neo.config(model_dir=model_dir, neo_runtime=True):

sagemaker_neo_compilation_jobs/pytorch_vgg19_bn/code/vgg19_bn.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,24 @@
1-
# ------------------------------------------------------------ #
2-
# Neo host methods #
3-
# ------------------------------------------------------------ #
4-
5-
import os
1+
import io
2+
import json
63
import logging
4+
import os
5+
import pickle
6+
7+
import numpy as np
78
import torch
9+
import torchvision.transforms as transforms
10+
from PIL import Image # Training container doesn't have this package
811

912
logger = logging.getLogger(__name__)
1013
logger.setLevel(logging.DEBUG)
1114

1215

16+
# ------------------------------------------------------------ #
17+
# Neo host methods #
18+
# ------------------------------------------------------------ #
19+
1320
def transform_fn(model, payload, request_content_type,
1421
response_content_type):
15-
from PIL import Image # Training container doesn't have this package
16-
import logging
17-
import numpy as np
18-
import io
19-
import json
20-
import torchvision.transforms as transforms
21-
2222

2323
logger.info('Invoking user-defined transform function')
2424

@@ -62,7 +62,6 @@ def transform_fn(model, payload, request_content_type,
6262

6363

6464
def model_fn(model_dir):
65-
import pickle
6665

6766
logger.info('model_fn')
6867
with torch.neo.config(model_dir=model_dir, neo_runtime=True):

0 commit comments

Comments
 (0)