Skip to content

Commit bceefd1

Browse files
bryannahm1Bryannah Hernandezsage-maker
authored
feat: Pulling in dependencies (in_process mode) using conda environment (#4807)
* InferenceSpec support for HF * feat: InferenceSpec support for MMS and testing * Introduce changes for InProcess Mode * mb_inprocess updates * In_Process mode for TGI transformers, edits * Remove InfSpec from branch * changes to support in_process * changes to get pre-checks passing * pylint fix * unit test, test mb * period missing, added * suggestions and test added * pre-push fix * missing an @ * fixes to test, added stubbing * removing for fixes * variable fixes * init fix * tests for in process mode * prepush fix * deps and mb * changes * fixing pkl * testing * save pkl debug * changes * conda create * Conda fixes * random dep * subproces * requirementsmanager.py script * requires manag * changing command * changing command * print * shell=true * minor fix * changes * check=true * unit test * testing * unit test for requirementsmanager * removing in_process and minor edits * format * .txt file * renaming functions * fix path * making .txt evaluate to true --------- Co-authored-by: Bryannah Hernandez <[email protected]> Co-authored-by: sage-maker <[email protected]>
1 parent 97a6be3 commit bceefd1

File tree

8 files changed

+398
-5
lines changed

8 files changed

+398
-5
lines changed

src/sagemaker/serve/builder/model_builder.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,6 @@ def _overwrite_mode_in_deploy(self, overwrite_mode: str):
603603
s3_upload_path, env_vars_sagemaker = self._prepare_for_mode()
604604
self.pysdk_model.model_data = s3_upload_path
605605
self.pysdk_model.env.update(env_vars_sagemaker)
606-
607606
elif overwrite_mode == Mode.LOCAL_CONTAINER:
608607
self.mode = self.pysdk_model.mode = Mode.LOCAL_CONTAINER
609608
self._prepare_for_mode()
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Requirements Manager class to pull in client dependencies from a .txt or .yml file"""
14+
from __future__ import absolute_import
15+
import logging
16+
import os
17+
import subprocess
18+
19+
from typing import Optional
20+
21+
logger = logging.getLogger(__name__)
22+
23+
24+
class RequirementsManager:
25+
"""Manages dependency installation by detecting file types"""
26+
27+
def capture_and_install_dependencies(self, dependencies: Optional[str] = None) -> str:
28+
"""Detects the type of file dependencies will be installed from
29+
30+
If a req.txt or conda.yml file is provided, it verifies their existence and
31+
returns the local file path
32+
33+
Args:
34+
dependencies (str): Local path where dependencies file exists.
35+
36+
Returns:
37+
file path of the existing or generated dependencies file
38+
"""
39+
_dependencies = dependencies or self._detect_conda_env_and_local_dependencies()
40+
41+
# Dependencies specified as either req.txt or conda_env.yml
42+
if _dependencies.endswith(".txt"):
43+
self._install_requirements_txt()
44+
elif _dependencies.endswith(".yml"):
45+
self._update_conda_env_in_path()
46+
else:
47+
raise ValueError(f'Invalid dependencies provided: "{_dependencies}"')
48+
49+
def _install_requirements_txt(self):
50+
"""Install requirements.txt file using pip"""
51+
logger.info("Running command to pip install")
52+
subprocess.run("pip install -r in_process_requirements.txt", shell=True, check=True)
53+
logger.info("Command ran successfully")
54+
55+
def _update_conda_env_in_path(self):
56+
"""Update conda env using conda yml file"""
57+
logger.info("Updating conda env")
58+
subprocess.run("conda env update -f conda_in_process.yml", shell=True, check=True)
59+
logger.info("Conda env updated successfully")
60+
61+
def _get_active_conda_env_name(self) -> str:
62+
"""Returns the conda environment name from the set environment variable. None otherwise."""
63+
return os.getenv("CONDA_DEFAULT_ENV")
64+
65+
def _get_active_conda_env_prefix(self) -> str:
66+
"""Returns the conda prefix from the set environment variable. None otherwise."""
67+
return os.getenv("CONDA_PREFIX")
68+
69+
def _detect_conda_env_and_local_dependencies(self) -> str:
70+
"""Generates dependencies list from the user's local runtime.
71+
72+
Raises RuntimeEnvironmentError if not able to.
73+
74+
Currently supports: conda environments
75+
"""
76+
77+
# Try to capture dependencies from the conda environment, if any.
78+
conda_env_name = self._get_active_conda_env_name()
79+
logger.info("Found conda_env_name: '%s'", conda_env_name)
80+
conda_env_prefix = None
81+
82+
if conda_env_name is None:
83+
conda_env_prefix = self._get_active_conda_env_prefix()
84+
85+
if conda_env_name is None and conda_env_prefix is None:
86+
local_dependencies_path = os.path.join(os.getcwd(), "in_process_requirements.txt")
87+
logger.info(local_dependencies_path)
88+
89+
return local_dependencies_path
90+
91+
if conda_env_name == "base":
92+
logger.warning(
93+
"We recommend using an environment other than base to "
94+
"isolate your project dependencies from conda dependencies"
95+
)
96+
97+
local_dependencies_path = os.path.join(os.getcwd(), "conda_in_process.yml")
98+
logger.info(local_dependencies_path)
99+
100+
return local_dependencies_path

src/sagemaker/serve/builder/transformers_builder.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from abc import ABC, abstractmethod
1818
from typing import Type
1919
from pathlib import Path
20+
import subprocess
2021
from packaging.version import Version
2122

2223
from sagemaker.model import Model
@@ -41,6 +42,8 @@
4142
from sagemaker.serve.utils.telemetry_logger import _capture_telemetry
4243
from sagemaker.base_predictor import PredictorBase
4344
from sagemaker.huggingface.llm_utils import get_huggingface_model_metadata
45+
from sagemaker.serve.builder.requirements_manager import RequirementsManager
46+
4447

4548
logger = logging.getLogger(__name__)
4649
DEFAULT_TIMEOUT = 1800
@@ -376,6 +379,9 @@ def _build_for_transformers(self):
376379
save_pkl(code_path, (self.inference_spec, self.schema_builder))
377380
logger.info("PKL file saved to file: %s", code_path)
378381

382+
if self.mode == Mode.IN_PROCESS:
383+
self._create_conda_env()
384+
379385
self._auto_detect_container()
380386

381387
self.secret_key = prepare_for_mms(
@@ -394,3 +400,11 @@ def _build_for_transformers(self):
394400
if self.sagemaker_session:
395401
self.pysdk_model.sagemaker_session = self.sagemaker_session
396402
return self.pysdk_model
403+
404+
def _create_conda_env(self):
405+
"""Creating conda environment by running commands"""
406+
407+
try:
408+
RequirementsManager().capture_and_install_dependencies(self)
409+
except subprocess.CalledProcessError:
410+
print("Failed to create and activate conda environment.")

src/sagemaker/serve/model_server/multi_model_server/server.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def _start_serving(
3131
secret_key: str,
3232
env_vars: dict,
3333
):
34-
"""Placeholder docstring"""
34+
"""Initializes the start of the server"""
3535
env = {
3636
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
3737
"SAGEMAKER_PROGRAM": "inference.py",
@@ -59,7 +59,7 @@ def _start_serving(
5959
)
6060

6161
def _invoke_multi_model_server_serving(self, request: object, content_type: str, accept: str):
62-
"""Placeholder docstring"""
62+
"""Invokes MMS server by hitting the docker host"""
6363
try:
6464
response = requests.post(
6565
f"http://{get_docker_host()}:8080/invocations",
@@ -73,7 +73,7 @@ def _invoke_multi_model_server_serving(self, request: object, content_type: str,
7373
raise Exception("Unable to send request to the local container server") from e
7474

7575
def _multi_model_server_deep_ping(self, predictor: PredictorBase):
76-
"""Placeholder docstring"""
76+
"""Deep ping in order to ensure prediction"""
7777
response = None
7878
try:
7979
response = predictor.predict(self.schema_builder.sample_input)
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
name: conda_env
2+
channels:
3+
- defaults
4+
dependencies:
5+
- accelerate>=0.24.1,<=0.27.0
6+
- sagemaker_schema_inference_artifacts>=0.0.5
7+
- uvicorn>=0.30.1
8+
- fastapi>=0.111.0
9+
- nest-asyncio
10+
- pip>=23.0.1
11+
- attrs>=23.1.0,<24
12+
- boto3>=1.34.142,<2.0
13+
- cloudpickle==2.2.1
14+
- google-pasta
15+
- numpy>=1.9.0,<2.0
16+
- protobuf>=3.12,<5.0
17+
- smdebug_rulesconfig==1.0.1
18+
- importlib-metadata>=1.4.0,<7.0
19+
- packaging>=20.0
20+
- pandas
21+
- pathos
22+
- schema
23+
- PyYAML~=6.0
24+
- jsonschema
25+
- platformdirs
26+
- tblib>=1.7.0,<4
27+
- urllib3>=1.26.8,<3.0.0
28+
- requests
29+
- docker
30+
- tqdm
31+
- psutil
32+
- pip:
33+
- altair>=4.2.2
34+
- anyio>=3.6.2
35+
- awscli>=1.27.114
36+
- blinker>=1.6.2
37+
- botocore>=1.29.114
38+
- cachetools>=5.3.0
39+
- certifi==2022.12.7
40+
- harset-normalizer>=3.1.0
41+
- click>=8.1.3
42+
- cloudpickle>=2.2.1
43+
- colorama>=0.4.4
44+
- contextlib2>=21.6.0
45+
- decorator>=5.1.1
46+
- dill>=0.3.6
47+
- docutils>=0.16
48+
- entrypoints>=0.4
49+
- filelock>=3.11.0
50+
- gitdb>=4.0.10
51+
- gitpython>=3.1.31
52+
- gunicorn>=20.1.0
53+
- h11>=0.14.0
54+
- huggingface-hub>=0.13.4
55+
- idna>=3.4
56+
- importlib-metadata>=4.13.0
57+
- jinja2>=3.1.2
58+
- jmespath>=1.0.1
59+
- jsonschema>=4.17.3
60+
- markdown-it-py>=2.2.0
61+
- markupsafe>=2.1.2
62+
- mdurl>=0.1.2
63+
- mpmath>=1.3.0
64+
- multiprocess>=0.70.14
65+
- networkx>=3.1
66+
- packaging>=23.1
67+
- pandas>=1.5.3
68+
- pathos>=0.3.0
69+
- pillow>=9.5.0
70+
- platformdirs>=3.2.0
71+
- pox>=0.3.2
72+
- ppft>=1.7.6.6
73+
- protobuf>=3.20.3
74+
- protobuf3-to-dict>=0.1.5
75+
- pyarrow>=11.0.0
76+
- pyasn1>=0.4.8
77+
- pydantic>=1.10.7
78+
- pydeck>=0.8.1b0
79+
- pygments>=2.15.1
80+
- pympler>=1.0.1
81+
- pyrsistent>=0.19.3
82+
- python-dateutil>=2.8.2
83+
- pytz>=2023.3
84+
- pytz-deprecation-shim>=0.1.0.post0
85+
- pyyaml>=5.4.1
86+
- regex>=2023.3.23
87+
- requests>=2.28.2
88+
- rich>=13.3.4
89+
- rsa>=4.7.2
90+
- s3transfer>=0.6.0
91+
- sagemaker>=2.148.0
92+
- schema>=0.7.5
93+
- six>=1.16.0
94+
- smdebug-rulesconfig>=1.0.1
95+
- smmap==5.0.0
96+
- sniffio>=1.3.0
97+
- starlette>=0.26.1
98+
- streamlit>=1.21.0
99+
- sympy>=1.11.1
100+
- tblib>=1.7.0
101+
- tokenizers>=0.13.3
102+
- toml>=0.10.2
103+
- toolz>=0.12.0
104+
- torch>=2.0.0
105+
- tornado>=6.3
106+
- tqdm>=4.65.0
107+
- transformers>=4.28.1
108+
- typing-extensions>=4.5.0
109+
- tzdata>=2023.3
110+
- tzlocal>=4.3
111+
- urllib3>=1.26.15
112+
- validators>=0.20.0
113+
- zipp>=3.15.0

src/sagemaker/serve/utils/exceptions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Placeholder Docstring"""
1+
"""Exceptions used across different model builder invocations"""
22

33
from __future__ import absolute_import
44

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
altair>=4.2.2
2+
anyio>=3.6.2
3+
awscli>=1.27.114
4+
blinker>=1.6.2
5+
botocore>=1.29.114
6+
cachetools>=5.3.0
7+
certifi==2022.12.7
8+
harset-normalizer>=3.1.0
9+
click>=8.1.3
10+
cloudpickle>=2.2.1
11+
colorama>=0.4.4
12+
contextlib2>=21.6.0
13+
decorator>=5.1.1
14+
dill>=0.3.6
15+
docutils>=0.16
16+
entrypoints>=0.4
17+
filelock>=3.11.0
18+
gitdb>=4.0.10
19+
gitpython>=3.1.31
20+
gunicorn>=20.1.0
21+
h11>=0.14.0
22+
huggingface-hub>=0.13.4
23+
idna>=3.4
24+
importlib-metadata>=4.13.0
25+
jinja2>=3.1.2
26+
jmespath>=1.0.1
27+
jsonschema>=4.17.3
28+
markdown-it-py>=2.2.0
29+
markupsafe>=2.1.2
30+
mdurl>=0.1.2
31+
mpmath>=1.3.0
32+
multiprocess>=0.70.14
33+
networkx>=3.1
34+
packaging>=23.1
35+
pandas>=1.5.3
36+
pathos>=0.3.0
37+
pillow>=9.5.0
38+
platformdirs>=3.2.0
39+
pox>=0.3.2
40+
ppft>=1.7.6.6
41+
protobuf>=3.20.3
42+
protobuf3-to-dict>=0.1.5
43+
pyarrow>=11.0.0
44+
pyasn1>=0.4.8
45+
pydantic>=1.10.7
46+
pydeck>=0.8.1b0
47+
pygments>=2.15.1
48+
pympler>=1.0.1
49+
pyrsistent>=0.19.3
50+
python-dateutil>=2.8.2
51+
pytz>=2023.3
52+
pytz-deprecation-shim>=0.1.0.post0
53+
pyyaml>=5.4.1
54+
regex>=2023.3.23
55+
requests>=2.28.2
56+
rich>=13.3.4
57+
rsa>=4.7.2
58+
s3transfer>=0.6.0
59+
sagemaker>=2.148.0
60+
schema>=0.7.5
61+
six>=1.16.0
62+
smdebug-rulesconfig>=1.0.1
63+
smmap==5.0.0
64+
sniffio>=1.3.0
65+
starlette>=0.26.1
66+
streamlit>=1.21.0
67+
sympy>=1.11.1
68+
tblib>=1.7.0
69+
tokenizers>=0.13.3
70+
toml>=0.10.2
71+
toolz>=0.12.0
72+
torch>=2.0.0
73+
tornado>=6.3
74+
tqdm>=4.65.0
75+
transformers>=4.28.1
76+
typing-extensions>=4.5.0
77+
tzdata>=2023.3
78+
tzlocal>=4.3
79+
urllib3>=1.26.15
80+
validators>=0.20.0
81+
zipp>=3.15.0
82+
uvicorn>=0.30.1
83+
fastapi>=0.111.0
84+
nest-asyncio
85+
transformers

0 commit comments

Comments
 (0)