Skip to content

Commit 9e64c40

Browse files
committed
move accelerate to utils
1 parent 5783194 commit 9e64c40

File tree

3 files changed

+47
-96
lines changed

3 files changed

+47
-96
lines changed

src/sagemaker/serve/builder/model_builder.py

Lines changed: 7 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from pathlib import Path
2222

23-
from accelerate.commands.estimate import estimate_command_parser, gather_data
23+
2424
from sagemaker import Session
2525
from sagemaker.model import Model
2626
from sagemaker.base_predictor import PredictorBase
@@ -43,7 +43,8 @@
4343
from sagemaker.serve.utils import task
4444
from sagemaker.serve.utils.exceptions import TaskNotFoundException
4545
from sagemaker.serve.utils.predictors import _get_local_mode_predictor
46-
from sagemaker.serve.utils.hardware_detector import _get_gpu_info, _get_gpu_info_fallback
46+
from sagemaker.serve.utils.hardware_detector import _get_gpu_info, _get_gpu_info_fallback,\
47+
_total_inference_model_size_mib
4748
from sagemaker.serve.detector.image_detector import (
4849
auto_detect_container,
4950
_detect_framework_and_version,
@@ -70,13 +71,6 @@
7071
ModelServer.DJL_SERVING,
7172
}
7273

73-
MIB_CONVERSION_FACTOR = 0.00000095367431640625
74-
MEMORY_BUFFER_MULTIPLIER = 1.2 # 20% buffer
75-
VERSION_DETECTION_ERROR = (
76-
"Please install accelerate and transformers for HuggingFace (HF) model "
77-
"size calculations e.g. pip install 'sagemaker[huggingface]'"
78-
)
79-
8074

8175
# pylint: disable=attribute-defined-outside-init, disable=E1101
8276
@dataclass
@@ -723,43 +717,19 @@ def _schema_builder_init(self, model_task: str):
723717
except ValueError:
724718
raise TaskNotFoundException(f"Schema builder for {model_task} could not be found.")
725719

726-
def _total_inference_model_size_mib(self):
727-
"""Calculates the model size from HF accelerate
728-
729-
This function gets the model size from accelerate. It also adds a
730-
padding and converts to size MiB. When performing inference, expect
731-
to add up to an additional 20% to the given model size as found by EleutherAI.
732-
"""
733-
try:
734-
dtypes = self.env_vars.get("dtypes", "float32")
735-
parser = estimate_command_parser()
736-
args = parser.parse_args([self.model, "--dtypes", dtypes])
737-
738-
output = gather_data(
739-
args
740-
) # "dtype", "Largest Layer", "Total Size Bytes", "Training using Adam"
741-
except ImportError as e:
742-
logger.warning(VERSION_DETECTION_ERROR)
743-
raise e
744-
745-
if output is None:
746-
raise ValueError(f"Could not get Model size for {self.model}")
747-
748-
total_memory_size_mib = MEMORY_BUFFER_MULTIPLIER * output[0][2] * MIB_CONVERSION_FACTOR
749-
logger.info("Total memory size MIB: %s", total_memory_size_mib)
750-
return total_memory_size_mib
751-
752720
def _can_fit_on_single_gpu(self) -> Type[bool]:
753721
"""Check if model can fit on a single GPU
754722
755723
If the size of the model is <= single gpu memory size, returns True else False
756724
"""
757725
try:
758726
single_gpu_size_mib = self._try_fetch_gpu_info()
759-
if self._total_inference_model_size_mib() <= single_gpu_size_mib:
727+
if _total_inference_model_size_mib(self.model, self.env_vars.get("dtypes", "float32")) \
728+
<= single_gpu_size_mib:
760729
logger.info(
761730
"Total inference model size MIB %s, single GPU size for instance MIB %s",
762-
self._total_inference_model_size_mib(),
731+
_total_inference_model_size_mib(self.model, self.env_vars.get("dtypes",
732+
"float32")),
763733
single_gpu_size_mib,
764734
)
765735
return True

