Skip to content

Commit c63c268

Browse files
samrudsjiapinw
authored andcommitted
change: Enhance model builder selection logic to include model size (aws#4429)
* change: Enhance model builder selection logic to include model size * Fix conflicts * Address PR comments * fix formatting * fix formatting of test * Fix token in tasks.json * Increase coverage for tests * fix formatting * Fix requirements * Import code instead of importing accelerate * Fix formatting * Setup dependencies
1 parent 9849a98 commit c63c268

File tree

10 files changed

+796
-10
lines changed

10 files changed

+796
-10
lines changed

doc/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ docutils==0.15.2
44
packaging==20.9
55
jinja2==3.1.3
66
schema==0.7.5
7+
accelerate>=0.24.1,<=0.27.0
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
accelerate>=0.24.1,<=0.27.0

requirements/extras/test_requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,4 @@ tritonclient[http]<2.37.0
3939
onnx==1.14.1
4040
# tf2onnx==1.15.1
4141
nbformat>=5.9,<6
42+
accelerate>=0.24.1,<=0.27.0

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def read_requirements(filename):
7979
"feature-processor": read_requirements(
8080
"requirements/extras/feature-processor_requirements.txt"
8181
),
82+
"huggingface": read_requirements("requirements/extras/huggingface_requirements.txt"),
8283
}
8384
# Meta dependency groups
8485
extras["all"] = [item for group in extras.values() for item in group]

src/sagemaker/serve/builder/model_builder.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@
2020

2121
from pathlib import Path
2222

23+
from accelerate.commands.estimate import estimate_command_parser, gather_data
2324
from sagemaker import Session
2425
from sagemaker.model import Model
2526
from sagemaker.base_predictor import PredictorBase
27+
from sagemaker.djl_inference import defaults
2628
from sagemaker.serializers import NumpySerializer, TorchTensorSerializer
2729
from sagemaker.deserializers import JSONDeserializer, TorchTensorDeserializer
2830
from sagemaker.serve.builder.schema_builder import SchemaBuilder
@@ -42,6 +44,7 @@
4244
from sagemaker.serve.utils import task
4345
from sagemaker.serve.utils.exceptions import TaskNotFoundException
4446
from sagemaker.serve.utils.predictors import _get_local_mode_predictor
47+
from sagemaker.serve.utils.hardware_detector import _get_gpu_info, _get_gpu_info_fallback
4548
from sagemaker.serve.detector.image_detector import (
4649
auto_detect_container,
4750
_detect_framework_and_version,
@@ -69,6 +72,9 @@
6972
ModelServer.FASTAPI
7073
}
7174

75+
MIB_CONVERSION_FACTOR = 0.00000095367431640625
76+
MEMORY_BUFFER_MULTIPLIER = 1.2 # 20% buffer
77+
7278

