Skip to content

Commit da348e8

Browse files
bryannahm1Bryannah Hernandezsage-makermakungaj1Jonathan Makunga
authored
feat: FastAPI integration for In_Process Mode (2/2) (#4808)
* InferenceSpec support for HF * feat: InferenceSpec support for MMS and testing * Introduce changes for InProcess Mode * mb_inprocess updates * In_Process mode for TGI transformers, edits * Remove InfSpec from branch * changes to support in_process * changes to get pre-checks passing * pylint fix * unit test, test mb * period missing, added * suggestions and test added * pre-push fix * missing an @ * fixes to test, added stubbing * removing for fixes * variable fixes * init fix * tests for in process mode * prepush fix * FastAPI with In_Process * minor fixes * putting 6 args * server change * changes * format fix * fastapi fixes * port * start * changing port * prepush * import and unused fix * moving files and transformers * fix imports * changes * fixing modules * placement * dep fixes * minor change * fastapi predictor fix * minor changes * import transformers * pylint comment * delete local_run.sh * format * fix * uvicorn fix * fastapi * try and except * app * deps comment out * app func fix * deps * fix * test fix * non object * comment out test for in_process * removing unnecessary loggers * fixing UT * clean up loggers * Update in_process_mode.py * In Process Mode * workers flags not needed when reload is set to true. * Refactore * wheel * minor changes * delete whl * ut for app.py * unit test fir app * more unit tests * py310 and higher * delete whl file --------- Co-authored-by: Bryannah Hernandez <[email protected]> Co-authored-by: sage-maker <[email protected]> Co-authored-by: Jonathan Makunga <[email protected]> Co-authored-by: Jonathan Makunga <[email protected]> Co-authored-by: Erick Benitez-Ramos <[email protected]>
1 parent e240518 commit da348e8

File tree

12 files changed

+316
-36
lines changed

12 files changed

+316
-36
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
11
accelerate>=0.24.1,<=0.27.0
22
sagemaker_schema_inference_artifacts>=0.0.5
3+
uvicorn>=0.30.1
4+
fastapi>=0.111.0
5+
nest-asyncio

requirements/extras/test_requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,6 @@ schema==0.7.5
4040
tensorflow>=2.1,<=2.16
4141
mlflow>=2.12.2,<2.13
4242
huggingface_hub>=0.23.4
43+
uvicorn>=0.30.1
44+
fastapi>=0.111.0
45+
nest-asyncio

src/sagemaker/serve/app.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
"""FastAPI requests"""
2+
3+
from __future__ import absolute_import
4+
5+
import asyncio
6+
import logging
7+
import threading
8+
from typing import Optional
9+
10+
11+
logger = logging.getLogger(__name__)
12+
13+
14+
try:
15+
import uvicorn
16+
except ImportError:
17+
logger.error("Unable to import uvicorn, check if uvicorn is installed.")
18+
19+
20+
try:
21+
from transformers import pipeline
22+
except ImportError:
23+
logger.error("Unable to import transformers, check if transformers is installed.")
24+
25+
26+
try:
27+
from fastapi import FastAPI, Request, APIRouter
28+
except ImportError:
29+
logger.error("Unable to import fastapi, check if fastapi is installed.")
30+
31+
32+
class InProcessServer:
33+
"""Placeholder docstring"""
34+
35+
def __init__(self, model_id: Optional[str] = None, task: Optional[str] = None):
36+
self._thread = None
37+
self._loop = None
38+
self._stop_event = asyncio.Event()
39+
self._router = APIRouter()
40+
self._model_id = model_id
41+
self._task = task
42+
self.server = None
43+
self.port = None
44+
self.host = None
45+
# TODO: Pick up device automatically.
46+
self._generator = pipeline(task, model=model_id, device="cpu")
47+
48+
# pylint: disable=unused-variable
49+
@self._router.post("/generate")
50+
async def generate_text(prompt: Request):
51+
"""Placeholder docstring"""
52+
str_prompt = await prompt.json()
53+
str_prompt = str_prompt["inputs"] if "inputs" in str_prompt else str_prompt
54+
55+
generated_text = self._generator(
56+
str_prompt, max_length=30, num_return_sequences=1, truncation=True
57+
)
58+
return generated_text
59+
60+
self._create_server()
61+
62+
def _create_server(self):
63+
"""Placeholder docstring"""
64+
app = FastAPI()
65+
app.include_router(self._router)
66+
67+
config = uvicorn.Config(
68+
app,
69+
host="127.0.0.1",
70+
port=9007,
71+
log_level="info",
72+
loop="asyncio",
73+
reload=True,
74+
use_colors=True,
75+
)
76+
77+
self.server = uvicorn.Server(config)
78+
self.host = config.host
79+
self.port = config.port
80+
81+
def start_server(self):
82+
"""Starts the uvicorn server."""
83+
if not (self._thread and self._thread.is_alive()):
84+
logger.info("Waiting for a connection...")
85+
self._thread = threading.Thread(target=self._start_run_async_in_thread, daemon=True)
86+
self._thread.start()
87+
88+
def stop_server(self):
89+
"""Destroys the uvicorn server."""
90+
# TODO: Implement me.
91+
92+
def _start_run_async_in_thread(self):
93+
"""Placeholder docstring"""
94+
loop = asyncio.new_event_loop()
95+
asyncio.set_event_loop(loop)
96+
loop.run_until_complete(self._serve())
97+
98+
async def _serve(self):
99+
"""Placeholder docstring"""
100+
await self.server.serve()

src/sagemaker/serve/builder/model_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -812,7 +812,7 @@ def _initialize_for_mlflow(self, artifact_path: str) -> None:
812812
self.dependencies.update({"requirements": mlflow_model_dependency_path})
813813

814814
# Model Builder is a class to build the model for deployment.
815-
# It supports two* modes of deployment
815+
# It supports three modes of deployment
816816
# 1/ SageMaker Endpoint
817817
# 2/ Local launch with container
818818
# 3/ In process mode with Transformers server in beta release

src/sagemaker/serve/builder/requirements_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def capture_and_install_dependencies(self, dependencies: Optional[str] = None) -
3636
Returns:
3737
file path of the existing or generated dependencies file
3838
"""
39-
_dependencies = dependencies or self._detect_conda_env_and_local_dependencies()
39+
_dependencies = dependencies or self._detect_conda_env_and_local_dependencies
4040

4141
# Dependencies specified as either req.txt or conda_env.yml
4242
if _dependencies.endswith(".txt"):

src/sagemaker/serve/builder/transformers_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,6 @@ def _create_conda_env(self):
421421
"""Creating conda environment by running commands"""
422422

423423
try:
424-
RequirementsManager().capture_and_install_dependencies(self)
424+
RequirementsManager().capture_and_install_dependencies
425425
except subprocess.CalledProcessError:
426426
print("Failed to create and activate conda environment.")

src/sagemaker/serve/mode/in_process_mode.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Module that defines the InProcessMode class"""
22

33
from __future__ import absolute_import
4+
45
from pathlib import Path
56
import logging
67
from typing import Dict, Type
@@ -11,7 +12,7 @@
1112
from sagemaker.serve.spec.inference_spec import InferenceSpec
1213
from sagemaker.serve.builder.schema_builder import SchemaBuilder
1314
from sagemaker.serve.utils.types import ModelServer
14-
from sagemaker.serve.utils.exceptions import LocalDeepPingException
15+
from sagemaker.serve.utils.exceptions import InProcessDeepPingException
1516
from sagemaker.serve.model_server.multi_model_server.server import InProcessMultiModelServer
1617
from sagemaker.session import Session
1718

@@ -46,7 +47,7 @@ def __init__(
4647
self.session = session
4748
self.schema_builder = schema_builder
4849
self.model_server = model_server
49-
self._ping_container = None
50+
self._ping_local_server = None
5051

5152
def load(self, model_path: str = None):
5253
"""Loads model path, checks that path exists"""
@@ -69,21 +70,29 @@ def create_server(
6970
logger.info("Waiting for model server %s to start up...", self.model_server)
7071

7172
if self.model_server == ModelServer.MMS:
72-
self._ping_container = self._multi_model_server_deep_ping
73+
self._ping_local_server = self._multi_model_server_deep_ping
74+
self._start_serving()
75+
76+
# allow some time for server to be ready.
77+
time.sleep(1)
7378

7479
time_limit = datetime.now() + timedelta(seconds=5)
75-
while self._ping_container is not None:
80+
healthy = True
81+
while True:
7682
final_pull = datetime.now() > time_limit
77-
7883
if final_pull:
7984
break
8085

81-
time.sleep(10)
82-
83-
healthy, response = self._ping_container(predictor)
86+
healthy, response = self._ping_local_server(predictor)
8487
if healthy:
8588
logger.debug("Ping health check has passed. Returned %s", str(response))
8689
break
8790

91+
time.sleep(1)
92+
8893
if not healthy:
89-
raise LocalDeepPingException(_PING_HEALTH_CHECK_FAIL_MSG)
94+
raise InProcessDeepPingException(_PING_HEALTH_CHECK_FAIL_MSG)
95+
96+
def destroy_server(self):
97+
"""Placeholder docstring"""
98+
self._stop_serving()

src/sagemaker/serve/model_server/multi_model_server/server.py

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,16 @@
22

33
from __future__ import absolute_import
44

5+
import json
6+
57
import requests
68
import logging
79
import platform
810
from pathlib import Path
11+
912
from sagemaker import Session, fw_utils
1013
from sagemaker.serve.utils.exceptions import LocalModelInvocationException
14+
from sagemaker.serve.utils.exceptions import InProcessDeepPingException
1115
from sagemaker.base_predictor import PredictorBase
1216
from sagemaker.s3_utils import determine_bucket_and_prefix, parse_s3_url, s3_path_join
1317
from sagemaker.s3 import S3Uploader
@@ -25,16 +29,55 @@ class InProcessMultiModelServer:
2529

2630
def _start_serving(self):
2731
"""Initializes the start of the server"""
28-
return Exception("Not implemented")
32+
from sagemaker.serve.app import InProcessServer
2933

30-
def _invoke_multi_model_server_serving(self, request: object, content_type: str, accept: str):
31-
"""Invokes the MMS server by sending POST request"""
32-
return Exception("Not implemented")
34+
if hasattr(self, "inference_spec"):
35+
model_id = self.inference_spec.get_model()
36+
if not model_id:
37+
raise ValueError("Model id was not provided in Inference Spec.")
38+
else:
39+
model_id = None
40+
self.server = InProcessServer(model_id=model_id)
41+
42+
self.server.start_server()
43+
44+
def _stop_serving(self):
45+
"""Stops the server"""
46+
self.server.stop_server()
47+
48+
def _invoke_multi_model_server_serving(self, request: bytes, content_type: str, accept: str):
49+
"""Placeholder docstring"""
50+
try:
51+
response = requests.post(
52+
f"http://{self.server.host}:{self.server.port}/generate",
53+
data=request,
54+
headers={"Content-Type": content_type, "Accept": accept},
55+
timeout=600,
56+
)
57+
response.raise_for_status()
58+
if isinstance(response.content, bytes):
59+
return json.loads(response.content.decode("utf-8"))
60+
return response.content
61+
except Exception as e:
62+
if "Connection refused" in str(e):
63+
raise Exception(
64+
"Unable to send request to the local server: Connection refused."
65+
) from e
66+
raise Exception("Unable to send request to the local server.") from e
3367

3468
def _multi_model_server_deep_ping(self, predictor: PredictorBase):
3569
"""Sends a deep ping to ensure prediction"""
70+
healthy = False
3671
response = None
37-
return (True, response)
72+
try:
73+
response = predictor.predict(self.schema_builder.sample_input)
74+
healthy = response is not None
75+
# pylint: disable=broad-except
76+
except Exception as e:
77+
if "422 Client Error: Unprocessable Entity for url" in str(e):
78+
raise InProcessDeepPingException(str(e))
79+
80+
return healthy, response
3881

3982

4083
class LocalMultiModelServer:

src/sagemaker/serve/utils/predictors.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import absolute_import
44
import io
55
from typing import Type
6-
6+
import logging
77
from sagemaker import Session
88
from sagemaker.serve.mode.local_container_mode import LocalContainerMode
99
from sagemaker.serve.mode.in_process_mode import InProcessMode
@@ -16,6 +16,8 @@
1616

1717
APPLICATION_X_NPY = "application/x-npy"
1818

19+
logger = logging.getLogger(__name__)
20+
1921

2022
class TorchServeLocalPredictor(PredictorBase):
2123
"""Lightweight predictor for local deployment in IN_PROCESS and LOCAL_CONTAINER modes"""
@@ -211,7 +213,7 @@ def delete_predictor(self):
211213

212214

213215
class TransformersInProcessModePredictor(PredictorBase):
214-
"""Lightweight Transformers predictor for local deployment"""
216+
"""Lightweight Transformers predictor for in process mode deployment"""
215217

216218
def __init__(
217219
self,
@@ -225,18 +227,11 @@ def __init__(
225227

226228
def predict(self, data):
227229
"""Placeholder docstring"""
228-
return [
229-
self.deserializer.deserialize(
230-
io.BytesIO(
231-
self._mode_obj._invoke_multi_model_server_serving(
232-
self.serializer.serialize(data),
233-
self.content_type,
234-
self.deserializer.ACCEPT[0],
235-
)
236-
),
237-
self.content_type,
238-
)
239-
]
230+
return self._mode_obj._invoke_multi_model_server_serving(
231+
self.serializer.serialize(data),
232+
self.content_type,
233+
self.deserializer.ACCEPT[0],
234+
)
240235

241236
@property
242237
def content_type(self):

tests/unit/sagemaker/serve/builder/test_requirements_manager.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class TestRequirementsManager(unittest.TestCase):
2929
@patch(
3030
"sagemaker.serve.builder.requirements_manager.RequirementsManager._detect_conda_env_and_local_dependencies"
3131
)
32-
def test_capture_and_install_dependencies(
32+
def test_capture_and_install_dependencies_txt(
3333
self,
3434
mock_detect_conda_env_and_local_dependencies,
3535
mock_install_requirements_txt,
@@ -40,8 +40,7 @@ def test_capture_and_install_dependencies(
4040
RequirementsManager().capture_and_install_dependencies()
4141
mock_install_requirements_txt.assert_called_once()
4242

43-
mock_detect_conda_env_and_local_dependencies.side_effect = lambda: ".yml"
44-
RequirementsManager().capture_and_install_dependencies()
43+
RequirementsManager().capture_and_install_dependencies("conda.yml")
4544
mock_update_conda_env_in_path.assert_called_once()
4645

4746
@patch(

0 commit comments

Comments
 (0)