Skip to content

feat: FastAPI integration for In_Process Mode (2/2) #4808

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 90 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from 86 commits
Commits
Show all changes
90 commits
Select commit Hold shift + click to select a range
2cc906b
InferenceSpec support for HF
Jun 26, 2024
b25295a
Merge branch 'aws:master' into hf-inf-spec-support
bryannahm1 Jun 27, 2024
fb28458
feat: InferenceSpec support for MMS and testing
Jun 27, 2024
3576ea9
Introduce changes for InProcess Mode
Jun 29, 2024
d3b8e9b
mb_inprocess updates
Jul 3, 2024
68cede1
In_Process mode for TGI transformers, edits
Jul 8, 2024
02e54ef
Remove InfSpec from branch
Jul 8, 2024
f39cca6
merge from master for inf spec
Jul 12, 2024
cc0ca14
changes to support in_process
Jul 13, 2024
18fc3f2
changes to get pre-checks passing
Jul 15, 2024
495c7b4
pylint fix
Jul 15, 2024
1121f47
unit test, test mb
Jul 15, 2024
b6062a7
period missing, added
Jul 15, 2024
1ec209c
suggestions and test added
Jul 16, 2024
ca6c818
pre-push fix
Jul 16, 2024
cd3dbaa
missing an @
Jul 16, 2024
f52f36c
fixes to test, added stubbing
Jul 17, 2024
1843210
removing for fixes
Jul 17, 2024
d0fe3ac
variable fixes
Jul 17, 2024
1b93244
init fix
Jul 17, 2024
b40f36c
tests for in process mode
Jul 18, 2024
68000e1
prepush fix
Jul 18, 2024
499063d
FastAPI with In_Process
Jul 18, 2024
de054b5
minor fixes
Jul 18, 2024
511df55
putting 6 args
Jul 19, 2024
7e17631
server change
Jul 23, 2024
6f34c61
changes
Jul 23, 2024
de9a36c
format fix
Jul 23, 2024
adb9531
fastapi fixes
Jul 24, 2024
5a371b9
port
Jul 24, 2024
5ccfadd
start
Jul 24, 2024
563d397
changing port
Jul 28, 2024
acc3cbc
prepush
Jul 28, 2024
cc236e3
import and unused fix
Jul 28, 2024
b16b227
moving files and transformers
Jul 31, 2024
faca933
fix imports
Aug 1, 2024
d7366ab
changes
Aug 1, 2024
8eda8b5
fixing modules
Aug 1, 2024
8302586
placement
Aug 1, 2024
04887c6
dep fixes
Aug 2, 2024
ff4f62d
Merge branch 'master' into fastapi-inprocess
bryannahm1 Aug 6, 2024
96ce9d0
minor change
Aug 6, 2024
7939237
Merge branch 'fastapi-inprocess' of https://github.com/bryannahm1/sag…
Aug 6, 2024
4e6bd26
fastapi predictor fix
Aug 6, 2024
282c1ab
minor changes
Aug 7, 2024
87f9de9
import transformers
Aug 7, 2024
1962132
pylint comment
Aug 7, 2024
f3bf3c3
delete local_run.sh
Aug 7, 2024
d928254
format
Aug 7, 2024
50db803
fix
Aug 8, 2024
8d4f06c
uvicorn fix
Aug 8, 2024
21f99b5
fastapi
Aug 8, 2024
2e14cd3
try and except
Aug 8, 2024
00b19e7
Merge branch 'master' into fastapi-inprocess
sage-maker Aug 8, 2024
182535a
Merge branch 'master' into fastapi-inprocess
sage-maker Aug 8, 2024
a4c5e4f
app
Aug 8, 2024
8c2f919
merge
Aug 8, 2024
1783bae
Merge branch 'master' into fastapi-inprocess
sage-maker Aug 8, 2024
b414b2b
Merge branch 'master' into fastapi-inprocess
sage-maker Aug 8, 2024
5a2328c
Merge branch 'master' into fastapi-inprocess
sage-maker Aug 9, 2024
62e89ec
deps comment out
Aug 12, 2024
ebc465e
Merge branch 'master' into fastapi-inprocess
bryannahm1 Aug 12, 2024
a356dad
app func fix
Aug 12, 2024
e2ae197
merge
Aug 12, 2024
442686f
deps
Aug 12, 2024
8dd3468
fix
Aug 12, 2024
31e22ff
test fix
Aug 12, 2024
890ccd4
non object
Aug 12, 2024
8237b7b
comment out test for in_process
Aug 12, 2024
5241022
removing unnecessary loggers
Aug 13, 2024
6960d5f
fixing UT
Aug 13, 2024
0821e14
clean up loggers
Aug 14, 2024
7fb02cf
Update in_process_mode.py
makungaj1 Aug 14, 2024
e89940e
In Process Mode
Aug 15, 2024
5570442
workers flags not needed when reload is set to true.
Aug 15, 2024
442704f
Refactore
Aug 15, 2024
fea86b9
Merge pull request #1 from makungaj1/patch-1
makungaj1 Aug 15, 2024
84f1808
wheel
Aug 15, 2024
a0f43b3
minor changes
Aug 16, 2024
1e50b3a
merge
Aug 16, 2024
009a90e
delete whl
Aug 16, 2024
ae79df1
Merge branch 'master' into fastapi-inprocess
makungaj1 Aug 16, 2024
ebfc3e2
ut for app.py
Aug 20, 2024
c83e0ce
unit test fir app
Aug 21, 2024
6d023b5
merge
Aug 21, 2024
a83664b
more unit tests
Aug 21, 2024
88ee686
py310 and higher
Aug 21, 2024
69304c2
Merge branch 'master' into fastapi-inprocess
benieric Aug 22, 2024
7b38f76
delete whl file
Aug 22, 2024
b62d097
Merge branch 'fastapi-inprocess' of https://github.com/bryannahm1/sag…
Aug 22, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions requirements/extras/huggingface_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
accelerate>=0.24.1,<=0.27.0
sagemaker_schema_inference_artifacts>=0.0.5
uvicorn>=0.30.1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these needed? Doesn't look like we import them anywhere?

