Skip to content

Commit c44f877

Browse files
authored
Feat: TorchServe SageMaker MME GPU example (#4116)
* init torchserve-mme-gpu * add test data * add mme notebook * Update torchserve_multi_model_endpoint.ipynb * add dockerfile * fmt notebook * fmt * chmod 755 for build_and_push.sh * fix typo * model config * fix grammer * fix grammer * fix grammer * fix grammer * fix grammer * fix grammer
1 parent 324525d commit c44f877

18 files changed

+1672
-0
lines changed

inference/torchserve/mme-gpu/torchserve_multi_model_endpoint.ipynb

Lines changed: 870 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
segment-anything-py==1.0
2+
opencv-python-headless==4.7.0.68
3+
transformers==4.28.1
4+
ftfy
5+
diffusers
6+
xformers
7+
tqdm
8+
#easydict==1.9.0
9+
#scikit-image==0.17.2
10+
#scikit-learn==0.24.2
11+
easydict
12+
scikit-image
13+
tensorflow
14+
joblib
15+
matplotlib
16+
albumentations==0.5.2
17+
hydra-core==1.1.0
18+
pytorch-lightning
19+
tabulate
20+
kornia==0.5.0
21+
webdataset
22+
omegaconf==2.1.2
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
ARG BASE_IMAGE
2+
3+
FROM $BASE_IMAGE
4+
5+
#Install any additional libraries
6+
RUN pip install segment-anything-py==1.0
7+
RUN pip install opencv-python-headless==4.7.0.68
8+
RUN pip install matplotlib==3.6.3
9+
RUN pip install diffusers
10+
RUN pip install tqdm
11+
RUN pip install easydict
12+
RUN pip install scikit-image
13+
RUN pip install xformers
14+
RUN pip install tensorflow
15+
RUN pip install joblib
16+
RUN pip install matplotlib
17+
RUN pip install albumentations==0.5.2
18+
RUN pip install hydra-core==1.1.0
19+
RUN pip install pytorch-lightning
20+
RUN pip install tabulate
21+
RUN pip install kornia==0.5.0
22+
RUN pip install webdataset
23+
RUN pip install omegaconf==2.1.2
24+
RUN pip install transformers==4.28.1
25+
RUN pip install accelerate
26+
RUN pip install ftfy
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#!/usr/bin/env bash
2+
3+
# This script shows how to build the Docker image and push it to ECR to be ready for use
4+
# by SageMaker.
5+
6+
# The argument to this script is the image name. This will be used as the image on the local
7+
# machine and combined with the account and region to form the repository name for ECR.
8+
reponame=$1
9+
versiontag=$2
10+
baseimage=$3
11+
regionname=$4
12+
account=$5
13+
14+
if [ "$reponame" == "" ] || [ "$versiontag" == "" ] || [ "$baseimage" == "" ] || [ "$regionname" == "" ] || [ "$account" == "" ]
15+
then
16+
echo "Usage: $0 <repo-name> <version-tag> <base-image> <region> <account>"
17+
exit 1
18+
fi
19+
20+
if [ $? -ne 0 ]
21+
then
22+
exit 255
23+
fi
24+
25+
fullname="${account}.dkr.ecr.${regionname}.amazonaws.com/${reponame}:${versiontag}"
26+
27+
# If the repository doesn't exist in ECR, create it.
28+
aws ecr describe-repositories --repository-names "${reponame}" > /dev/null 2>&1
29+
30+
if [ $? -ne 0 ]
31+
then
32+
aws ecr create-repository --repository-name "${reponame}" > /dev/null
33+
fi
34+
35+
# Get the login command from ECR in order to pull down the SageMaker PyTorch image
36+
aws ecr get-login-password --region $regionname | docker login --username AWS --password-stdin ${baseimage}
37+
38+
aws ecr get-login-password --region $regionname | docker login --username AWS --password-stdin ${account}.dkr.ecr."${regionname}".amazonaws.com
39+
40+
# Build the docker image locally with the image name and then push it to ECR
41+
# with the full name.
42+
docker build -t ${reponame} . --build-arg BASE_IMAGE=${baseimage}
43+
docker tag ${reponame} ${fullname}
44+
45+
# Get the login command from ECR in order to pull down the SageMaker PyTorch image
46+
aws ecr get-login-password --region $regionname | docker login --username AWS --password-stdin ${account}.dkr.ecr."${regionname}".amazonaws.com
47+
48+
docker push ${fullname}
49+
echo "${fullname}"
50+
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import base64
2+
import json
3+
import io
4+
import numpy as np
5+
from PIL import Image
6+
import httpx
7+
from io import BytesIO
8+
9+
def encode_image(img):
10+
11+
# Convert the image to bytes
12+
with io.BytesIO() as output:
13+
img.save(output, format="JPEG")
14+
img_bytes = output.getvalue()
15+
16+
return base64.b64encode(img_bytes).decode('utf8')
17+
18+
img_file = 'sample1.png'
19+
img_bytes = None
20+
with Image.open(img_file) as f:
21+
img_bytes = encode_image(f)
22+
23+
mask_file = 'sample1_mask.jpg'
24+
mask = Image.open(mask_file)
25+
mask_bytes = encode_image(mask)
26+
27+
payload = {
28+
"image": img_bytes,
29+
"mask_image": mask_bytes,
30+
}
31+
32+
url="http://127.0.0.1:8080/predictions/lama"
33+
response = httpx.post(url, json=payload, timeout=None)
34+
encoded_masks_string = response.json()['generated_image']
35+
base64_bytes_masks = base64.b64decode(encoded_masks_string)
36+
print(base64_bytes_masks)
37+
with Image.open(io.BytesIO(base64_bytes_masks)) as f:
38+
generated_image_rgb=f.convert("RGB")
39+
generated_image_rgb.show()
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
import os
2+
import sys
3+
import numpy as np
4+
import torch
5+
import yaml
6+
import glob
7+
import argparse
8+
from PIL import Image
9+
from omegaconf import OmegaConf
10+
from pathlib import Path
11+
import json
12+
import base64
13+
14+
from abc import ABC
15+
from io import BytesIO
16+
from ts.context import Context
17+
from ts.torch_handler.base_handler import BaseHandler
18+
19+
from saicinpainting.evaluation.utils import move_to_device
20+
from saicinpainting.training.trainers import load_checkpoint
21+
from saicinpainting.evaluation.data import pad_tensor_to_modulo
22+
23+
os.environ['OMP_NUM_THREADS'] = '1'
24+
os.environ['OPENBLAS_NUM_THREADS'] = '1'
25+
os.environ['MKL_NUM_THREADS'] = '1'
26+
os.environ['VECLIB_MAXIMUM_THREADS'] = '1'
27+
os.environ['NUMEXPR_NUM_THREADS'] = '1'
28+
29+
class LamaHandler(BaseHandler, ABC):
30+
31+
def __init__(self):
32+
super().__init__()
33+
self.initialized = False
34+
35+
def initialize(self, ctx: Context):
36+
properties = ctx.system_properties
37+
model_dir = properties.get("model_dir")
38+
39+
self.device = torch.device(
40+
"cuda:" + str(properties.get("gpu_id"))
41+
if torch.cuda.is_available() and properties.get("gpu_id") is not None
42+
else "cpu"
43+
)
44+
45+
predict_config = OmegaConf.load(f'{model_dir}/configs/prediction/default.yaml')
46+
predict_config.model.path = f'{model_dir}/big-lama'
47+
with open(f'{predict_config.model.path}/config.yaml', 'r') as f:
48+
train_config = OmegaConf.create(yaml.safe_load(f))
49+
50+
train_config.training_model.predict_only = True
51+
train_config.visualizer.kind = 'noop'
52+
53+
checkpoint_path = os.path.join(
54+
predict_config.model.path, 'models',
55+
predict_config.model.checkpoint
56+
)
57+
self.model = load_checkpoint(
58+
train_config,
59+
checkpoint_path,
60+
strict=False,
61+
map_location='cpu')
62+
self.model.freeze()
63+
self.model.to(self.device)
64+
65+
self.initialized = True
66+
67+
def preprocess(self, data):
68+
69+
requests = []
70+
for row in data:
71+
request = row.get("data") or row.get("body")
72+
73+
if isinstance(request, (bytearray, bytes)):
74+
request = json.loads(request.decode('utf-8'))
75+
76+
if isinstance(request, dict) and \
77+
"image" in request and \
78+
"mask_image" in request:
79+
img = request["image"]
80+
if isinstance(img, str):
81+
img = base64.b64decode(img)
82+
83+
with Image.open(BytesIO(img)) as f:
84+
img_rgb = f.convert("RGB")
85+
img_np_array = np.array(img_rgb)
86+
request["image"] = img_np_array
87+
88+
mask_img = request["mask_image"]
89+
if isinstance(mask_img, str):
90+
mask_img = base64.b64decode(mask_img)
91+
92+
with Image.open(BytesIO(mask_img)) as f:
93+
mask_img_rgb = f.convert("L")
94+
mask_img_np_array = np.array(mask_img_rgb)
95+
request["mask_image"] = mask_img_np_array
96+
97+
requests.append(request)
98+
else:
99+
raise RuntimeError("Dict request must include image and mask_image")
100+
101+
return requests
102+
103+
def inference(self, data):
104+
105+
responses = []
106+
for request in data:
107+
mod = 8
108+
img = torch.from_numpy(request["image"]).float().div(255.)
109+
mask = torch.from_numpy(request["mask_image"]).float()
110+
111+
batch = {}
112+
batch['image'] = img.permute(2, 0, 1).unsqueeze(0)
113+
batch['mask'] = mask[None, None]
114+
unpad_to_size = [batch['image'].shape[2], batch['image'].shape[3]]
115+
batch['image'] = pad_tensor_to_modulo(batch['image'], mod)
116+
batch['mask'] = pad_tensor_to_modulo(batch['mask'], mod)
117+
#batch = move_to_device(batch, 'cuda')
118+
batch = move_to_device(batch, self.device)
119+
batch['mask'] = (batch['mask'] > 0) * 1
120+
121+
batch = self.model(batch)
122+
cur_res = batch['inpainted'][0].permute(1, 2, 0)
123+
cur_res = cur_res.detach().cpu().numpy()
124+
125+
if unpad_to_size is not None:
126+
orig_height, orig_width = unpad_to_size
127+
cur_res = cur_res[:orig_height, :orig_width]
128+
129+
cur_res = np.clip(cur_res * 255, 0, 255).astype('uint8')
130+
131+
output_img = Image.fromarray(cur_res)
132+
output_img_bytes = self.encode_image(output_img)
133+
134+
print(f'output_img_bytes:{output_img_bytes}')
135+
responses.append({"generated_image": output_img_bytes})
136+
137+
return responses
138+
139+
def handle(self, data, context):
140+
requests = self.preprocess(data)
141+
responses = self.inference(requests)
142+
143+
return responses
144+
145+
def dilate_mask(self, mask, dilate_factor=15):
146+
mask = mask.astype(np.uint8)
147+
mask = cv2.dilate(
148+
mask,
149+
np.ones((dilate_factor, dilate_factor), np.uint8),
150+
iterations=1
151+
)
152+
return mask
153+
154+
def encode_image(self, img):
155+
# Convert the image to bytes
156+
with BytesIO() as output:
157+
img.save(output, format="JPEG")
158+
img_bytes = output.getvalue()
159+
160+
return base64.b64encode(img_bytes).decode("utf-8")
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
minWorkers: 4
2+
maxWorkers: 4
3+
batchSize: 1
4+
maxBatchDelay: 200
5+
responseTimeout: 300
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
opencv-python
2+
pyyaml
3+
tqdm
4+
numpy
5+
easydict==1.9.0
6+
scikit-image==0.17.2
7+
scikit-learn==0.24.2
8+
tensorflow
9+
joblib
10+
matplotlib
11+
pandas
12+
albumentations==0.5.2
13+
hydra-core==1.1.0
14+
#pytorch-lightning==1.2.9
15+
pytorch-lightning
16+
tabulate
17+
kornia==0.5.0
18+
webdataset
19+
packaging
20+
omegaconf==2.1.2
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import base64
2+
import json
3+
from json import JSONEncoder
4+
import io
5+
import numpy as np
6+
from PIL import Image
7+
import httpx
8+
9+
def encode_image(img):
10+
11+
# Convert the image to bytes
12+
with io.BytesIO() as output:
13+
img.save(output, format="JPEG")
14+
img_bytes = output.getvalue()
15+
16+
return base64.b64encode(img_bytes).decode('utf8')
17+
18+
img_file = 'sample1.png'
19+
img_bytes = None
20+
with Image.open(img_file) as f:
21+
img_bytes = encode_image(f)
22+
23+
gen_args = json.dumps(dict(point_coords=[750, 500], point_labels=1, dilate_kernel_size=15))
24+
25+
payload = {
26+
"image": img_bytes,
27+
"gen_args": gen_args
28+
}
29+
30+
url="http://127.0.0.1:8080/predictions/sam"
31+
response = httpx.post(url, json=payload, timeout=None)
32+
encoded_masks_string = response.json()['generated_image']
33+
base64_bytes_masks = base64.b64decode(encoded_masks_string)
34+
with Image.open(io.BytesIO(base64_bytes_masks)) as f:
35+
generated_image_rgb=f.convert("RGB")
36+
generated_image_rgb.show()

0 commit comments

Comments
 (0)