Skip to content

Commit e0c363e

Browse files
committed
move accelerate to utils
1 parent 5783194 commit e0c363e

File tree

3 files changed

+54
-96
lines changed

3 files changed

+54
-96
lines changed

src/sagemaker/serve/builder/model_builder.py

Lines changed: 13 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from pathlib import Path
2222

23-
from accelerate.commands.estimate import estimate_command_parser, gather_data
23+
2424
from sagemaker import Session
2525
from sagemaker.model import Model
2626
from sagemaker.base_predictor import PredictorBase
@@ -43,7 +43,11 @@
4343
from sagemaker.serve.utils import task
4444
from sagemaker.serve.utils.exceptions import TaskNotFoundException
4545
from sagemaker.serve.utils.predictors import _get_local_mode_predictor
46-
from sagemaker.serve.utils.hardware_detector import _get_gpu_info, _get_gpu_info_fallback
46+
from sagemaker.serve.utils.hardware_detector import (
47+
_get_gpu_info,
48+
_get_gpu_info_fallback,
49+
_total_inference_model_size_mib,
50+
)
4751
from sagemaker.serve.detector.image_detector import (
4852
auto_detect_container,
4953
_detect_framework_and_version,
@@ -70,13 +74,6 @@
7074
ModelServer.DJL_SERVING,
7175
}
7276

73-
MIB_CONVERSION_FACTOR = 0.00000095367431640625
74-
MEMORY_BUFFER_MULTIPLIER = 1.2 # 20% buffer
75-
VERSION_DETECTION_ERROR = (
76-
"Please install accelerate and transformers for HuggingFace (HF) model "
77-
"size calculations e.g. pip install 'sagemaker[huggingface]'"
78-
)
79-
8077

8178
# pylint: disable=attribute-defined-outside-init, disable=E1101
8279
@dataclass
@@ -723,43 +720,22 @@ def _schema_builder_init(self, model_task: str):
723720
except ValueError:
724721
raise TaskNotFoundException(f"Schema builder for {model_task} could not be found.")
725722

726-
def _total_inference_model_size_mib(self):
727-
"""Calculates the model size from HF accelerate
728-
729-
This function gets the model size from accelerate. It also adds a
730-
padding and converts to size MiB. When performing inference, expect
731-
to add up to an additional 20% to the given model size as found by EleutherAI.
732-
"""
733-
try:
734-
dtypes = self.env_vars.get("dtypes", "float32")
735-
parser = estimate_command_parser()
736-
args = parser.parse_args([self.model, "--dtypes", dtypes])
737-
738-
output = gather_data(
739-
args
740-
) # "dtype", "Largest Layer", "Total Size Bytes", "Training using Adam"
741-
except ImportError as e:
742-
logger.warning(VERSION_DETECTION_ERROR)
743-
raise e
744-
745-
if output is None:
746-
raise ValueError(f"Could not get Model size for {self.model}")
747-
748-
total_memory_size_mib = MEMORY_BUFFER_MULTIPLIER * output[0][2] * MIB_CONVERSION_FACTOR
749-
logger.info("Total memory size MIB: %s", total_memory_size_mib)
750-
return total_memory_size_mib
751-
752723
def _can_fit_on_single_gpu(self) -> Type[bool]:
753724
"""Check if model can fit on a single GPU
754725
755726
If the size of the model is <= single gpu memory size, returns True else False
756727
"""
757728
try:
758729
single_gpu_size_mib = self._try_fetch_gpu_info()
759-
if self._total_inference_model_size_mib() <= single_gpu_size_mib:
730+
if (
731+
_total_inference_model_size_mib(self.model, self.env_vars.get("dtypes", "float32"))
732+
<= single_gpu_size_mib
733+
):
760734
logger.info(
761735
"Total inference model size MIB %s, single GPU size for instance MIB %s",
762-
self._total_inference_model_size_mib(),
736+
_total_inference_model_size_mib(
737+
self.model, self.env_vars.get("dtypes", "float32")
738+
),
763739
single_gpu_size_mib,
764740
)
765741
return True