Copy link
Contributor Author

@bryannahm1 bryannahm1 Aug 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They're imported in the app.py file

fastapi>=0.111.0
nest-asyncio
3 changes: 3 additions & 0 deletions requirements/extras/test_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,6 @@ schema==0.7.5
tensorflow>=2.1,<=2.16
mlflow>=2.12.2,<2.13
huggingface_hub>=0.23.4
uvicorn>=0.30.1
fastapi>=0.111.0
nest-asyncio
Binary file added sagemaker-2.228.1.dev0-py3-none-any.whl
Binary file not shown.
100 changes: 100 additions & 0 deletions src/sagemaker/serve/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""FastAPI requests"""

from __future__ import absolute_import

import asyncio
import logging
import threading
from typing import Optional


logger = logging.getLogger(__name__)


try:
import uvicorn
except ImportError:
logger.error("Unable to import uvicorn, check if uvicorn is installed.")


try:
from transformers import pipeline
except ImportError:
logger.error("Unable to import transformers, check if transformers is installed.")


try:
from fastapi import FastAPI, Request, APIRouter
except ImportError:
logger.error("Unable to import fastapi, check if fastapi is installed.")


class InProcessServer:
"""Placeholder docstring"""

def __init__(self, model_id: Optional[str] = None, task: Optional[str] = None):
self._thread = None
self._loop = None
self._stop_event = asyncio.Event()
self._router = APIRouter()
self._model_id = model_id
self._task = task
self.server = None
self.port = None
self.host = None
# TODO: Pick up device automatically.
self._generator = pipeline(task, model=model_id, device="cpu")

# pylint: disable=unused-variable
@self._router.post("/generate")
async def generate_text(prompt: Request):
"""Placeholder docstring"""
str_prompt = await prompt.json()
str_prompt = str_prompt["inputs"] if "inputs" in str_prompt else str_prompt

generated_text = self._generator(
str_prompt, max_length=30, num_return_sequences=1, truncation=True
)
return generated_text

self._create_server()

def _create_server(self):
"""Placeholder docstring"""
app = FastAPI()
app.include_router(self._router)

config = uvicorn.Config(
app,
host="127.0.0.1",
port=9007,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

q: can we let port be dynamic?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are researching this, will be an action item.

log_level="info",
loop="asyncio",
reload=True,
use_colors=True,
)

self.server = uvicorn.Server(config)
self.host = config.host
self.port = config.port

def start_server(self):
"""Starts the uvicorn server."""
if not (self._thread and self._thread.is_alive()):
logger.info("Waiting for a connection...")
self._thread = threading.Thread(target=self._start_run_async_in_thread, daemon=True)
self._thread.start()

def stop_server(self):
"""Destroys the uvicorn server."""
# TODO: Implement me.

def _start_run_async_in_thread(self):
"""Placeholder docstring"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(self._serve())

