Skip to content

feat: FastAPI integration for In_Process Mode (2/2) #4808

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 90 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from 55 commits
Commits
Show all changes
90 commits
Select commit Hold shift + click to select a range
2cc906b
InferenceSpec support for HF
Jun 26, 2024
b25295a
Merge branch 'aws:master' into hf-inf-spec-support
bryannahm1 Jun 27, 2024
fb28458
feat: InferenceSpec support for MMS and testing
Jun 27, 2024
3576ea9
Introduce changes for InProcess Mode
Jun 29, 2024
d3b8e9b
mb_inprocess updates
Jul 3, 2024
68cede1
In_Process mode for TGI transformers, edits
Jul 8, 2024
02e54ef
Remove InfSpec from branch
Jul 8, 2024
f39cca6
merge from master for inf spec
Jul 12, 2024
cc0ca14
changes to support in_process
Jul 13, 2024
18fc3f2
changes to get pre-checks passing
Jul 15, 2024
495c7b4
pylint fix
Jul 15, 2024
1121f47
unit test, test mb
Jul 15, 2024
b6062a7
period missing, added
Jul 15, 2024
1ec209c
suggestions and test added
Jul 16, 2024
ca6c818
pre-push fix
Jul 16, 2024
cd3dbaa
missing an @
Jul 16, 2024
f52f36c
fixes to test, added stubbing
Jul 17, 2024
1843210
removing for fixes
Jul 17, 2024
d0fe3ac
variable fixes
Jul 17, 2024
1b93244
init fix
Jul 17, 2024
b40f36c
tests for in process mode
Jul 18, 2024
68000e1
prepush fix
Jul 18, 2024
499063d
FastAPI with In_Process
Jul 18, 2024
de054b5
minor fixes
Jul 18, 2024
511df55
putting 6 args
Jul 19, 2024
7e17631
server change
Jul 23, 2024
6f34c61
changes
Jul 23, 2024
de9a36c
format fix
Jul 23, 2024
adb9531
fastapi fixes
Jul 24, 2024
5a371b9
port
Jul 24, 2024
5ccfadd
start
Jul 24, 2024
563d397
changing port
Jul 28, 2024
acc3cbc
prepush
Jul 28, 2024
cc236e3
import and unused fix
Jul 28, 2024
b16b227
moving files and transformers
Jul 31, 2024
faca933
fix imports
Aug 1, 2024
d7366ab
changes
Aug 1, 2024
8eda8b5
fixing modules
Aug 1, 2024
8302586
placement
Aug 1, 2024
04887c6
dep fixes
Aug 2, 2024
ff4f62d
Merge branch 'master' into fastapi-inprocess
bryannahm1 Aug 6, 2024
96ce9d0
minor change
Aug 6, 2024
7939237
Merge branch 'fastapi-inprocess' of https://github.com/bryannahm1/sag…
Aug 6, 2024
4e6bd26
fastapi predictor fix
Aug 6, 2024
282c1ab
minor changes
Aug 7, 2024
87f9de9
import transformers
Aug 7, 2024
1962132
pylint comment
Aug 7, 2024
f3bf3c3
delete local_run.sh
Aug 7, 2024
d928254
format
Aug 7, 2024
50db803
fix
Aug 8, 2024
8d4f06c
uvicorn fix
Aug 8, 2024
21f99b5
fastapi
Aug 8, 2024
2e14cd3
try and except
Aug 8, 2024
00b19e7
Merge branch 'master' into fastapi-inprocess
sage-maker Aug 8, 2024
182535a
Merge branch 'master' into fastapi-inprocess
sage-maker Aug 8, 2024
a4c5e4f
app
Aug 8, 2024
8c2f919
merge
Aug 8, 2024
1783bae
Merge branch 'master' into fastapi-inprocess
sage-maker Aug 8, 2024
b414b2b
Merge branch 'master' into fastapi-inprocess
sage-maker Aug 8, 2024
5a2328c
Merge branch 'master' into fastapi-inprocess
sage-maker Aug 9, 2024
62e89ec
deps comment out
Aug 12, 2024
ebc465e
Merge branch 'master' into fastapi-inprocess
bryannahm1 Aug 12, 2024
a356dad
app func fix
Aug 12, 2024
e2ae197
merge
Aug 12, 2024
442686f
deps
Aug 12, 2024
8dd3468
fix
Aug 12, 2024
31e22ff
test fix
Aug 12, 2024
890ccd4
non object
Aug 12, 2024
8237b7b
comment out test for in_process
Aug 12, 2024
5241022
removing unnecessary loggers
Aug 13, 2024
6960d5f
fixing UT
Aug 13, 2024
0821e14
clean up loggers
Aug 14, 2024
7fb02cf
Update in_process_mode.py
makungaj1 Aug 14, 2024
e89940e
In Process Mode
Aug 15, 2024
5570442
workers flags not needed when reload is set to true.
Aug 15, 2024
442704f
Refactore
Aug 15, 2024
fea86b9
Merge pull request #1 from makungaj1/patch-1
makungaj1 Aug 15, 2024
84f1808
wheel
Aug 15, 2024
a0f43b3
minor changes
Aug 16, 2024
1e50b3a
merge
Aug 16, 2024
009a90e
delete whl
Aug 16, 2024
ae79df1
Merge branch 'master' into fastapi-inprocess
makungaj1 Aug 16, 2024
ebfc3e2
ut for app.py
Aug 20, 2024
c83e0ce
unit test fir app
Aug 21, 2024
6d023b5
merge
Aug 21, 2024
a83664b
more unit tests
Aug 21, 2024
88ee686
py310 and higher
Aug 21, 2024
69304c2
Merge branch 'master' into fastapi-inprocess
benieric Aug 22, 2024
7b38f76
delete whl file
Aug 22, 2024
b62d097
Merge branch 'fastapi-inprocess' of https://github.com/bryannahm1/sag…
Aug 22, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions requirements/extras/huggingface_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
accelerate>=0.24.1,<=0.27.0
sagemaker_schema_inference_artifacts>=0.0.5
uvicorn>=0.30.1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these needed? Doesn't look like we import them anywhere?