src/sagemaker/serve/utils/hardware_detector.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,21 @@
1818

1919
from botocore.exceptions import ClientError
2020

21+
from accelerate.commands.estimate import estimate_command_parser, gather_data
2122
from sagemaker import Session
23+
from sagemaker.model import Model
2224
from sagemaker import instance_types_gpu_info
2325

2426
logger = logging.getLogger(__name__)
2527

2628

29+
MIB_CONVERSION_FACTOR = 0.00000095367431640625
30+
MEMORY_BUFFER_MULTIPLIER = 1.2 # 20% buffer
31+
VERSION_DETECTION_ERROR = (
32+
"Please install accelerate and transformers for HuggingFace (HF) model "
33+
"size calculations e.g. pip install 'sagemaker[huggingface]'"
34+
)
35+
2736
def _get_gpu_info(instance_type: str, session: Session) -> Tuple[int, int]:
2837
"""Get GPU info for the provided instance
2938
@@ -108,3 +117,30 @@ def _format_instance_type(instance_type: str) -> str:
108117

109118
ec2_instance = ".".join(split_instance)
110119
return ec2_instance
120+
121+
122+
def _total_inference_model_size_mib(model: Model, dtype: str) -> int:
123+
"""Calculates the model size from HF accelerate
124+
125+
This function gets the model size from accelerate. It also adds a
126+
padding and converts to size MiB. When performing inference, expect
127+
to add up to an additional 20% to the given model size as found by EleutherAI.
128+
"""
129+
try:
130+
dtypes = dtype
131+
parser = estimate_command_parser()
132+
args = parser.parse_args([model, "--dtypes", dtypes])
133+
134+
output = gather_data(
135+
args
136+
) # "dtype", "Largest Layer", "Total Size Bytes", "Training using Adam"
137+
except ImportError as e:
138+
logger.warning(VERSION_DETECTION_ERROR)
139+
raise e
140+
141+
if output is None:
142+
raise ValueError(f"Could not get Model size for {model}")
143+
144+
total_memory_size_mib = MEMORY_BUFFER_MULTIPLIER * output[0][2] * MIB_CONVERSION_FACTOR
145+
logger.info("Total memory size MIB: %s", total_memory_size_mib)
146+
return total_memory_size_mib

tests/unit/sagemaker/serve/builder/test_model_builder.py

Lines changed: 4 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1205,7 +1205,7 @@ def test_build_for_transformers_happy_case(
12051205

12061206
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers")
12071207
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._try_fetch_gpu_info")
1208-
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._total_inference_model_size_mib")
1208+
@patch("sagemaker.serve.builder.model_builder._total_inference_model_size_mib")
12091209
@patch("sagemaker.image_uris.retrieve")
12101210
@patch("sagemaker.djl_inference.model.urllib")
12111211
@patch("sagemaker.djl_inference.model.json")
@@ -1248,7 +1248,7 @@ def test_build_for_transformers_happy_case_with_values(
12481248

12491249
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_djl", Mock())
12501250
@patch("sagemaker.serve.builder.model_builder._get_gpu_info")
1251-
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._total_inference_model_size_mib")
1251+
@patch("sagemaker.serve.builder.model_builder._total_inference_model_size_mib")
12521252
@patch("sagemaker.image_uris.retrieve")
12531253
@patch("sagemaker.djl_inference.model.urllib")
12541254
@patch("sagemaker.djl_inference.model.json")
@@ -1293,7 +1293,7 @@ def test_build_for_transformers_happy_case_with_valid_gpu_info(
12931293
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers", Mock())
12941294
@patch("sagemaker.serve.builder.model_builder._get_gpu_info")
12951295
@patch("sagemaker.serve.builder.model_builder._get_gpu_info_fallback")
1296-
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._total_inference_model_size_mib")
1296+
@patch("sagemaker.serve.builder.model_builder._total_inference_model_size_mib")
12971297
@patch("sagemaker.image_uris.retrieve")
12981298
@patch("sagemaker.djl_inference.model.urllib")
12991299
@patch("sagemaker.djl_inference.model.json")
@@ -1342,61 +1342,6 @@ def test_build_for_transformers_happy_case_with_valid_gpu_fallback(
13421342
)
13431343
self.assertEqual(model_builder._can_fit_on_single_gpu(), True)
13441344

1345-
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers", Mock())
1346-
@patch("sagemaker.serve.builder.model_builder.estimate_command_parser")
1347-
@patch("sagemaker.serve.builder.model_builder.gather_data")
1348-
@patch("sagemaker.image_uris.retrieve")
1349-
@patch("sagemaker.djl_inference.model.urllib")
1350-
@patch("sagemaker.djl_inference.model.json")
1351-
@patch("sagemaker.huggingface.llm_utils.urllib")
1352-
@patch("sagemaker.huggingface.llm_utils.json")
1353-
@patch("sagemaker.model_uris.retrieve")
1354-
@patch("sagemaker.serve.builder.model_builder._ServeSettings")
1355-
def test_build_for_transformers_happy_case_hugging_face_responses(
1356-
self,
1357-
mock_serveSettings,
1358-
mock_model_uris_retrieve,
1359-
mock_llm_utils_json,
1360-
mock_llm_utils_urllib,
1361-
mock_model_json,
1362-
mock_model_urllib,
1363-
mock_image_uris_retrieve,
1364-
mock_gather_data,
1365-
mock_parser,
1366-
):
1367-
mock_setting_object = mock_serveSettings.return_value
1368-
mock_setting_object.role_arn = mock_role_arn
1369-
mock_setting_object.s3_model_data_url = mock_s3_model_data_url
1370-
1371-
mock_model_uris_retrieve.side_effect = KeyError
1372-
mock_llm_utils_json.load.return_value = {"pipeline_tag": "text-classification"}
1373-
mock_llm_utils_urllib.request.Request.side_effect = Mock()
1374-
1375-
mock_model_json.load.return_value = {"some": "config"}
1376-
mock_model_urllib.request.Request.side_effect = Mock()
1377-
mock_image_uris_retrieve.return_value = "https://some-image-uri"
1378-
1379-
mock_parser.return_value = Mock()
1380-
mock_gather_data.return_value = [[1, 1, 1, 1]]
1381-
product = MIB_CONVERSION_FACTOR * 1 * MEMORY_BUFFER_MULTIPLIER
1382-
1383-
model_builder = ModelBuilder(
1384-
model="stable-diffusion",
1385-
sagemaker_session=mock_session,
1386-
instance_type=mock_instance_type,
1387-
)
1388-
self.assertEqual(model_builder._total_inference_model_size_mib(), product)
1389-
1390-
mock_parser.return_value = Mock()
1391-
mock_gather_data.return_value = None
1392-
model_builder = ModelBuilder(
1393-
model="stable-diffusion",
1394-
sagemaker_session=mock_session,
1395-
instance_type=mock_instance_type,
1396-
)
1397-
with self.assertRaises(ValueError) as _:
1398-
model_builder._total_inference_model_size_mib()
1399-
14001345
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_djl")
14011346
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._can_fit_on_single_gpu")
14021347
@patch("sagemaker.image_uris.retrieve")
@@ -1556,7 +1501,7 @@ def test_try_fetch_gpu_info_throws(
15561501
self.assertEqual(model_builder._can_fit_on_single_gpu(), False)
15571502

15581503
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers", Mock())
1559-
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._total_inference_model_size_mib")
1504+
@patch("sagemaker.serve.builder.model_builder._total_inference_model_size_mib")
15601505
@patch("sagemaker.image_uris.retrieve")
15611506
@patch("sagemaker.djl_inference.model.urllib")
15621507
@patch("sagemaker.djl_inference.model.json")

0 commit comments

Comments
 (0)