async def _serve(self):
"""Placeholder docstring"""
await self.server.serve()
2 changes: 1 addition & 1 deletion src/sagemaker/serve/builder/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,7 @@ def _initialize_for_mlflow(self, artifact_path: str) -> None:
self.dependencies.update({"requirements": mlflow_model_dependency_path})

# Model Builder is a class to build the model for deployment.
# It supports two* modes of deployment
# It supports three modes of deployment
# 1/ SageMaker Endpoint
# 2/ Local launch with container
# 3/ In process mode with Transformers server in beta release
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/serve/builder/requirements_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def capture_and_install_dependencies(self, dependencies: Optional[str] = None) -
Returns:
file path of the existing or generated dependencies file
"""
_dependencies = dependencies or self._detect_conda_env_and_local_dependencies()
_dependencies = dependencies or self._detect_conda_env_and_local_dependencies

# Dependencies specified as either req.txt or conda_env.yml
if _dependencies.endswith(".txt"):
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/serve/builder/transformers_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,6 @@ def _create_conda_env(self):
"""Creating conda environment by running commands"""

try:
RequirementsManager().capture_and_install_dependencies(self)
RequirementsManager().capture_and_install_dependencies
except subprocess.CalledProcessError:
print("Failed to create and activate conda environment.")
27 changes: 18 additions & 9 deletions src/sagemaker/serve/mode/in_process_mode.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Module that defines the InProcessMode class"""

from __future__ import absolute_import

from pathlib import Path
import logging
from typing import Dict, Type
Expand All @@ -11,7 +12,7 @@
from sagemaker.serve.spec.inference_spec import InferenceSpec
from sagemaker.serve.builder.schema_builder import SchemaBuilder
from sagemaker.serve.utils.types import ModelServer
from sagemaker.serve.utils.exceptions import LocalDeepPingException
from sagemaker.serve.utils.exceptions import InProcessDeepPingException
from sagemaker.serve.model_server.multi_model_server.server import InProcessMultiModelServer
from sagemaker.session import Session

