-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feat: FastAPI integration for In_Process Mode (2/2) #4808
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 76 commits
2cc906b
b25295a
fb28458
3576ea9
d3b8e9b
68cede1
02e54ef
f39cca6
cc0ca14
18fc3f2
495c7b4
1121f47
b6062a7
1ec209c
ca6c818
cd3dbaa
f52f36c
1843210
d0fe3ac
1b93244
b40f36c
68000e1
499063d
de054b5
511df55
7e17631
6f34c61
de9a36c
adb9531
5a371b9
5ccfadd
563d397
acc3cbc
cc236e3
b16b227
faca933
d7366ab
8eda8b5
8302586
04887c6
ff4f62d
96ce9d0
7939237
4e6bd26
282c1ab
87f9de9
1962132
f3bf3c3
d928254
50db803
8d4f06c
21f99b5
2e14cd3
00b19e7
182535a
a4c5e4f
8c2f919
1783bae
b414b2b
5a2328c
62e89ec
ebc465e
a356dad
e2ae197
442686f
8dd3468
31e22ff
890ccd4
8237b7b
5241022
6960d5f
0821e14
7fb02cf
e89940e
5570442
442704f
fea86b9
84f1808
a0f43b3
1e50b3a
009a90e
ae79df1
ebfc3e2
c83e0ce
6d023b5
a83664b
88ee686
69304c2
7b38f76
b62d097
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,5 @@ | ||
accelerate>=0.24.1,<=0.27.0 | ||
sagemaker_schema_inference_artifacts>=0.0.5 | ||
uvicorn>=0.30.1 | ||
fastapi>=0.111.0 | ||
nest-asyncio |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
"""FastAPI requests""" | ||
|
||
from __future__ import absolute_import | ||
|
||
import asyncio | ||
import logging | ||
import threading | ||
from typing import Optional | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
try: | ||
import uvicorn | ||
except ImportError: | ||
logger.error("Unable to import uvicorn, check if uvicorn is installed.") | ||
|
||
|
||
try: | ||
from transformers import pipeline | ||
except ImportError: | ||
logger.error( | ||
"Unable to import transformers, check if transformers is installed." | ||
) | ||
|
||
|
||
try: | ||
from fastapi import FastAPI, Request, APIRouter | ||
except ImportError: | ||
logger.error("Unable to import fastapi, check if fastapi is installed.") | ||
|
||
|
||
class InProcessServer: | ||
|
||
def __init__( | ||
self, | ||
model_id: Optional[str] = None, | ||
task: Optional[str] = None | ||
): | ||
self._thread = None | ||
self._loop = None | ||
self._stop_event = asyncio.Event() | ||
self._router = APIRouter() | ||
self._model_id = model_id | ||
self._task = task | ||
self.server = None | ||
self.port = None | ||
self.host = None | ||
# TODO: Pick up device automatically. | ||
self._generator = pipeline(task, model=model_id, device="cpu") | ||
|
||
@self._router.post("/generate") | ||
async def generate_text(prompt: Request): | ||
"""Placeholder docstring""" | ||
str_prompt = await prompt.json() | ||
str_prompt = str_prompt["inputs"] if "inputs" in str_prompt else str_prompt | ||
|
||
generated_text = self._generator( | ||
str_prompt, max_length=30, num_return_sequences=1, truncation=True | ||
) | ||
return generated_text | ||
|
||
self._create_server() | ||
|
||
def _create_server(self): | ||
_app = FastAPI() | ||
_app.include_router(self._router) | ||
|
||
config = uvicorn.Config( | ||
_app, | ||
host="127.0.0.1", | ||
port=9007, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. q: can we let port be dynamic? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We are researching this, will be an action item. |
||
log_level="info", | ||
loop="asyncio", | ||
reload=True, | ||
use_colors=True, | ||
) | ||
|
||
self.server = uvicorn.Server(config) | ||
self.host = config.host | ||
self.port = config.port | ||
|
||
def start_server(self): | ||
"""Starts the uvicorn server.""" | ||
if not (self._thread and self._thread.is_alive()): | ||
logger.info("Waiting for a connection...") | ||
self._thread = threading.Thread(target=self._start_run_async_in_thread, daemon=True) | ||
self._thread.start() | ||
|
||
def stop_server(self): | ||
"""Destroys the uvicorn server.""" | ||
# TODO: Implement me. | ||
|
||
def _start_run_async_in_thread(self): | ||
loop = asyncio.new_event_loop() | ||
asyncio.set_event_loop(loop) | ||
loop.run_until_complete(self._serve()) | ||
|
||
async def _serve(self): | ||
await self.server.serve() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
"""Module that defines the InProcessMode class""" | ||
|
||
from __future__ import absolute_import | ||
|
||
from pathlib import Path | ||
import logging | ||
from typing import Dict, Type | ||
|
@@ -11,7 +12,7 @@ | |
from sagemaker.serve.spec.inference_spec import InferenceSpec | ||
from sagemaker.serve.builder.schema_builder import SchemaBuilder | ||
from sagemaker.serve.utils.types import ModelServer | ||
from sagemaker.serve.utils.exceptions import LocalDeepPingException | ||
from sagemaker.serve.utils.exceptions import InProcessDeepPingException | ||
from sagemaker.serve.model_server.multi_model_server.server import InProcessMultiModelServer | ||
from sagemaker.session import Session | ||
|
||
|
@@ -46,7 +47,7 @@ def __init__( | |
self.session = session | ||
self.schema_builder = schema_builder | ||
self.model_server = model_server | ||
self._ping_container = None | ||
self._ping_local_server = None | ||
|
||
def load(self, model_path: str = None): | ||
"""Loads model path, checks that path exists""" | ||
|
@@ -69,21 +70,30 @@ def create_server( | |
logger.info("Waiting for model server %s to start up...", self.model_server) | ||
|
||
if self.model_server == ModelServer.MMS: | ||
self._ping_container = self._multi_model_server_deep_ping | ||
self._ping_local_server = self._multi_model_server_deep_ping | ||
self._start_serving() | ||
|
||
time_limit = datetime.now() + timedelta(seconds=5) | ||
while self._ping_container is not None: | ||
final_pull = datetime.now() > time_limit | ||
# allow some time for server to be ready. | ||
time.sleep(1) | ||
|
||
count = 1 | ||
time_limit = datetime.now() + timedelta(seconds=20) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will fix, thank you |
||
healthy = True | ||
while True: | ||
final_pull = datetime.now() > time_limit | ||
if final_pull: | ||
break | ||
|
||
time.sleep(10) | ||
|
||
healthy, response = self._ping_container(predictor) | ||
healthy, response = self._ping_local_server(predictor) | ||
count += 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: this was for debugging. we can remove it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed |
||
if healthy: | ||
logger.debug("Ping health check has passed. Returned %s", str(response)) | ||
break | ||
|
||
time.sleep(1) | ||
|
||
if not healthy: | ||
raise LocalDeepPingException(_PING_HEALTH_CHECK_FAIL_MSG) | ||
raise InProcessDeepPingException(_PING_HEALTH_CHECK_FAIL_MSG) | ||
|
||
def destroy_server(self): | ||
self._stop_serving() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,12 +2,15 @@ | |
|
||
from __future__ import absolute_import | ||
|
||
import json | ||
|
||
import requests | ||
import logging | ||
import platform | ||
from pathlib import Path | ||
|
||
from sagemaker import Session, fw_utils | ||
from sagemaker.serve.utils.exceptions import LocalModelInvocationException | ||
from sagemaker.serve.utils.exceptions import InProcessDeepPingException | ||
from sagemaker.base_predictor import PredictorBase | ||
from sagemaker.s3_utils import determine_bucket_and_prefix, parse_s3_url, s3_path_join | ||
from sagemaker.s3 import S3Uploader | ||
|
@@ -25,16 +28,50 @@ class InProcessMultiModelServer: | |
|
||
def _start_serving(self): | ||
"""Initializes the start of the server""" | ||
return Exception("Not implemented") | ||
from sagemaker.serve.app import InProcessServer | ||
if hasattr(self, "inference_spec"): | ||
model_id = self.inference_spec.get_model() | ||
else: | ||
model_id = None | ||
self.server = InProcessServer(model_id=model_id) | ||
|
||
def _invoke_multi_model_server_serving(self, request: object, content_type: str, accept: str): | ||
"""Invokes the MMS server by sending POST request""" | ||
return Exception("Not implemented") | ||
self.server.start_server() | ||
|
||
def _stop_serving(self): | ||
"""Stops the server""" | ||
self.server.stop_server() | ||
|
||
def _invoke_multi_model_server_serving(self, request: bytes, content_type: str, accept: str): | ||
"""Placeholder docstring""" | ||
try: | ||
response = requests.post( | ||
f"http://{self.server.host}:{self.server.port}/generate", | ||
data=request, | ||
headers={"Content-Type": content_type, "Accept": accept}, | ||
timeout=600, | ||
) | ||
response.raise_for_status() | ||
if isinstance(response.content, bytes): | ||
return json.loads(response.content.decode('utf-8')) | ||
return response.content | ||
except Exception as e: | ||
if not "Connection refused" in str(e): | ||
raise Exception("Unable to send request to the local server: Connection refused.") from e | ||
raise Exception("Unable to send request to the local server.") from e | ||
|
||
def _multi_model_server_deep_ping(self, predictor: PredictorBase): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't this actually ping the server similar to https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/serve/model_server/tei/server.py#L88? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good suggestion, I will tweak the code |
||
"""Sends a deep ping to ensure prediction""" | ||
healthy = False | ||
response = None | ||
return (True, response) | ||
try: | ||
response = predictor.predict(self.schema_builder.sample_input) | ||
healthy = response is not None | ||
# pylint: disable=broad-except | ||
except Exception as e: | ||
if "422 Client Error: Unprocessable Entity for url" in str(e): | ||
raise InProcessDeepPingException(str(e)) | ||
|
||
return healthy, response | ||
|
||
|
||
class LocalMultiModelServer: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these needed? Doesn't look like we import them anywhere?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They're imported in the app.py file