Skip to content

Commit 85252ef

Browse files
committed
Enhance model builder selection logic to include model size
1 parent 0900405 commit 85252ef

File tree

6 files changed

+85
-6
lines changed

6 files changed

+85
-6
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
accelerate
2+
numpy>=1.17
3+
packaging>=20.0
4+
psutil
5+
pyyaml
6+
torch>=1.10.0
7+
huggingface_hub

requirements/extras/test_requirements.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,10 @@ tritonclient[http]<2.37.0
3939
onnx==1.14.1
4040
# tf2onnx==1.15.1
4141
nbformat>=5.9,<6
42+
accelerate
43+
numpy>=1.17
44+
packaging>=20.0
45+
psutil
46+
pyyaml
47+
torch>=1.10.0
48+
huggingface_hub

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def read_requirements(filename):
7979
"feature-processor": read_requirements(
8080
"requirements/extras/feature-processor_requirements.txt"
8181
),
82+
"huggingface": read_requirements("requirements/extras/huggingface_requirements.txt"),
8283
}
8384
# Meta dependency groups
8485
extras["all"] = [item for group in extras.values() for item in group]

src/sagemaker/serve/builder/model_builder.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020

2121
from pathlib import Path
2222

23+
from accelerate.commands.estimate import estimate_command_parser, gather_data
2324
from sagemaker import Session
25+
from sagemaker.djl_inference import defaults
2426
from sagemaker.model import Model
2527
from sagemaker.base_predictor import PredictorBase
2628
from sagemaker.serializers import NumpySerializer, TorchTensorSerializer
@@ -39,6 +41,7 @@
3941
from sagemaker.serve.save_retrive.version_1_0_0.metadata.metadata import Metadata
4042
from sagemaker.serve.spec.inference_spec import InferenceSpec
4143
from sagemaker.serve.utils.predictors import _get_local_mode_predictor
44+
from sagemaker.serve.utils.hardware_detector import _get_gpu_info, _get_gpu_info_fallback
4245
from sagemaker.serve.detector.image_detector import (
4346
auto_detect_container,
4447
_detect_framework_and_version,
@@ -65,6 +68,9 @@
6568
ModelServer.DJL_SERVING,
6669
}
6770

71+
MIB_CONVERSION_FACTOR = 0.00000095367431640625
72+
MEMORY_BUFFER_MULTIPLIER = 1.2 # 20% buffer
73+
6874