Expand Down Expand Up @@ -46,7 +47,7 @@ def __init__(
self.session = session
self.schema_builder = schema_builder
self.model_server = model_server
self._ping_container = None
self._ping_local_server = None

def load(self, model_path: str = None):
"""Loads model path, checks that path exists"""
Expand All @@ -69,21 +70,29 @@ def create_server(
logger.info("Waiting for model server %s to start up...", self.model_server)

if self.model_server == ModelServer.MMS:
self._ping_container = self._multi_model_server_deep_ping
self._ping_local_server = self._multi_model_server_deep_ping
self._start_serving()

# allow some time for server to be ready.
time.sleep(1)

time_limit = datetime.now() + timedelta(seconds=5)
while self._ping_container is not None:
healthy = True
while True:
final_pull = datetime.now() > time_limit

if final_pull:
break

time.sleep(10)

healthy, response = self._ping_container(predictor)
healthy, response = self._ping_local_server(predictor)
if healthy:
logger.debug("Ping health check has passed. Returned %s", str(response))
break

time.sleep(1)

if not healthy:
raise LocalDeepPingException(_PING_HEALTH_CHECK_FAIL_MSG)
raise InProcessDeepPingException(_PING_HEALTH_CHECK_FAIL_MSG)

def destroy_server(self):
"""Placeholder docstring"""
self._stop_serving()
53 changes: 48 additions & 5 deletions src/sagemaker/serve/model_server/multi_model_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@

from __future__ import absolute_import

import json

import requests
import logging
import platform
from pathlib import Path

from sagemaker import Session, fw_utils
from sagemaker.serve.utils.exceptions import LocalModelInvocationException
from sagemaker.serve.utils.exceptions import InProcessDeepPingException
from sagemaker.base_predictor import PredictorBase
from sagemaker.s3_utils import determine_bucket_and_prefix, parse_s3_url, s3_path_join
from sagemaker.s3 import S3Uploader
Expand All @@ -25,16 +29,55 @@ class InProcessMultiModelServer:

def _start_serving(self):
"""Initializes the start of the server"""
return Exception("Not implemented")
from sagemaker.serve.app import InProcessServer

def _invoke_multi_model_server_serving(self, request: object, content_type: str, accept: str):
"""Invokes the MMS server by sending POST request"""
return Exception("Not implemented")
if hasattr(self, "inference_spec"):
model_id = self.inference_spec.get_model()
if not model_id:
raise ValueError("Model id was not provided in Inference Spec.")
else:
model_id = None
self.server = InProcessServer(model_id=model_id)

self.server.start_server()

def _stop_serving(self):
"""Stops the server"""
self.server.stop_server()

def _invoke_multi_model_server_serving(self, request: bytes, content_type: str, accept: str):
"""Placeholder docstring"""
try:
response = requests.post(
f"http://{self.server.host}:{self.server.port}/generate",
data=request,
headers={"Content-Type": content_type, "Accept": accept},
timeout=600,
)
response.raise_for_status()
if isinstance(response.content, bytes):
return json.loads(response.content.decode("utf-8"))
return response.content
except Exception as e:
if "Connection refused" in str(e):
raise Exception(
"Unable to send request to the local server: Connection refused."
) from e
raise Exception("Unable to send request to the local server.") from e

def _multi_model_server_deep_ping(self, predictor: PredictorBase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion, I will tweak the code

"""Sends a deep ping to ensure prediction"""
healthy = False
response = None
return (True, response)
try:
response = predictor.predict(self.schema_builder.sample_input)
healthy = response is not None
# pylint: disable=broad-except
except Exception as e:
if "422 Client Error: Unprocessable Entity for url" in str(e):
raise InProcessDeepPingException(str(e))

return healthy, response


class LocalMultiModelServer:
Expand Down
23 changes: 9 additions & 14 deletions src/sagemaker/serve/utils/predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import absolute_import
import io
from typing import Type

import logging
from sagemaker import Session
from sagemaker.serve.mode.local_container_mode import LocalContainerMode
from sagemaker.serve.mode.in_process_mode import InProcessMode
Expand All @@ -16,6 +16,8 @@

APPLICATION_X_NPY = "application/x-npy"

logger = logging.getLogger(__name__)


class TorchServeLocalPredictor(PredictorBase):
"""Lightweight predictor for local deployment in IN_PROCESS and LOCAL_CONTAINER modes"""
Expand Down Expand Up @@ -211,7 +213,7 @@ def delete_predictor(self):


class TransformersInProcessModePredictor(PredictorBase):
"""Lightweight Transformers predictor for local deployment"""
"""Lightweight Transformers predictor for in process mode deployment"""

def __init__(
self,
Expand All @@ -225,18 +227,11 @@ def __init__(

def predict(self, data):
"""Placeholder docstring"""
return [
self.deserializer.deserialize(
io.BytesIO(
self._mode_obj._invoke_multi_model_server_serving(
self.serializer.serialize(data),
self.content_type,
self.deserializer.ACCEPT[0],
)
),
self.content_type,
)
]
return self._mode_obj._invoke_multi_model_server_serving(
self.serializer.serialize(data),
self.content_type,
self.deserializer.ACCEPT[0],
)

@property
def content_type(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class TestRequirementsManager(unittest.TestCase):
@patch(
"sagemaker.serve.builder.requirements_manager.RequirementsManager._detect_conda_env_and_local_dependencies"
)
def test_capture_and_install_dependencies(
def test_capture_and_install_dependencies_txt(
self,
mock_detect_conda_env_and_local_dependencies,
mock_install_requirements_txt,
Expand All @@ -40,8 +40,7 @@ def test_capture_and_install_dependencies(
RequirementsManager().capture_and_install_dependencies()
mock_install_requirements_txt.assert_called_once()

mock_detect_conda_env_and_local_dependencies.side_effect = lambda: ".yml"
RequirementsManager().capture_and_install_dependencies()
RequirementsManager().capture_and_install_dependencies("conda.yml")
mock_update_conda_env_in_path.assert_called_once()

@patch(
Expand Down
Loading
Loading