Skip to content

Commit 8e89266

Browse files
committed
Fix: Move accelerate dependency closer to size calculations
1 parent 09fe1c6 commit 8e89266

File tree

1 file changed

+16
-8
lines changed

1 file changed

+16
-8
lines changed

src/sagemaker/serve/builder/model_builder.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
from pathlib import Path
2222

23-
from accelerate.commands.estimate import estimate_command_parser, gather_data
2423
from sagemaker import Session
2524
from sagemaker.model import Model
2625
from sagemaker.base_predictor import PredictorBase
@@ -72,6 +71,8 @@
7271

7372
MIB_CONVERSION_FACTOR = 0.00000095367431640625
7473
MEMORY_BUFFER_MULTIPLIER = 1.2 # 20% buffer
74+
VERSION_DETECTION_ERROR = "Please install accelerate and transformers for HuggingFace (HF) model " \
75+
"size calculations pip install 'sagemaker[huggingface]'"
7576

7677

7778
# pylint: disable=attribute-defined-outside-init
@@ -726,13 +727,20 @@ def _total_inference_model_size_mib(self):
726727
padding and converts to size MiB. When performing inference, expect
727728
to add up to an additional 20% to the given model size as found by EleutherAI.
728729
"""
729-
dtypes = self.env_vars.get("dtypes", "float32")
730-
parser = estimate_command_parser()
731-
args = parser.parse_args([self.model, "--dtypes", dtypes])
732-
733-
output = gather_data(
734-
args
735-
) # "dtype", "Largest Layer", "Total Size Bytes", "Training using Adam"
730+
try:
731+
import accelerate.commands.estimate.estimate_command_parser
732+
import accelerate.commands.estimate.gather_data
733+
734+
dtypes = self.env_vars.get("dtypes", "float32")
735+
parser = accelerate.commands.estimate.estimate_command_parser.estimate_command_parser()
736+
args = parser.parse_args([self.model, "--dtypes", dtypes])
737+
738+
output = accelerate.commands.estimate.gather_data.gather_data(
739+
args
740+
) # "dtype", "Largest Layer", "Total Size Bytes", "Training using Adam"
741+
except ImportError as e:
742+
logger.warning(VERSION_DETECTION_ERROR)
743+
raise e
736744

737745
if output is None:
738746
raise ValueError(f"Could not get Model size for {self.model}")

0 commit comments

Comments
 (0)