6975
# pylint: disable=attribute-defined-outside-init
7076
@dataclass
@@ -567,7 +573,7 @@ def wrapper(*args, **kwargs):
567573
# It supports two modes of deployment
568574
# 1/ SageMaker Endpoint
569575
# 2/ Local launch with container
570-
def build(
576+
def build( # pylint: disable=R0911
571577
self,
572578
mode: Type[Mode] = None,
573579
role_arn: str = None,
@@ -616,6 +622,10 @@ def build(
616622
)
617623
if hf_model_md.get("pipeline_tag") == "text-generation": # pylint: disable=R1705
618624
return self._build_for_tgi()
625+
elif self.can_fit_on_single_gpu():
626+
return self._build_for_transformers()
627+
elif self.model in defaults.FASTER_TRANSFORMER_SUPPORTED_ARCHITECTURES:
628+
return self._build_for_djl()
619629
else:
620630
return self._build_for_transformers()
621631

@@ -672,3 +682,58 @@ def validate(self, model_dir: str) -> Type[bool]:
672682
"""
673683

674684
return get_metadata(model_dir)
685+
686+
def total_inference_model_size_mib(self):
687+
"""Calculates the model size from HF accelerate
688+
689+
This function gets the model size from accelerate. It also adds a
690+
padding and converts to size MiB. When performing inference, expect
691+
to add up to an additional 20% to the given model size as found by EleutherAI.
692+
"""
693+
dtypes = "float32"
694+
try:
695+
if self.env_vars.get("dtypes"):
696+
dtypes = self.env_vars.get("dtypes")
697+
698+
parser = estimate_command_parser()
699+
args = parser.parse_args([self.model, "--dtypes", dtypes])
700+
except ValueError:
701+
logging.error("Args specified incorrect for model %s", self.model)
702+
703+
output = gather_data(
704+
args
705+
) # "dtype", "Largest Layer", "Total Size Bytes", "Training using Adam"
706+
707+
total_memory_size_mib = MEMORY_BUFFER_MULTIPLIER * output[0][2] * MIB_CONVERSION_FACTOR
708+
logger.info("Total memory size MIB: %s", total_memory_size_mib)
709+
return total_memory_size_mib
710+
711+
def can_fit_on_single_gpu(self):
712+
"""Check if model can fit on a single GPU
713+
714+
This function gets the GPU info or fallback to set the size of a single GPU.
715+
If the size of the model is <= single gpu memory size, returns true.
716+
"""
717+
try:
718+
gpu_info = _get_gpu_info(self.instance_type, self.sagemaker_session)
719+
logger.info("GPU info %s for instance %s", gpu_info, self.instance_type)
720+
single_gpu_size_mib = gpu_info[1] / gpu_info[0]
721+
except ValueError:
722+
gpu_fallback = _get_gpu_info_fallback(
723+
self.instance_type, self.sagemaker_session.boto_region_name
724+
)
725+
logger.info("GPU fallback picked up %s", gpu_fallback)
726+
single_gpu_size_mib = gpu_fallback[1] / gpu_fallback[0]
727+
728+
if single_gpu_size_mib is None:
729+
logger.info("Unable to determine single GPU size for instance %s", self.instance_type)
730+
return False
731+
732+
if self.total_inference_model_size_mib() <= single_gpu_size_mib:
733+
logger.info(
734+
"Total inference model size MIB %s, single GPU size for instance MIB %s",
735+
self.total_inference_model_size_mib(),
736+
single_gpu_size_mib,
737+
)
738+
return True
739+
return False

tests/integ/sagemaker/serve/test_model_builder_gpu.py renamed to tests/integ/sagemaker/serve/test_serve_model_builder_gpu.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,8 @@
1313
from __future__ import absolute_import
1414

1515
import pytest
16-
from sagemaker.serve import Mode
17-
from sagemaker.serve.builder.model_builder import ModelBuilder
1816
from sagemaker.serve.builder.schema_builder import SchemaBuilder
17+
from sagemaker.serve.builder.model_builder import ModelBuilder, Mode
1918
from tests.integ.sagemaker.serve.constants import (
2019
HF_DIR,
2120
PYTHON_VERSION_IS_NOT_310,
@@ -90,10 +89,10 @@ def model_builder(request):
9089
def test_non_text_generation_model_single_GPU(sagemaker_session, model_builder, model_input):
9190
iam_client = sagemaker_session.boto_session.client("iam")
9291
role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"]
92+
model = model_builder.build(role_arn=role_arn, sagemaker_session=sagemaker_session)
9393
caught_ex = None
9494
with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT):
9595
try:
96-
model = model_builder.build(role_arn=role_arn, sagemaker_session=sagemaker_session)
9796
logger.info("Running in SAGEMAKER_ENDPOINT mode")
9897
predictor = model.deploy(
9998
mode=Mode.SAGEMAKER_ENDPOINT,
@@ -137,9 +136,9 @@ def test_non_text_generation_model_multi_GPU(sagemaker_session, model_builder, m
137136
iam_client = sagemaker_session.boto_session.client("iam")
138137
role_arn = iam_client.get_role(RoleName="SageMakerRole")["Role"]["Arn"]
139138
caught_ex = None
139+
model = model_builder.build(role_arn=role_arn, sagemaker_session=sagemaker_session)
140140
with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT):
141141
try:
142-
model = model_builder.build(role_arn=role_arn, sagemaker_session=sagemaker_session)
143142
logger.info("Running in SAGEMAKER_ENDPOINT mode")
144143
predictor = model.deploy(
145144
mode=Mode.SAGEMAKER_ENDPOINT,

tests/integ/sagemaker/serve/test_serve_transformers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def test_pytorch_transformers_sagemaker_endpoint(
106106
with timeout(minutes=SERVE_SAGEMAKER_ENDPOINT_TIMEOUT):
107107
try:
108108
logger.info("Deploying and predicting in SAGEMAKER_ENDPOINT mode...")
109-
predictor = model.deploy(instance_type=gpu_instance_type, initial_instance_count=1)
109+
predictor = model.deploy(instance_type="ml.g4dn.xlarge", initial_instance_count=1)
110110
logger.info("Endpoint successfully deployed.")
111111
predictor.predict(input)
112112
except Exception as e:

0 commit comments

Comments
 (0)