Skip to content

Commit 43b7a58

Browse files
authored
FIx traffic splitter example (#1597)
1 parent a1253c2 commit 43b7a58

File tree

4 files changed

+119
-22
lines changed

4 files changed

+119
-22
lines changed

examples/traffic-splitter/cortex.yaml

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
11
# WARNING: you are on the master branch; please refer to examples on the branch corresponding to your `cortex version` (e.g. for version 0.22.*, run `git checkout -b 0.22` or switch to the `0.22` branch on GitHub)
22

3-
- name: iris-classifier-onnx
3+
- name: iris-classifier-pytorch
44
kind: RealtimeAPI
55
predictor:
6-
type: onnx
7-
path: onnx_predictor.py
8-
model_path: s3://cortex-examples/onnx/iris-classifier/
6+
type: python
7+
path: pytorch_predictor.py
8+
config:
9+
model: s3://cortex-examples/pytorch/iris-classifier/weights.pth
910
monitoring:
1011
model_type: classification
1112

12-
- name: iris-classifier-tf
13+
- name: iris-classifier-onnx
1314
kind: RealtimeAPI
1415
predictor:
15-
type: tensorflow
16-
path: tensorflow_predictor.py
17-
model_path: s3://cortex-examples/tensorflow/iris-classifier/nn/
16+
type: onnx
17+
path: onnx_predictor.py
18+
model_path: s3://cortex-examples/onnx/iris-classifier/
1819
monitoring:
1920
model_type: classification
2021

@@ -23,5 +24,5 @@
2324
apis:
2425
- name: iris-classifier-onnx
2526
weight: 30
26-
- name: iris-classifier-tf
27+
- name: iris-classifier-pytorch
2728
weight: 70

examples/traffic-splitter/model.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# WARNING: you are on the master branch; please refer to examples on the branch corresponding to your `cortex version` (e.g. for version 0.22.*, run `git checkout -b 0.22` or switch to the `0.22` branch on GitHub)
2+
3+
import torch
4+
import torch.nn as nn
5+
import torch.nn.functional as F
6+
from torch.autograd import Variable
7+
from sklearn.datasets import load_iris
8+
from sklearn.model_selection import train_test_split
9+
from sklearn.metrics import accuracy_score
10+
11+
12+
class IrisNet(nn.Module):
13+
def __init__(self):
14+
super(IrisNet, self).__init__()
15+
self.fc1 = nn.Linear(4, 100)
16+
self.fc2 = nn.Linear(100, 100)
17+
self.fc3 = nn.Linear(100, 3)
18+
self.softmax = nn.Softmax(dim=1)
19+
20+
def forward(self, X):
21+
X = F.relu(self.fc1(X))
22+
X = self.fc2(X)
23+
X = self.fc3(X)
24+
X = self.softmax(X)
25+
return X
26+
27+
28+
if __name__ == "__main__":
29+
iris = load_iris()
30+
X, y = iris.data, iris.target
31+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.8, random_state=42)
32+
33+
train_X = Variable(torch.Tensor(X_train).float())
34+
test_X = Variable(torch.Tensor(X_test).float())
35+
train_y = Variable(torch.Tensor(y_train).long())
36+
test_y = Variable(torch.Tensor(y_test).long())
37+
38+
model = IrisNet()
39+
40+
criterion = nn.CrossEntropyLoss()
41+
42+
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
43+
44+
for epoch in range(1000):
45+
optimizer.zero_grad()
46+
out = model(train_X)
47+
loss = criterion(out, train_y)
48+
loss.backward()
49+
optimizer.step()
50+
51+
if epoch % 100 == 0:
52+
print("number of epoch {} loss {}".format(epoch, loss))
53+
54+
predict_out = model(test_X)
55+
_, predict_y = torch.max(predict_out, 1)
56+
57+
print("prediction accuracy {}".format(accuracy_score(test_y.data, predict_y.data)))
58+
59+
torch.save(model.state_dict(), "weights.pth")
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# WARNING: you are on the master branch; please refer to examples on the branch corresponding to your `cortex version` (e.g. for version 0.22.*, run `git checkout -b 0.22` or switch to the `0.22` branch on GitHub)
2+
3+
import re
4+
import torch
5+
import os
6+
import boto3
7+
from botocore import UNSIGNED
8+
from botocore.client import Config
9+
from model import IrisNet
10+
11+
labels = ["setosa", "versicolor", "virginica"]
12+
13+
14+
class PythonPredictor:
15+
def __init__(self, config):
16+
# download the model
17+
bucket, key = re.match("s3://(.+?)/(.+)", config["model"]).groups()
18+
19+
if os.environ.get("AWS_ACCESS_KEY_ID"):
20+
s3 = boto3.client("s3") # client will use your credentials if available
21+
else:
22+
s3 = boto3.client("s3", config=Config(signature_version=UNSIGNED)) # anonymous client
23+
24+
s3.download_file(bucket, key, "/tmp/model.pth")
25+
26+
# initialize the model
27+
model = IrisNet()
28+
model.load_state_dict(torch.load("/tmp/model.pth"))
29+
model.eval()
30+
31+
self.model = model
32+
33+
def predict(self, payload):
34+
# Convert the request to a tensor and pass it into the model
35+
input_tensor = torch.FloatTensor(
36+
[
37+
[
38+
payload["sepal_length"],
39+
payload["sepal_width"],
40+
payload["petal_length"],
41+
payload["petal_width"],
42+
]
43+
]
44+
)
45+
46+
# Run the prediction
47+
output = self.model(input_tensor)
48+
49+
# Translate the model output to the corresponding label string
50+
return labels[torch.argmax(output[0])]

examples/traffic-splitter/tensorflow_predictor.py

Lines changed: 0 additions & 13 deletions
This file was deleted.

0 commit comments

Comments
 (0)