-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feat: FastAPI integration for In_Process Mode (2/2) #4808
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 55 commits
2cc906b
b25295a
fb28458
3576ea9
d3b8e9b
68cede1
02e54ef
f39cca6
cc0ca14
18fc3f2
495c7b4
1121f47
b6062a7
1ec209c
ca6c818
cd3dbaa
f52f36c
1843210
d0fe3ac
1b93244
b40f36c
68000e1
499063d
de054b5
511df55
7e17631
6f34c61
de9a36c
adb9531
5a371b9
5ccfadd
563d397
acc3cbc
cc236e3
b16b227
faca933
d7366ab
8eda8b5
8302586
04887c6
ff4f62d
96ce9d0
7939237
4e6bd26
282c1ab
87f9de9
1962132
f3bf3c3
d928254
50db803
8d4f06c
21f99b5
2e14cd3
00b19e7
182535a
a4c5e4f
8c2f919
1783bae
b414b2b
5a2328c
62e89ec
ebc465e
a356dad
e2ae197
442686f
8dd3468
31e22ff
890ccd4
8237b7b
5241022
6960d5f
0821e14
7fb02cf
e89940e
5570442
442704f
fea86b9
84f1808
a0f43b3
1e50b3a
009a90e
ae79df1
ebfc3e2
c83e0ce
6d023b5
a83664b
88ee686
69304c2
7b38f76
b62d097
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,5 @@ | ||
accelerate>=0.24.1,<=0.27.0 | ||
sagemaker_schema_inference_artifacts>=0.0.5 | ||
uvicorn>=0.30.1 | ||
fastapi>=0.111.0 | ||
nest-asyncio |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
"""FastAPI requests""" | ||
|
||
from __future__ import absolute_import | ||
|
||
import logging | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
try: | ||
from fastapi import FastAPI, Request | ||
|
||
app = FastAPI( | ||
title="Transformers In Process Server", | ||
version="1.0", | ||
description="A simple server", | ||
) | ||
|
||
@app.get("/") | ||
def read_root(): | ||
"""Placeholder docstring""" | ||
return {"Hello": "World"} | ||
|
||
@app.get("/generate") | ||
async def generate_text(prompt: Request): | ||
"""Placeholder docstring""" | ||
logger.info("Generating Text....") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: to keep user experience clean. Let's reduce logging. Let's remove these two logs in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you, I will remove. |
||
|
||
str_prompt = await prompt.json() | ||
|
||
logger.info(str_prompt) | ||
|
||
generated_text = generator( | ||
str_prompt, max_length=30, num_return_sequences=5, truncation=True | ||
) | ||
return generated_text[0]["generated_text"] | ||
|
||
except ImportError: | ||
logger.error("To enable in_process mode for Transformers install fastapi from HuggingFace hub") | ||
|
||
|
||
try: | ||
from transformers import pipeline | ||
|
||
generator = pipeline("text-generation", model="gpt2") | ||
|
||
except ImportError: | ||
logger.error( | ||
"To enable in_process mode for Transformers install transformers from HuggingFace hub" | ||
) | ||
|
||
try: | ||
import uvicorn | ||
|
||
except ImportError: | ||
logger.error("To enable in_process mode for Transformers install uvicorn from HuggingFace hub") | ||
|
||
|
||
@app.post("/post") | ||
def post(payload: dict): | ||
"""Placeholder docstring""" | ||
return payload | ||
|
||
|
||
async def main(): | ||
"""Running server locally with uvicorn""" | ||
logger.info("Running") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: From here the application isn't running yet. Can we get rid of this log? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed |
||
config = uvicorn.Config( | ||
"sagemaker.app:app", | ||
host="127.0.0.1", | ||
port=9007, | ||
log_level="info", | ||
loop="asyncio", | ||
reload=True, | ||
workers=3, | ||
use_colors=True, | ||
) | ||
server = uvicorn.Server(config) | ||
logger.info("Waiting for a connection...") | ||
await server.serve() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,6 +36,7 @@ | |
from sagemaker.serve.mode.function_pointers import Mode | ||
from sagemaker.serve.mode.sagemaker_endpoint_mode import SageMakerEndpointMode | ||
from sagemaker.serve.mode.local_container_mode import LocalContainerMode | ||
from sagemaker.serve.mode.in_process_mode import InProcessMode | ||
from sagemaker.serve.detector.pickler import save_pkl, save_xgboost | ||
from sagemaker.serve.builder.serve_settings import _ServeSettings | ||
from sagemaker.serve.builder.djl_builder import DJL | ||
|
@@ -410,7 +411,7 @@ def _prepare_for_mode( | |
) | ||
self.env_vars.update(env_vars_sagemaker) | ||
return self.s3_upload_path, env_vars_sagemaker | ||
if self.mode == Mode.LOCAL_CONTAINER: | ||
elif self.mode == Mode.LOCAL_CONTAINER: | ||
# init the LocalContainerMode object | ||
self.modes[str(Mode.LOCAL_CONTAINER)] = LocalContainerMode( | ||
inference_spec=self.inference_spec, | ||
|
@@ -422,9 +423,22 @@ def _prepare_for_mode( | |
) | ||
self.modes[str(Mode.LOCAL_CONTAINER)].prepare() | ||
return None | ||
elif self.mode == Mode.IN_PROCESS: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you need the in process changes here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this is a PR that is going along with #4784, which introduces In_Process_mode in a broader scope. This is PR 2/2 which combines In_Process + FastAPI. |
||
# init the InProcessMode object | ||
self.modes[str(Mode.IN_PROCESS)] = InProcessMode( | ||
inference_spec=self.inference_spec, | ||
schema_builder=self.schema_builder, | ||
session=self.sagemaker_session, | ||
model_path=self.model_path, | ||
env_vars=self.env_vars, | ||
model_server=self.model_server, | ||
) | ||
self.modes[str(Mode.IN_PROCESS)].prepare() | ||
return None | ||
|
||
raise ValueError( | ||
"Please specify mode in: %s, %s" % (Mode.LOCAL_CONTAINER, Mode.SAGEMAKER_ENDPOINT) | ||
"Please specify mode in: %s, %s, %s" | ||
% (Mode.LOCAL_CONTAINER, Mode.SAGEMAKER_ENDPOINT, Mode.IN_PROCESS) | ||
) | ||
|
||
def _get_client_translators(self): | ||
|
@@ -606,6 +620,9 @@ def _overwrite_mode_in_deploy(self, overwrite_mode: str): | |
elif overwrite_mode == Mode.LOCAL_CONTAINER: | ||
self.mode = self.pysdk_model.mode = Mode.LOCAL_CONTAINER | ||
self._prepare_for_mode() | ||
elif overwrite_mode == Mode.IN_PROCESS: | ||
self.mode = self.pysdk_model.mode = Mode.IN_PROCESS | ||
self._prepare_for_mode() | ||
else: | ||
raise ValueError("Mode %s is not supported!" % overwrite_mode) | ||
|
||
|
@@ -795,9 +812,10 @@ def _initialize_for_mlflow(self, artifact_path: str) -> None: | |
self.dependencies.update({"requirements": mlflow_model_dependency_path}) | ||
|
||
# Model Builder is a class to build the model for deployment. | ||
# It supports two modes of deployment | ||
# It supports three modes of deployment | ||
# 1/ SageMaker Endpoint | ||
# 2/ Local launch with container | ||
# 3/ In process mode with Transformers server in beta release | ||
def build( # pylint: disable=R0911 | ||
self, | ||
mode: Type[Mode] = None, | ||
|
@@ -895,8 +913,10 @@ def build( # pylint: disable=R0911 | |
|
||
def _build_validations(self): | ||
"""Validations needed for model server overrides, or auto-detection or fallback""" | ||
if self.mode == Mode.IN_PROCESS: | ||
raise ValueError("IN_PROCESS mode is not supported yet!") | ||
if self.mode == Mode.IN_PROCESS and self.model_server is not ModelServer.MMS: | ||
raise ValueError( | ||
"IN_PROCESS mode is only supported for MMS/Transformers server in beta release." | ||
) | ||
|
||
if self.inference_spec and self.model: | ||
raise ValueError("Can only set one of the following: model, inference_spec.") | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
"""Module that defines the InProcessMode class""" | ||
|
||
from __future__ import absolute_import | ||
from pathlib import Path | ||
import logging | ||
from typing import Dict, Type | ||
import time | ||
from datetime import datetime, timedelta | ||
|
||
from sagemaker.base_predictor import PredictorBase | ||
from sagemaker.serve.spec.inference_spec import InferenceSpec | ||
from sagemaker.serve.builder.schema_builder import SchemaBuilder | ||
from sagemaker.serve.utils.types import ModelServer | ||
from sagemaker.serve.utils.exceptions import LocalDeepPingException | ||
from sagemaker.serve.model_server.multi_model_server.server import InProcessMultiModelServer | ||
from sagemaker.session import Session | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
_PING_HEALTH_CHECK_FAIL_MSG = ( | ||
"Ping health check did not pass. " | ||
+ "Please increase container_timeout_seconds or review your inference code." | ||
) | ||
|
||
|
||
class InProcessMode( | ||
InProcessMultiModelServer, | ||
): | ||
"""A class that holds methods to deploy model to a container in process environment""" | ||
|
||
def __init__( | ||
self, | ||
model_server: ModelServer, | ||
inference_spec: Type[InferenceSpec], | ||
schema_builder: Type[SchemaBuilder], | ||
session: Session, | ||
model_path: str = None, | ||
env_vars: Dict = None, | ||
): | ||
# pylint: disable=bad-super-call | ||
super().__init__() | ||
|
||
self.inference_spec = inference_spec | ||
self.model_path = model_path | ||
self.env_vars = env_vars | ||
self.session = session | ||
self.schema_builder = schema_builder | ||
self.model_server = model_server | ||
self._ping_container = None | ||
|
||
def load(self, model_path: str = None): | ||
"""Loads model path, checks that path exists""" | ||
path = Path(model_path if model_path else self.model_path) | ||
if not path.exists(): | ||
raise ValueError("model_path does not exist") | ||
if not path.is_dir(): | ||
raise ValueError("model_path is not a valid directory") | ||
|
||
return self.inference_spec.load(str(path)) | ||
|
||
def prepare(self): | ||
"""Prepares the server""" | ||
|
||
def create_server( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This needs to be
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Else it will break the build There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch, I will fix |
||
self, | ||
predictor: PredictorBase, | ||
): | ||
"""Creating the server and checking ping health.""" | ||
logger.info("Waiting for model server %s to start up...", self.model_server) | ||
|
||
if self.model_server == ModelServer.MMS: | ||
self._start_serving() | ||
self._ping_container = self._multi_model_server_deep_ping | ||
|
||
time_limit = datetime.now() + timedelta(seconds=5) | ||
while self._ping_container is not None: | ||
final_pull = datetime.now() > time_limit | ||
|
||
if final_pull: | ||
break | ||
|
||
time.sleep(10) | ||
|
||
healthy, response = self._ping_container(predictor) | ||
if healthy: | ||
logger.debug("Ping health check has passed. Returned %s", str(response)) | ||
break | ||
|
||
if not healthy: | ||
raise LocalDeepPingException(_PING_HEALTH_CHECK_FAIL_MSG) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,9 +2,11 @@ | |
|
||
from __future__ import absolute_import | ||
|
||
import asyncio | ||
import requests | ||
import logging | ||
import platform | ||
import time | ||
from pathlib import Path | ||
from sagemaker import Session, fw_utils | ||
from sagemaker.serve.utils.exceptions import LocalModelInvocationException | ||
|
@@ -13,13 +15,79 @@ | |
from sagemaker.s3 import S3Uploader | ||
from sagemaker.local.utils import get_docker_host | ||
from sagemaker.serve.utils.optimize_utils import _is_s3_uri | ||
from sagemaker.serve.app import main | ||
|
||
MODE_DIR_BINDING = "/opt/ml/model/" | ||
_DEFAULT_ENV_VARS = {} | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class InProcessMultiModelServer: | ||
"""In Process Mode Multi Model server instance""" | ||
|
||
def _start_serving(self): | ||
"""Initializes the start of the server""" | ||
background_tasks = set() | ||
task = asyncio.create_task(main()) | ||
background_tasks.add(task) | ||
task.add_done_callback(background_tasks.discard) | ||
|
||
time.sleep(10) | ||
|
||
def _invoke_multi_model_server_serving(self, request: object, content_type: str, accept: str): | ||
"""Placeholder docstring""" | ||
background_tasks = set() | ||
task = asyncio.create_task(self.generate_connect()) | ||
background_tasks.add(task) | ||
task.add_done_callback(background_tasks.discard) | ||
|
||
def _multi_model_server_deep_ping(self, predictor: PredictorBase): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't this actually ping the server similar to https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/serve/model_server/tei/server.py#L88? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good suggestion, I will tweak the code |
||
"""Sends a deep ping to ensure prediction""" | ||
background_tasks = set() | ||
task = asyncio.create_task(self.tcp_connect()) | ||
background_tasks.add(task) | ||
task.add_done_callback(background_tasks.discard) | ||
response = None | ||
return True, response | ||
|
||
async def generate_connect(self): | ||
"""Writes the lines in bytes for server""" | ||
reader, writer = await asyncio.open_connection("127.0.0.1", 9007) | ||
a = ( | ||
b"GET /generate HTTP/1.1\r\nHost: 127.0.0.1:9007\r\nUser-Agent: " | ||
b"python-requests/2.31.0\r\nAccept-Encoding: gzip, deflate, br\r\nAccept: */*\r\nConnection: ", | ||
"keep-alive\r\nContent-Length: 33\r\nContent-Type: application/json\r\n\r\n", | ||
) | ||
b = b'"\\"Hello, I\'m a language model\\""' | ||
list = [a, b] | ||
writer.writelines(list) | ||
logger.debug(writer.get_extra_info("peername")) | ||
logger.debug(writer.transport) | ||
|
||
data = await reader.read() | ||
logger.info("Response from server") | ||
logger.info(data) | ||
writer.close() | ||
await writer.wait_closed() | ||
|
||
async def tcp_connect(self): | ||
"""Writes the lines in bytes for server""" | ||
reader, writer = await asyncio.open_connection("127.0.0.1", 9007) | ||
writer.write( | ||
b"GET / HTTP/1.1\r\nHost: 127.0.0.1:9007\r\nUser-Agent: python-requests/2.32.3\r\nAccept-Encoding: gzip, ", | ||
"deflate, br\r\nAccept: */*\r\nConnection: keep-alive\r\n\r\n", | ||
) | ||
logger.debug(writer.get_extra_info("peername")) | ||
logger.debug(writer.transport) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are these for debugging? If so, do we need them in master? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will remove |
||
|
||
data = await reader.read() | ||
logger.info("Response from server") | ||
logger.info(data) | ||
writer.close() | ||
await writer.wait_closed() | ||
|
||
|
||
class LocalMultiModelServer: | ||
"""Local Multi Model server instance""" | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these needed? Doesn't look like we import them anywhere?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They're imported in the app.py file