|
20 | 20 |
|
21 | 21 | from pathlib import Path
|
22 | 22 |
|
23 |
| -from accelerate.commands.estimate import estimate_command_parser, gather_data |
24 | 23 | from sagemaker import Session
|
25 | 24 | from sagemaker.model import Model
|
26 | 25 | from sagemaker.base_predictor import PredictorBase
|
|
72 | 71 |
|
73 | 72 | MIB_CONVERSION_FACTOR = 0.00000095367431640625
|
74 | 73 | MEMORY_BUFFER_MULTIPLIER = 1.2 # 20% buffer
|
| 74 | +VERSION_DETECTION_ERROR = "Please install accelerate and transformers for HuggingFace (HF) model " \ |
| 75 | + "size calculations pip install 'sagemaker[huggingface]'" |
75 | 76 |
|
76 | 77 |
|
77 | 78 | # pylint: disable=attribute-defined-outside-init
|
@@ -726,13 +727,20 @@ def _total_inference_model_size_mib(self):
|
726 | 727 | padding and converts to size MiB. When performing inference, expect
|
727 | 728 | to add up to an additional 20% to the given model size as found by EleutherAI.
|
728 | 729 | """
|
729 |
| - dtypes = self.env_vars.get("dtypes", "float32") |
730 |
| - parser = estimate_command_parser() |
731 |
| - args = parser.parse_args([self.model, "--dtypes", dtypes]) |
732 |
| - |
733 |
| - output = gather_data( |
734 |
| - args |
735 |
| - ) # "dtype", "Largest Layer", "Total Size Bytes", "Training using Adam" |
| 730 | + try: |
| 731 | + import accelerate.commands.estimate.estimate_command_parser |
| 732 | + import accelerate.commands.estimate.gather_data |
| 733 | + |
| 734 | + dtypes = self.env_vars.get("dtypes", "float32") |
| 735 | + parser = accelerate.commands.estimate.estimate_command_parser.estimate_command_parser() |
| 736 | + args = parser.parse_args([self.model, "--dtypes", dtypes]) |
| 737 | + |
| 738 | + output = accelerate.commands.estimate.gather_data.gather_data( |
| 739 | + args |
| 740 | + ) # "dtype", "Largest Layer", "Total Size Bytes", "Training using Adam" |
| 741 | + except ImportError as e: |
| 742 | + logger.warning(VERSION_DETECTION_ERROR) |
| 743 | + raise e |
736 | 744 |
|
737 | 745 | if output is None:
|
738 | 746 | raise ValueError(f"Could not get Model size for {self.model}")
|
|
0 commit comments