Copy link
Contributor Author

@bryannahm1 bryannahm1 Aug 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They're imported in the app.py file

fastapi>=0.111.0
nest-asyncio
3 changes: 3 additions & 0 deletions requirements/extras/test_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,6 @@ schema==0.7.5
tensorflow>=2.1,<=2.16
mlflow>=2.12.2,<2.13
huggingface_hub>=0.23.4
uvicorn>=0.30.1
fastapi>=0.111.0
nest-asyncio
79 changes: 79 additions & 0 deletions src/sagemaker/serve/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""FastAPI requests"""

from __future__ import absolute_import

import logging

logger = logging.getLogger(__name__)

try:
from fastapi import FastAPI, Request

app = FastAPI(
title="Transformers In Process Server",
version="1.0",
description="A simple server",
)

@app.get("/")
def read_root():
"""Placeholder docstring"""
return {"Hello": "World"}

@app.get("/generate")
async def generate_text(prompt: Request):
"""Placeholder docstring"""
logger.info("Generating Text....")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: to keep user experience clean. Let's reduce logging. Let's remove these two logs in async def generate_text(prompt: Request): method.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, I will remove.


str_prompt = await prompt.json()

logger.info(str_prompt)

generated_text = generator(
str_prompt, max_length=30, num_return_sequences=5, truncation=True
)
return generated_text[0]["generated_text"]

except ImportError:
logger.error("To enable in_process mode for Transformers install fastapi from HuggingFace hub")


try:
from transformers import pipeline

generator = pipeline("text-generation", model="gpt2")

except ImportError:
logger.error(
"To enable in_process mode for Transformers install transformers from HuggingFace hub"
)

try:
import uvicorn

except ImportError:
logger.error("To enable in_process mode for Transformers install uvicorn from HuggingFace hub")


@app.post("/post")
def post(payload: dict):
"""Placeholder docstring"""
return payload


async def main():
"""Running server locally with uvicorn"""
logger.info("Running")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: From here the application isn't running yet. Can we get rid of this log?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

config = uvicorn.Config(
"sagemaker.app:app",
host="127.0.0.1",
port=9007,
log_level="info",
loop="asyncio",
reload=True,
workers=3,
use_colors=True,
)
server = uvicorn.Server(config)
logger.info("Waiting for a connection...")
await server.serve()
30 changes: 25 additions & 5 deletions src/sagemaker/serve/builder/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from sagemaker.serve.mode.function_pointers import Mode
from sagemaker.serve.mode.sagemaker_endpoint_mode import SageMakerEndpointMode
from sagemaker.serve.mode.local_container_mode import LocalContainerMode
from sagemaker.serve.mode.in_process_mode import InProcessMode
from sagemaker.serve.detector.pickler import save_pkl, save_xgboost
from sagemaker.serve.builder.serve_settings import _ServeSettings
from sagemaker.serve.builder.djl_builder import DJL
Expand Down Expand Up @@ -410,7 +411,7 @@ def _prepare_for_mode(
)
self.env_vars.update(env_vars_sagemaker)
return self.s3_upload_path, env_vars_sagemaker
if self.mode == Mode.LOCAL_CONTAINER:
elif self.mode == Mode.LOCAL_CONTAINER:
# init the LocalContainerMode object
self.modes[str(Mode.LOCAL_CONTAINER)] = LocalContainerMode(
inference_spec=self.inference_spec,
Expand All @@ -422,9 +423,22 @@ def _prepare_for_mode(
)
self.modes[str(Mode.LOCAL_CONTAINER)].prepare()
return None
elif self.mode == Mode.IN_PROCESS:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need the in process changes here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is a PR that is going along with #4784, which introduces In_Process_mode in a broader scope. This is PR 2/2 which combines In_Process + FastAPI.

# init the InProcessMode object
self.modes[str(Mode.IN_PROCESS)] = InProcessMode(
inference_spec=self.inference_spec,
schema_builder=self.schema_builder,
session=self.sagemaker_session,
model_path=self.model_path,
env_vars=self.env_vars,
model_server=self.model_server,
)
self.modes[str(Mode.IN_PROCESS)].prepare()
return None

raise ValueError(
"Please specify mode in: %s, %s" % (Mode.LOCAL_CONTAINER, Mode.SAGEMAKER_ENDPOINT)
"Please specify mode in: %s, %s, %s"
% (Mode.LOCAL_CONTAINER, Mode.SAGEMAKER_ENDPOINT, Mode.IN_PROCESS)
)

def _get_client_translators(self):
Expand Down Expand Up @@ -606,6 +620,9 @@ def _overwrite_mode_in_deploy(self, overwrite_mode: str):
elif overwrite_mode == Mode.LOCAL_CONTAINER:
self.mode = self.pysdk_model.mode = Mode.LOCAL_CONTAINER
self._prepare_for_mode()
elif overwrite_mode == Mode.IN_PROCESS:
self.mode = self.pysdk_model.mode = Mode.IN_PROCESS
self._prepare_for_mode()
else:
raise ValueError("Mode %s is not supported!" % overwrite_mode)

Expand Down Expand Up @@ -795,9 +812,10 @@ def _initialize_for_mlflow(self, artifact_path: str) -> None:
self.dependencies.update({"requirements": mlflow_model_dependency_path})

# Model Builder is a class to build the model for deployment.
# It supports two modes of deployment
# It supports three modes of deployment
# 1/ SageMaker Endpoint
# 2/ Local launch with container
# 3/ In process mode with Transformers server in beta release
def build( # pylint: disable=R0911
self,
mode: Type[Mode] = None,
Expand Down Expand Up @@ -895,8 +913,10 @@ def build( # pylint: disable=R0911

def _build_validations(self):
"""Validations needed for model server overrides, or auto-detection or fallback"""
if self.mode == Mode.IN_PROCESS:
raise ValueError("IN_PROCESS mode is not supported yet!")
if self.mode == Mode.IN_PROCESS and self.model_server is not ModelServer.MMS:
raise ValueError(
"IN_PROCESS mode is only supported for MMS/Transformers server in beta release."
)

if self.inference_spec and self.model:
raise ValueError("Can only set one of the following: model, inference_spec.")
Expand Down
20 changes: 18 additions & 2 deletions src/sagemaker/serve/builder/transformers_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@
)
from sagemaker.serve.detector.pickler import save_pkl
from sagemaker.serve.utils.optimize_utils import _is_optimized
from sagemaker.serve.utils.predictors import TransformersLocalModePredictor
from sagemaker.serve.utils.predictors import (
TransformersLocalModePredictor,
TransformersInProcessModePredictor,
)
from sagemaker.serve.utils.types import ModelServer
from sagemaker.serve.mode.function_pointers import Mode
from sagemaker.serve.utils.telemetry_logger import _capture_telemetry
Expand All @@ -47,6 +50,7 @@

logger = logging.getLogger(__name__)
DEFAULT_TIMEOUT = 1800
LOCAL_MODES = [Mode.LOCAL_CONTAINER, Mode.IN_PROCESS]


"""Retrieves images for different libraries - Pytorch, TensorFlow from HuggingFace hub
Expand Down Expand Up @@ -228,6 +232,18 @@ def _transformers_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[Pr
)
return predictor

