Skip to content

Commit fea86b9

Browse files
authored
Merge pull request #1 from makungaj1/patch-1
Update in_process_mode.py
2 parents 6960d5f + 442704f commit fea86b9

File tree

3 files changed

+134
-117
lines changed

3 files changed

+134
-117
lines changed

src/sagemaker/serve/app.py

Lines changed: 74 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,75 +1,101 @@
11
"""FastAPI requests"""
22

33
from __future__ import absolute_import
4+
5+
import asyncio
46
import logging
7+
import threading
8+
from typing import Optional
59

610

711
logger = logging.getLogger(__name__)
812

913

1014
try:
1115
import uvicorn
12-
1316
except ImportError:
14-
logger.error("To enable in_process mode for Transformers install uvicorn from HuggingFace hub")
17+
logger.error("Unable to import uvicorn, check if uvicorn is installed.")
1518

1619

1720
try:
1821
from transformers import pipeline
19-
20-
generator = pipeline("text-generation", model="gpt2")
21-
2222
except ImportError:
2323
logger.error(
24-
"To enable in_process mode for Transformers install transformers from HuggingFace hub"
24+
"Unable to import transformers, check if transformers is installed."
2525
)
2626

2727

2828
try:
29-
from fastapi import FastAPI, Request
30-
31-
app = FastAPI(
32-
title="Transformers In Process Server",
33-
version="1.0",
34-
description="A simple server",
35-
)
29+
from fastapi import FastAPI, Request, APIRouter
30+
except ImportError:
31+
logger.error("Unable to import fastapi, check if fastapi is installed.")
32+
33+
34+
class InProcessServer:
35+
36+
def __init__(
37+
self,
38+
model_id: Optional[str] = None,
39+
task: Optional[str] = None
40+
):
41+
self._thread = None
42+
self._loop = None
43+
self._stop_event = asyncio.Event()
44+
self._router = APIRouter()
45+
self._model_id = model_id
46+
self._task = task
47+
self.server = None
48+
self.port = None
49+
self.host = None
50+
# TODO: Pick up device automatically.
51+
self._generator = pipeline(task, model=model_id, device="cpu")
52+
53+
@self._router.post("/generate")
54+
async def generate_text(prompt: Request):
55+
"""Placeholder docstring"""
56+
str_prompt = await prompt.json()
57+
str_prompt = str_prompt["inputs"] if "inputs" in str_prompt else str_prompt
58+
59+
generated_text = self._generator(
60+
str_prompt, max_length=30, num_return_sequences=1, truncation=True
61+
)
62+
return generated_text
63+
64+
self._create_server()
65+
66+
def _create_server(self):
67+
_app = FastAPI()
68+
_app.include_router(self._router)
69+
70+
config = uvicorn.Config(
71+
_app,
72+
host="127.0.0.1",
73+
port=9007,
74+
log_level="info",
75+
loop="asyncio",
76+
reload=True,
77+
use_colors=True,
78+
)
3679

37-
@app.get("/")
38-
def read_root():
39-
"""Placeholder docstring"""
40-
return {"Hello": "World"}
80+
self.server = uvicorn.Server(config)
81+
self.host = config.host
82+
self.port = config.port
4183

42-
@app.get("/generate")
43-
async def generate_text(prompt: Request):
44-
"""Placeholder docstring"""
45-
str_prompt = await prompt.json()
84+
def start_server(self):
85+
"""Starts the uvicorn server."""
86+
if not (self._thread and self._thread.is_alive()):
87+
logger.info("Waiting for a connection...")
88+
self._thread = threading.Thread(target=self._start_run_async_in_thread, daemon=True)
89+
self._thread.start()
4690

47-
generated_text = generator(
48-
str_prompt, max_length=30, num_return_sequences=5, truncation=True
49-
)
50-
return generated_text[0]["generated_text"]
91+
def stop_server(self):
92+
"""Destroys the uvicorn server."""
93+
# TODO: Implement me.
5194

52-
@app.post("/post")
53-
def post(payload: dict):
54-
"""Placeholder docstring"""
55-
return payload
95+
def _start_run_async_in_thread(self):
96+
loop = asyncio.new_event_loop()
97+
asyncio.set_event_loop(loop)
98+
loop.run_until_complete(self._serve())
5699