7379
# pylint: disable=attribute-defined-outside-init
7480
@dataclass
@@ -571,7 +577,7 @@ def wrapper(*args, **kwargs):
571577
# It supports two modes of deployment
572578
# 1/ SageMaker Endpoint
573579
# 2/ Local launch with container
574-
def build(
580+
def build( # pylint: disable=R0911
575581
self,
576582
mode: Type[Mode] = None,
577583
role_arn: str = None,
@@ -627,6 +633,13 @@ def build(
627633

628634
if model_task == "text-generation": # pylint: disable=R1705
629635
return self._build_for_tgi()
636+
elif self._can_fit_on_single_gpu():
637+
return self._build_for_transformers()
638+
elif (
639+
self.model in defaults.DEEPSPEED_RECOMMENDED_ARCHITECTURES
640+
or self.model in defaults.FASTER_TRANSFORMER_RECOMMENDED_ARCHITECTURES
641+
):
642+
return self._build_for_djl()
630643
else:
631644
return self._build_for_transformers()
632645

@@ -704,3 +717,66 @@ def _schema_builder_init(self, model_task: str):
704717
self.schema_builder = SchemaBuilder(sample_inputs, sample_outputs)
705718
except ValueError:
706719
raise TaskNotFoundException(f"Schema builder for {model_task} could not be found.")
720+
721+
def _total_inference_model_size_mib(self):
722+
"""Calculates the model size from HF accelerate
723+
724+
This function gets the model size from accelerate. It also adds a
725+
padding and converts to size MiB. When performing inference, expect
726+
to add up to an additional 20% to the given model size as found by EleutherAI.
727+
"""
728+
dtypes = self.env_vars.get("dtypes", "float32")
729+
parser = estimate_command_parser()
730+
args = parser.parse_args([self.model, "--dtypes", dtypes])
731+
732+
output = gather_data(
733+
args
734+
) # "dtype", "Largest Layer", "Total Size Bytes", "Training using Adam"
735+
736+
if output is None:
737+
raise ValueError(f"Could not get Model size for {self.model}")
738+
739+
total_memory_size_mib = MEMORY_BUFFER_MULTIPLIER * output[0][2] * MIB_CONVERSION_FACTOR
740+
logger.info("Total memory size MIB: %s", total_memory_size_mib)
741+
return total_memory_size_mib
742+
743+
def _can_fit_on_single_gpu(self) -> Type[bool]:
744+
"""Check if model can fit on a single GPU
745+
746+
If the size of the model is <= single gpu memory size, returns True else False
747+
"""
748+
try:
749+
single_gpu_size_mib = self._try_fetch_gpu_info()
750+
if self._total_inference_model_size_mib() <= single_gpu_size_mib:
751+
logger.info(
752+
"Total inference model size MIB %s, single GPU size for instance MIB %s",
753+
self._total_inference_model_size_mib(),
754+
single_gpu_size_mib,
755+
)
756+
return True
757+
return False
758+
except ValueError:
759+
logger.info("Unable to determine single GPU size for instance %s", self.instance_type)
760+
return False
761+
762+
def _try_fetch_gpu_info(self):
763+
"""Get GPU info
764+
765+
This function gets the GPU info or fallback to set the size of a single GPU
766+
"""
767+
try:
768+
gpu_info = _get_gpu_info(self.instance_type, self.sagemaker_session)
769+
logger.info("GPU info %s for instance %s", gpu_info, self.instance_type)
770+
return gpu_info[1] / gpu_info[0]
771+
except ValueError:
772+
pass
773+
try:
774+
gpu_fallback = _get_gpu_info_fallback(
775+
self.instance_type, self.sagemaker_session.boto_region_name
776+
)
777+
logger.info("GPU fallback picked up %s", gpu_fallback)
778+
return gpu_fallback[1] / gpu_fallback[0]
779+
except ValueError:
780+
raise ValueError(
781+
f"Unable to determine single GPU size for instance: [{self.instance_type}]"
782+
)

src/sagemaker/serve/schema/task.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"fill-mask": {
33
"sample_inputs": {
44
"properties": {
5-
"inputs": "Paris is the <mask> of France.",
5+
"inputs": "Paris is the [MASK] of France.",
66
"parameters": {}
77
}
88
},
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import pytest
16+
from sagemaker.serve.builder.schema_builder import SchemaBuilder
17+
from sagemaker.serve.builder.model_builder import ModelBuilder, Mode
18+
import tests.integ
19+
from tests.integ.sagemaker.serve.constants import (
20+
HF_DIR,
21+
PYTHON_VERSION_IS_NOT_310,
22+
SERVE_SAGEMAKER_ENDPOINT_TIMEOUT,
23+
)
24+
from tests.integ.timeout import timeout
25+
from tests.integ.utils import cleanup_model_resources, gpu_list, retry_with_instance_list
26+
import logging
27+
28+
logger = logging.getLogger(__name__)
29+
30+
model_id = "bert-base-uncased"
31+
32+
sample_input = {"inputs": "Hello I'm a [MASK] model."}
33+
34+
sample_output = [
35+
{
36+
"score": 0.10731109976768494,
37+
"token": 4827,
38+
"token_str": "fashion",
39+
"sequence": "hello i'm a fashion model.",
40+
},
41+
{
42+
"score": 0.08774465322494507,
43+
"token": 2535,
44+
"token_str": "role",
45+
"sequence": "hello i'm a role model.",
46+
},
47+
{
48+
"score": 0.05338414013385773,
49+
"token": 2047,
50+
"token_str": "new",
51+
"sequence": "hello i'm a new model.",
52+
},
53+
{
54+
"score": 0.04667224362492561,
55+
"token": 3565,
56+
"token_str": "super",
57+
"sequence": "hello i'm a super model.",
58+
},
59+
{
60+
"score": 0.027096163481473923,
61+
"token": 2986,
62+
"token_str": "fine",
63+
"sequence": "hello i'm a fine model.",
64+
},
65+
]
66+
67+
68+
@pytest.fixture
69+
def model_input():
70+
return {"inputs": "The man worked as a [MASK]."}
71+
72+
73+
@pytest.fixture
74+
def model_builder_model_schema_builder():
75+
return ModelBuilder(
76+
model_path=HF_DIR, model=model_id, schema_builder=SchemaBuilder(sample_input, sample_output)
77+
)
78+
79+
80+
@pytest.fixture
81+
def model_builder(request):
82+
return request.getfixturevalue(request.param)
83+
84+
85+
@pytest.mark.skipif(
86+
PYTHON_VERSION_IS_NOT_310,
87+
tests.integ.test_region() in tests.integ.TRAINING_NO_P2_REGIONS
88+
and tests.integ.test_region() in tests.integ.TRAINING_NO_P3_REGIONS,
89+
reason="no ml.p2 or ml.p3 instances in this region",
90+
)
91+
@retry_with_instance_list(gpu_list(tests.integ.test_region()))
92+
@pytest.mark.parametrize("model_builder", ["model_builder_model_schema_builder"], indirect=True)
93+
def test_non_text_generation_model_single_GPU(
94+
sagemaker_session, model_builder, model_input, **kwargs
95+
):
96+
iam_client = sagemaker_session.boto_session.client("iam")
97+
role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"]
98+
model = model_builder.build(role_arn=role_arn, sagemaker_session=sagemaker_session)
99+
caught_ex = None
100+
with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT):
101+
try:
102+
logger.info("Running in SAGEMAKER_ENDPOINT mode")
103+
predictor = model.deploy(
104+
mode=Mode.SAGEMAKER_ENDPOINT,
105+
instance_type=kwargs["instance_type"],
106+
initial_instance_count=1,
107+
)
108+
logger.info("Endpoint successfully deployed.")
109+
prediction = predictor.predict(model_input)
110+
assert prediction is not None
111+
112+
endpoint_name = predictor.endpoint_name
113+
sagemaker_client = sagemaker_session.boto_session.client("sagemaker")
114+
endpoint_config_name = sagemaker_client.describe_endpoint(EndpointName=endpoint_name)[
115+
"EndpointConfigName"
116+
]
117+
actual_instance_type = sagemaker_client.describe_endpoint_config(
118+
EndpointConfigName=endpoint_config_name
119+
)["ProductionVariants"][0]["InstanceType"]
120+
assert kwargs["instance_type"] == actual_instance_type
121+
except Exception as e:
122+
caught_ex = e
123+
finally:
124+
cleanup_model_resources(
125+
sagemaker_session=model_builder.sagemaker_session,
126+
model_name=model.name,
127+
endpoint_name=model.endpoint_name,
128+
)
129+
if caught_ex:
130+
logger.exception(caught_ex)
131+
assert (
132+
False
133+
), f"Exception {caught_ex} was thrown when running model builder single GPU test"
134+
135+
136+
@pytest.mark.skipif(
137+
PYTHON_VERSION_IS_NOT_310,
138+
tests.integ.test_region() in tests.integ.TRAINING_NO_P2_REGIONS
139+
and tests.integ.test_region() in tests.integ.TRAINING_NO_P3_REGIONS,
140+
reason="no ml.p2 or ml.p3 instances in this region",
141+
)
142+
@retry_with_instance_list(gpu_list(tests.integ.test_region()))
143+
@pytest.mark.parametrize("model_builder", ["model_builder_model_schema_builder"], indirect=True)
144+
def test_non_text_generation_model_multi_GPU(
145+
sagemaker_session, model_builder, model_input, **kwargs
146+
):
147+
iam_client = sagemaker_session.boto_session.client("iam")
148+
role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"]
149+
caught_ex = None
150+
model = model_builder.build(role_arn=role_arn, sagemaker_session=sagemaker_session)
151+
with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT):
152+
try:
153+
logger.info("Running in SAGEMAKER_ENDPOINT mode")
154+
predictor = model.deploy(
155+
mode=Mode.SAGEMAKER_ENDPOINT,
156+
instance_type=kwargs["instance_type"],
157+
initial_instance_count=1,
158+
)
159+
logger.info("Endpoint successfully deployed.")
160+
prediction = predictor.predict(model_input)
161+
assert prediction is not None
162+
163+
endpoint_name = predictor.endpoint_name
164+
sagemaker_client = sagemaker_session.boto_session.client("sagemaker")
165+
endpoint_config_name = sagemaker_client.describe_endpoint(EndpointName=endpoint_name)[
166+
"EndpointConfigName"
167+
]
168+
actual_instance_type = sagemaker_client.describe_endpoint_config(
169+
EndpointConfigName=endpoint_config_name
170+
)["ProductionVariants"][0]["InstanceType"]
171+
assert kwargs["instance_type"] == actual_instance_type
172+
except Exception as e:
173+
caught_ex = e
174+
finally:
175+
cleanup_model_resources(
176+
sagemaker_session=model_builder.sagemaker_session,
177+
model_name=model.name,
178+
endpoint_name=model.endpoint_name,
179+
)
180+
if caught_ex:
181+
logger.exception(caught_ex)
182+
assert (
183+
False
184+
), f"Exception {caught_ex} was thrown when running model builder multi GPU test"

tests/integ/sagemaker/serve/test_serve_transformers.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,15 @@
1515
import pytest
1616
from sagemaker.serve.builder.schema_builder import SchemaBuilder
1717
from sagemaker.serve.builder.model_builder import ModelBuilder, Mode
18-
18+
import tests.integ
1919
from tests.integ.sagemaker.serve.constants import (
2020
HF_DIR,
2121
PYTHON_VERSION_IS_NOT_310,
2222
SERVE_SAGEMAKER_ENDPOINT_TIMEOUT,
2323
)
2424

2525
from tests.integ.timeout import timeout
26-
from tests.integ.utils import cleanup_model_resources
26+
from tests.integ.utils import cleanup_model_resources, gpu_list, retry_with_instance_list
2727
import logging
2828

2929
logger = logging.getLogger(__name__)
@@ -67,7 +67,7 @@
6767

6868

6969
@pytest.fixture
70-
def input():
70+
def model_input():
7171
return {"inputs": "The man worked as a [MASK]."}
7272

7373

@@ -87,11 +87,14 @@ def model_builder(request):
8787

8888
@pytest.mark.skipif(
8989
PYTHON_VERSION_IS_NOT_310,
90-
reason="Testing feature",
90+
tests.integ.test_region() in tests.integ.TRAINING_NO_P2_REGIONS
91+
and tests.integ.test_region() in tests.integ.TRAINING_NO_P3_REGIONS,
92+
reason="no ml.p2 or ml.p3 instances in this region",
9193
)
94+
@retry_with_instance_list(gpu_list(tests.integ.test_region()))
9295
@pytest.mark.parametrize("model_builder", ["model_builder_model_schema_builder"], indirect=True)
9396
def test_pytorch_transformers_sagemaker_endpoint(
94-
sagemaker_session, model_builder, gpu_instance_type, input
97+
sagemaker_session, model_builder, model_input, **kwargs
9598
):
9699
logger.info("Running in SAGEMAKER_ENDPOINT mode...")
97100
caught_ex = None
@@ -106,9 +109,12 @@ def test_pytorch_transformers_sagemaker_endpoint(
106109
with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT):
107110
try:
108111
logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...")
109-
predictor = model.deploy(instance_type=gpu_instance_type, initial_instance_count=1)
112+
predictor = model.deploy(
113+
instance_type=kwargs["instance_type"], initial_instance_count=2
114+
)
110115
logger.info("Endpoint successfully deployed.")
111-
predictor.predict(input)
116+
predictor.predict(model_input)
117+
assert predictor is not None
112118
except Exception as e:
113119
caught_ex = e
114120
finally:

0 commit comments

Comments
 (0)