Skip to content

Commit ddd1507

Browse files
author
Nikhil Kulkarni
committed
Merge branch 'master' into fix_xgboost_churn_neo
2 parents cb38428 + 28b727f commit ddd1507

File tree

9 files changed

+661
-348
lines changed

9 files changed

+661
-348
lines changed

sagemaker-experiments/mnist-handwritten-digits-classification-experiment/mnist-handwritten-digits-classification-experiment.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
"!{sys.executable} -m pip install torch==1.1.0\n",
6666
"!{sys.executable} -m pip install torchvision==0.3.0\n",
6767
"!{sys.executable} -m pip install pillow==6.2.2 ",
68-
"!{sys.executable} -m pip install --upgrade sagemaker",
68+
"!{sys.executable} -m pip install --upgrade sagemaker"
6969
]
7070
},
7171
{
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import io
2+
import json
3+
import logging
4+
import os
5+
import pickle
6+
7+
import numpy as np
8+
import torch
9+
import torchvision.transforms as transforms
10+
from PIL import Image # Training container doesn't have this package
11+
12+
logger = logging.getLogger(__name__)
13+
logger.setLevel(logging.DEBUG)
14+
15+
16+
def transform_fn(model, payload, request_content_type,
17+
response_content_type):
18+
19+
logger.info('Invoking user-defined transform function')
20+
21+
if request_content_type != 'application/octet-stream':
22+
raise RuntimeError(
23+
'Content type must be application/octet-stream. Provided: {0}'.format(request_content_type))
24+
25+
# preprocess
26+
decoded = Image.open(io.BytesIO(payload))
27+
preprocess = transforms.Compose([
28+
transforms.Resize(256),
29+
transforms.CenterCrop(224),
30+
transforms.ToTensor(),
31+
transforms.Normalize(
32+
mean=[
33+
0.485, 0.456, 0.406], std=[
34+
0.229, 0.224, 0.225]),
35+
])
36+
normalized = preprocess(decoded)
37+
batchified = normalized.unsqueeze(0)
38+
39+
# predict
40+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41+
batchified = batchified.to(device)
42+
result = model.forward(batchified)
43+
44+
# Softmax (assumes batch size 1)
45+
result = np.squeeze(result.cpu().numpy())
46+
result_exp = np.exp(result - np.max(result))
47+
result = result_exp / np.sum(result_exp)
48+
49+
response_body = json.dumps(result.tolist())
50+
content_type = 'application/json'
51+
52+
return response_body, content_type
53+
54+
55+
def model_fn(model_dir):
56+
57+
logger.info('model_fn')
58+
with torch.neo.config(model_dir=model_dir, neo_runtime=True):
59+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
60+
# The compiled model is saved as "compiled.pt"
61+
model = torch.jit.load(os.path.join(model_dir, 'compiled.pt'))
62+
model = model.to(device)
63+
64+
# It is recommended to run warm-up inference during model load
65+
sample_input_path = os.path.join(model_dir, 'sample_input.pkl')
66+
with open(sample_input_path, 'rb') as input_file:
67+
model_input = pickle.load(input_file)
68+
if torch.is_tensor(model_input):
69+
model_input = model_input.to(device)
70+
model(model_input)
71+
elif isinstance(model_input, tuple):
72+
model_input = (inp.to(device)
73+
for inp in model_input if torch.is_tensor(inp))
74+
model(*model_input)
75+
else:
76+
print("Only supports a torch tensor or a tuple of torch tensors")
77+
78+
return model

0 commit comments

Comments
 (0)