Skip to content

Commit 5783194

Browse files
committed
Fix repr
1 parent 9d07415 commit 5783194

File tree

4 files changed

+19
-16
lines changed

4 files changed

+19
-16
lines changed

src/sagemaker/serve/builder/model_builder.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from pathlib import Path
2222

23+
from accelerate.commands.estimate import estimate_command_parser, gather_data
2324
from sagemaker import Session
2425
from sagemaker.model import Model
2526
from sagemaker.base_predictor import PredictorBase
@@ -730,14 +731,11 @@ def _total_inference_model_size_mib(self):
730731
to add up to an additional 20% to the given model size as found by EleutherAI.
731732
"""
732733
try:
733-
import accelerate.commands.estimate.estimate_command_parser as estimate_parser
734-
import accelerate.commands.estimate.gather_data as estimate_gather
735-
736734
dtypes = self.env_vars.get("dtypes", "float32")
737-
parser = estimate_parser()
735+
parser = estimate_command_parser()
738736
args = parser.parse_args([self.model, "--dtypes", dtypes])
739737

740-
output = estimate_gather(
738+
output = gather_data(
741739
args
742740
) # "dtype", "Largest Layer", "Total Size Bytes", "Training using Adam"
743741
except ImportError as e:

src/sagemaker/serve/builder/schema_builder.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -208,13 +208,20 @@ def _get_inverse(self, obj):
208208

209209
def __repr__(self):
210210
"""Placeholder docstring"""
211-
return (
212-
f"SchemaBuilder(\n"
213-
f"input_serializer={self.input_serializer}\n"
214-
f"output_serializer={self.output_serializer}\n"
215-
f"input_deserializer={self.input_deserializer._deserializer}\n"
216-
f"output_deserializer={self.output_deserializer._deserializer})"
217-
)
211+
if hasattr(self, "input_serializer") and hasattr(self, "output_serializer"):
212+
return (
213+
f"SchemaBuilder(\n"
214+
f"input_serializer={self.input_serializer}\n"
215+
f"output_serializer={self.output_serializer}\n"
216+
f"input_deserializer={self.input_deserializer._deserializer}\n"
217+
f"output_deserializer={self.output_deserializer._deserializer})"
218+
)
219+
elif hasattr(self, "custom_input_translator") and hasattr(self, "custom_output_translator"):
220+
return (
221+
f"SchemaBuilder(\n"
222+
f"custom_input_translator={self.custom_input_translator}\n"
223+
f"custom_output_translator={self.custom_output_translator}\n"
224+
)
218225

219226
def generate_marshalling_map(self) -> dict:
220227
"""Generate marshalling map for the schema builder"""

tests/integ/sagemaker/serve/test_serve_pt_happy.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,6 @@ def model_builder(request):
181181
# ), f"{caught_ex} was thrown when running pytorch squeezenet local container test"
182182

183183

184-
@pytest.mark.skip(reason="Failing test. Fix is pending.")
185184
@pytest.mark.skipif(
186185
PYTHON_VERSION_IS_NOT_310, # or NOT_RUNNING_ON_INF_EXP_DEV_PIPELINE,
187186
reason="The goal of these test are to test the serving components of our feature",

tests/unit/sagemaker/serve/builder/test_model_builder.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1343,9 +1343,8 @@ def test_build_for_transformers_happy_case_with_valid_gpu_fallback(
13431343
self.assertEqual(model_builder._can_fit_on_single_gpu(), True)
13441344

13451345
@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers", Mock())
1346-
@patch("sagemaker.serve.builder.model_builder.accelerate.commands.estimate"
1347-
".estimate_command_parser")
1348-
@patch("sagemaker.serve.builder.model_builder.accelerate.commands.estimate.gather_data")
1346+
@patch("sagemaker.serve.builder.model_builder.estimate_command_parser")
1347+
@patch("sagemaker.serve.builder.model_builder.gather_data")
13491348
@patch("sagemaker.image_uris.retrieve")
13501349
@patch("sagemaker.djl_inference.model.urllib")
13511350
@patch("sagemaker.djl_inference.model.json")

0 commit comments

Comments
 (0)