src/sagemaker/serve/utils/hardware_detector.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,22 @@
1818

1919
from botocore.exceptions import ClientError
2020

21+
from accelerate.commands.estimate import estimate_command_parser, gather_data
2122
from sagemaker import Session
23+
from sagemaker.model import Model
2224
from sagemaker import instance_types_gpu_info
2325

2426
logger = logging.getLogger(__name__)
2527

2628

29+
MIB_CONVERSION_FACTOR = 0.00000095367431640625
30+
MEMORY_BUFFER_MULTIPLIER = 1.2 # 20% buffer
31+
VERSION_DETECTION_ERROR = (
32+
"Please install accelerate and transformers for HuggingFace (HF) model "
33+
"size calculations e.g. pip install 'sagemaker[huggingface]'"
34+
)
35+
36+
2737
def _get_gpu_info(instance_type: str, session: Session) -> Tuple[int, int]:
2838
"""Get GPU info for the provided instance
2939
@@ -108,3 +118,30 @@ def _format_instance_type(instance_type: str) -> str:
108118

109119
ec2_instance = ".".join(split_instance)
110120
return ec2_instance
121+
122+
123+
def _total_inference_model_size_mib(model: Model, dtype: str) -> int:
124+
"""Calculates the model size from HF accelerate
125+
126+
This function gets the model size from accelerate. It also adds a
127+
padding and converts to size MiB. When performing inference, expect
128+
to add up to an additional 20% to the given model size as found by EleutherAI.
129+
"""
130+
try:
131+
dtypes = dtype
132+
parser = estimate_command_parser()
133+
args = parser.parse_args([model, "--dtypes", dtypes])
134+
135+
output = gather_data(
136+
args
137+
) # "dtype", "Largest Layer", "Total Size Bytes", "Training using Adam"
138+
except ImportError as e:
139+
logger.warning(VERSION_DETECTION_ERROR)
140+
raise e
141+
142+
if output is None:
143+
raise ValueError(f"Could not get Model size for {model}")
144+
145+
total_memory_size_mib = MEMORY_BUFFER_MULTIPLIER * output[0][2] * MIB_CONVERSION_FACTOR
146+
logger.info("Total memory size MIB: %s", total_memory_size_mib)
147+
return total_memory_size_mib

tests/unit/sagemaker/serve/builder/test_model_builder.py

Lines changed: 4 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1205,7 +1205,7 @@ def test_build_for_transformers_happy_case(
12051205

12061206
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers")
12071207
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._try_fetch_gpu_info")
1208-
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._total_inference_model_size_mib")
1208+
@patch("sagemaker.serve.builder.model_builder._total_inference_model_size_mib")
12091209
@patch("sagemaker.image_uris.retrieve")
12101210
@patch("sagemaker.djl_inference.model.urllib")
12111211
@patch("sagemaker.djl_inference.model.json")
@@ -1248,7 +1248,7 @@ def test_build_for_transformers_happy_case_with_values(
12481248

12491249
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_djl", Mock())
12501250
@patch("sagemaker.serve.builder.model_builder._get_gpu_info")
1251-
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._total_inference_model_size_mib")
1251+
@patch("sagemaker.serve.builder.model_builder._total_inference_model_size_mib")
12521252
@patch("sagemaker.image_uris.retrieve")
12531253
@patch("sagemaker.djl_inference.model.urllib")
12541254
@patch("sagemaker.djl_inference.model.json")
@@ -1293,7 +1293,7 @@ def test_build_for_transformers_happy_case_with_valid_gpu_info(
12931293
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers", Mock())
12941294
@patch("sagemaker.serve.builder.model_builder._get_gpu_info")
12951295
@patch("sagemaker.serve.builder.model_builder._get_gpu_info_fallback")
1296-
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._total_inference_model_size_mib")
1296+
@patch("sagemaker.serve.builder.model_builder._total_inference_model_size_mib")
12971297
@patch("sagemaker.image_uris.retrieve")
12981298
@patch("sagemaker.djl_inference.model.urllib")
12991299
@patch("sagemaker.djl_inference.model.json")
@@ -1342,61 +1342,6 @@ def test_build_for_transformers_happy_case_with_valid_gpu_fallback(
13421342
)
13431343
self.assertEqual(model_builder._can_fit_on_single_gpu(), True)
13441344

1345-
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers", Mock())
1346-
@patch("sagemaker.serve.builder.model_builder.estimate_command_parser")
1347-
@patch("sagemaker.serve.builder.model_builder.gather_data")
1348-
@patch("sagemaker.image_uris.retrieve")
1349-
@patch("sagemaker.djl_inference.model.urllib")
1350-
@patch("sagemaker.djl_inference.model.json")
1351-
@patch("sagemaker.huggingface.llm_utils.urllib")
1352-
@patch("sagemaker.huggingface.llm_utils.json")
1353-
@patch("sagemaker.model_uris.retrieve")
1354-
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
1355-
def test_build_for_transformers_happy_case_hugging_face_responses(
1356-
self,
1357-
mock_serveSettings,
1358-
mock_model_uris_retrieve,
1359-
mock_llm_utils_json,
1360-
mock_llm_utils_urllib,
1361-
mock_model_json,
1362-
mock_model_urllib,
1363-
mock_image_uris_retrieve,
1364-
mock_gather_data,
1365-
mock_parser,
1366-
):
1367-
mock_setting_object = mock_serveSettings.return_value
1368-
mock_setting_object.role_arn = mock_role_arn
1369-
mock_setting_object.s3_model_data_url = mock_s3_model_data_url
1370-
1371-
mock_model_uris_retrieve.side_effect = KeyError
1372-
mock_llm_utils_json.load.return_value = {"pipeline_tag": "text-classification"}
1373-
mock_llm_utils_urllib.request.Request.side_effect = Mock()
1374-
1375-
mock_model_json.load.return_value = {"some": "config"}
1376-
mock_model_urllib.request.Request.side_effect = Mock()
1377-
mock_image_uris_retrieve.return_value = "https://some-image-uri"
1378-
1379-
mock_parser.return_value = Mock()
1380-
mock_gather_data.return_value = [[1, 1, 1, 1]]
1381-
product = MIB_CONVERSION_FACTOR * 1 * MEMORY_BUFFER_MULTIPLIER
1382-
1383-
model_builder = ModelBuilder(
1384-
model="stable-diffusion",
1385-
sagemaker_session=mock_session,
1386-
instance_type=mock_instance_type,
1387-
)
1388-
self.assertEqual(model_builder._total_inference_model_size_mib(), product)
1389-
1390-
mock_parser.return_value = Mock()
1391-
mock_gather_data.return_value = None
1392-
model_builder = ModelBuilder(
1393-
model="stable-diffusion",
1394-
sagemaker_session=mock_session,
1395-
instance_type=mock_instance_type,
1396-
)
1397-
with self.assertRaises(ValueError) as _:
1398-
model_builder._total_inference_model_size_mib()
1399-
14001345
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_djl")
14011346
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._can_fit_on_single_gpu")
14021347
@patch("sagemaker.image_uris.retrieve")
@@ -1556,7 +1501,7 @@ def test_try_fetch_gpu_info_throws(
15561501
self.assertEqual(model_builder._can_fit_on_single_gpu(), False)
15571502

15581503
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers", Mock())
1559-
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._total_inference_model_size_mib")
1504+
@patch("sagemaker.serve.builder.model_builder._total_inference_model_size_mib")
15601505
@patch("sagemaker.image_uris.retrieve")
15611506
@patch("sagemaker.djl_inference.model.urllib")
15621507
@patch("sagemaker.djl_inference.model.json")

0 commit comments

Comments
 (0)