57-
except ImportError:
58-
logger.error("To enable in_process mode for Transformers install fastapi from HuggingFace hub")
59-
60-
61-
async def main():
62-
"""Running server locally with uvicorn"""
63-
config = uvicorn.Config(
64-
"sagemaker.serve.app:app",
65-
host="127.0.0.1",
66-
port=9007,
67-
log_level="info",
68-
loop="asyncio",
69-
reload=True,
70-
workers=3,
71-
use_colors=True,
72-
)
73-
server = uvicorn.Server(config)
74-
logger.info("Waiting for a connection...")
75-
await server.serve()
100+
async def _serve(self):
101+
await self.server.serve()

src/sagemaker/serve/mode/in_process_mode.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Module that defines the InProcessMode class"""
22

33
from __future__ import absolute_import
4+
45
from pathlib import Path
56
import logging
67
from typing import Dict, Type
@@ -11,7 +12,7 @@
1112
from sagemaker.serve.spec.inference_spec import InferenceSpec
1213
from sagemaker.serve.builder.schema_builder import SchemaBuilder
1314
from sagemaker.serve.utils.types import ModelServer
14-
from sagemaker.serve.utils.exceptions import LocalDeepPingException
15+
from sagemaker.serve.utils.exceptions import InProcessDeepPingException
1516
from sagemaker.serve.model_server.multi_model_server.server import InProcessMultiModelServer
1617
from sagemaker.session import Session
1718

@@ -46,7 +47,7 @@ def __init__(
4647
self.session = session
4748
self.schema_builder = schema_builder
4849
self.model_server = model_server
49-
self._ping_container = None
50+
self._ping_local_server = None
5051

5152
def load(self, model_path: str = None):
5253
"""Loads model path, checks that path exists"""
@@ -69,22 +70,30 @@ def create_server(
6970
logger.info("Waiting for model server %s to start up...", self.model_server)
7071

7172
if self.model_server == ModelServer.MMS:
72-
self._ping_container = self._multi_model_server_deep_ping
73+
self._ping_local_server = self._multi_model_server_deep_ping
7374
self._start_serving()
7475

75-
time_limit = datetime.now() + timedelta(seconds=5)
76-
while self._ping_container is not None:
77-
final_pull = datetime.now() > time_limit
76+
# allow some time for server to be ready.
77+
time.sleep(1)
7878

79+
count = 1
80+
time_limit = datetime.now() + timedelta(seconds=20)
81+
healthy = True
82+
while True:
83+
final_pull = datetime.now() > time_limit
7984
if final_pull:
8085
break
8186

82-
time.sleep(10)
83-
84-
healthy, response = self._ping_container(predictor)
87+
healthy, response = self._ping_local_server(predictor)
88+
count += 1
8589
if healthy:
8690
logger.debug("Ping health check has passed. Returned %s", str(response))
8791
break
8892

93+
time.sleep(1)
94+
8995
if not healthy:
90-
raise LocalDeepPingException(_PING_HEALTH_CHECK_FAIL_MSG)
96+
raise InProcessDeepPingException(_PING_HEALTH_CHECK_FAIL_MSG)
97+
98+
def destroy_server(self):
99+
self._stop_serving()

src/sagemaker/serve/model_server/multi_model_server/server.py

Lines changed: 41 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22

33
from __future__ import absolute_import
44

5-
import asyncio
5+
import json
6+
67
import requests
78
import logging
89
import platform
9-
import time
1010
from pathlib import Path
11+
1112
from sagemaker import Session, fw_utils
12-
from sagemaker.serve.utils.exceptions import LocalModelInvocationException
13+
from sagemaker.serve.utils.exceptions import InProcessDeepPingException
1314
from sagemaker.base_predictor import PredictorBase
1415
from sagemaker.s3_utils import determine_bucket_and_prefix, parse_s3_url, s3_path_join
1516
from sagemaker.s3 import S3Uploader
@@ -25,71 +26,52 @@
2526
class InProcessMultiModelServer:
2627
"""In Process Mode Multi Model server instance"""
2728

28-
def __init__(self):
29-
from sagemaker.serve.app import main
30-
31-
self._main = main
32-
3329
def _start_serving(self):
3430
"""Initializes the start of the server"""
35-
background_tasks = set()
36-
task = asyncio.create_task(self._main())
37-
background_tasks.add(task)
38-
task.add_done_callback(background_tasks.discard)
31+
from sagemaker.serve.app import InProcessServer
32+
if hasattr(self, "inference_spec"):
33+
model_id = self.inference_spec.get_model()
34+
else:
35+
model_id = None
36+
self.server = InProcessServer(model_id=model_id)
3937

40-
time.sleep(10)
38+
self.server.start_server()
4139

42-
def _invoke_multi_model_server_serving(self, request: object, content_type: str, accept: str):
40+
def _stop_serving(self):
41+
"""Stops the server"""
42+
self.server.stop_server()
43+
44+
def _invoke_multi_model_server_serving(self, request: bytes, content_type: str, accept: str):
4345
"""Placeholder docstring"""
44-
background_tasks = set()
45-
task = asyncio.create_task(self.generate_connect())
46-
background_tasks.add(task)
47-
task.add_done_callback(background_tasks.discard)
46+
try:
47+
response = requests.post(
48+
f"http://{self.server.host}:{self.server.port}/generate",
49+
data=request,
50+
headers={"Content-Type": content_type, "Accept": accept},
51+
timeout=600,
52+
)
53+
response.raise_for_status()
54+
if isinstance(response.content, bytes):
55+
return json.loads(response.content.decode('utf-8'))
56+
return response.content
57+
except Exception as e:
58+
if not "Connection refused" in str(e):
59+
raise Exception("Unable to send request to the local server: Connection refused.") from e
60+
raise Exception("Unable to send request to the local server.") from e
4861

4962
def _multi_model_server_deep_ping(self, predictor: PredictorBase):
5063
"""Sends a deep ping to ensure prediction"""
51-
background_tasks = set()
52-
task = asyncio.create_task(self.tcp_connect())
53-
background_tasks.add(task)
54-
task.add_done_callback(background_tasks.discard)
64+
healthy = False
5565
response = None
56-
return True, response
57-
58-
async def generate_connect(self):
59-
"""Writes the lines in bytes for server"""
60-
reader, writer = await asyncio.open_connection("127.0.0.1", 9007)
61-
a = (
62-
b"GET /generate HTTP/1.1\r\nHost: 127.0.0.1:9007\r\nUser-Agent: "
63-
b"python-requests/2.31.0\r\nAccept-Encoding: gzip, deflate, br\r\nAccept: */*\r\nConnection: ",
64-
"keep-alive\r\nContent-Length: 33\r\nContent-Type: application/json\r\n\r\n",
65-
)
66-
b = b'"\\"Hello, I\'m a language model\\""'
67-
list = [a, b]
68-
writer.writelines(list)
69-
logger.debug(writer.get_extra_info("peername"))
70-
logger.debug(writer.transport)
71-
72-
data = await reader.read()
73-
logger.info("Response from server")
74-
logger.info(data)
75-
writer.close()
76-
await writer.wait_closed()
77-
78-
async def tcp_connect(self):
79-
"""Writes the lines in bytes for server"""
80-
reader, writer = await asyncio.open_connection("127.0.0.1", 9007)
81-
writer.write(
82-
b"GET / HTTP/1.1\r\nHost: 127.0.0.1:9007\r\nUser-Agent: python-requests/2.32.3\r\nAccept-Encoding: gzip, ",
83-
"deflate, br\r\nAccept: */*\r\nConnection: keep-alive\r\n\r\n",
84-
)
85-
logger.debug(writer.get_extra_info("peername"))
86-
logger.debug(writer.transport)
87-
88-
data = await reader.read()
89-
logger.info("Response from server")
90-
logger.info(data)
91-
writer.close()
92-
await writer.wait_closed()
66+
try:
67+
response = predictor.predict(self.schema_builder.sample_input)
68+
healthy = response is not None
69+
# pylint: disable=broad-except
70+
except Exception as e:
71+
if "422 Client Error: Unprocessable Entity for url" in str(e):
72+
raise InProcessDeepPingException(str(e))
73+
74+
return healthy, response
9375

9476

9577
class LocalMultiModelServer:

0 commit comments

Comments
 (0)