if self.mode == Mode.IN_PROCESS:
timeout = kwargs.get("model_data_download_timeout")

predictor = TransformersInProcessModePredictor(
self.modes[str(Mode.IN_PROCESS)], serializer, deserializer
)

self.modes[str(Mode.IN_PROCESS)].create_server(
predictor,
)
return predictor

self._set_instance(kwargs)

if "mode" in kwargs:
Expand Down Expand Up @@ -293,7 +309,7 @@ def _build_transformers_env(self):

self.pysdk_model = self._create_transformers_model()

if self.mode == Mode.LOCAL_CONTAINER:
if self.mode in LOCAL_MODES:
self._prepare_for_mode()

return self.pysdk_model
Expand Down
90 changes: 90 additions & 0 deletions src/sagemaker/serve/mode/in_process_mode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""Module that defines the InProcessMode class"""

from __future__ import absolute_import
from pathlib import Path
import logging
from typing import Dict, Type
import time
from datetime import datetime, timedelta

from sagemaker.base_predictor import PredictorBase
from sagemaker.serve.spec.inference_spec import InferenceSpec
from sagemaker.serve.builder.schema_builder import SchemaBuilder
from sagemaker.serve.utils.types import ModelServer
from sagemaker.serve.utils.exceptions import LocalDeepPingException
from sagemaker.serve.model_server.multi_model_server.server import InProcessMultiModelServer
from sagemaker.session import Session

logger = logging.getLogger(__name__)

_PING_HEALTH_CHECK_FAIL_MSG = (
"Ping health check did not pass. "
+ "Please increase container_timeout_seconds or review your inference code."
)


class InProcessMode(
InProcessMultiModelServer,
):
"""A class that holds methods to deploy model to a container in process environment"""

def __init__(
self,
model_server: ModelServer,
inference_spec: Type[InferenceSpec],
schema_builder: Type[SchemaBuilder],
session: Session,
model_path: str = None,
env_vars: Dict = None,
):
# pylint: disable=bad-super-call
super().__init__()

self.inference_spec = inference_spec
self.model_path = model_path
self.env_vars = env_vars
self.session = session
self.schema_builder = schema_builder
self.model_server = model_server
self._ping_container = None

def load(self, model_path: str = None):
"""Loads model path, checks that path exists"""
path = Path(model_path if model_path else self.model_path)
if not path.exists():
raise ValueError("model_path does not exist")
if not path.is_dir():
raise ValueError("model_path is not a valid directory")

return self.inference_spec.load(str(path))

def prepare(self):
"""Prepares the server"""

def create_server(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be

def create_server(
        self,
        image: str,
        secret_key: str,
        predictor: PredictorBase,
        env_vars: Dict[str, str] = None,
        model_path: str = None,
    ):

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Else it will break the build

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, I will fix

self,
predictor: PredictorBase,
):
"""Creating the server and checking ping health."""
logger.info("Waiting for model server %s to start up...", self.model_server)

if self.model_server == ModelServer.MMS:
self._start_serving()
self._ping_container = self._multi_model_server_deep_ping

time_limit = datetime.now() + timedelta(seconds=5)
while self._ping_container is not None:
final_pull = datetime.now() > time_limit

if final_pull:
break

time.sleep(10)

healthy, response = self._ping_container(predictor)
if healthy:
logger.debug("Ping health check has passed. Returned %s", str(response))
break

if not healthy:
raise LocalDeepPingException(_PING_HEALTH_CHECK_FAIL_MSG)
68 changes: 68 additions & 0 deletions src/sagemaker/serve/model_server/multi_model_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

from __future__ import absolute_import

import asyncio
import requests
import logging
import platform
import time
from pathlib import Path
from sagemaker import Session, fw_utils
from sagemaker.serve.utils.exceptions import LocalModelInvocationException
Expand All @@ -13,13 +15,79 @@
from sagemaker.s3 import S3Uploader
from sagemaker.local.utils import get_docker_host
from sagemaker.serve.utils.optimize_utils import _is_s3_uri
from sagemaker.serve.app import main

MODE_DIR_BINDING = "/opt/ml/model/"
_DEFAULT_ENV_VARS = {}

logger = logging.getLogger(__name__)


class InProcessMultiModelServer:
"""In Process Mode Multi Model server instance"""

def _start_serving(self):
"""Initializes the start of the server"""
background_tasks = set()
task = asyncio.create_task(main())
background_tasks.add(task)
task.add_done_callback(background_tasks.discard)

time.sleep(10)

def _invoke_multi_model_server_serving(self, request: object, content_type: str, accept: str):
"""Placeholder docstring"""
background_tasks = set()
task = asyncio.create_task(self.generate_connect())
background_tasks.add(task)
task.add_done_callback(background_tasks.discard)

def _multi_model_server_deep_ping(self, predictor: PredictorBase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion, I will tweak the code

"""Sends a deep ping to ensure prediction"""
background_tasks = set()
task = asyncio.create_task(self.tcp_connect())
background_tasks.add(task)
task.add_done_callback(background_tasks.discard)
response = None
return True, response

async def generate_connect(self):
"""Writes the lines in bytes for server"""
reader, writer = await asyncio.open_connection("127.0.0.1", 9007)
a = (
b"GET /generate HTTP/1.1\r\nHost: 127.0.0.1:9007\r\nUser-Agent: "
b"python-requests/2.31.0\r\nAccept-Encoding: gzip, deflate, br\r\nAccept: */*\r\nConnection: ",
"keep-alive\r\nContent-Length: 33\r\nContent-Type: application/json\r\n\r\n",
)
b = b'"\\"Hello, I\'m a language model\\""'
list = [a, b]
writer.writelines(list)
logger.debug(writer.get_extra_info("peername"))
logger.debug(writer.transport)

data = await reader.read()
logger.info("Response from server")
logger.info(data)
writer.close()
await writer.wait_closed()

async def tcp_connect(self):
"""Writes the lines in bytes for server"""
reader, writer = await asyncio.open_connection("127.0.0.1", 9007)
writer.write(
b"GET / HTTP/1.1\r\nHost: 127.0.0.1:9007\r\nUser-Agent: python-requests/2.32.3\r\nAccept-Encoding: gzip, ",
"deflate, br\r\nAccept: */*\r\nConnection: keep-alive\r\n\r\n",
)
logger.debug(writer.get_extra_info("peername"))
logger.debug(writer.transport)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these for debugging? If so, do we need them in master?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will remove


data = await reader.read()
logger.info("Response from server")
logger.info(data)
writer.close()
await writer.wait_closed()


class LocalMultiModelServer:
"""Local Multi Model server instance"""

Expand Down
Loading
Loading