Skip to content

Commit 2432b26

Browse files
authored
fix: excessive jumpstart instance type logging (#4256)
1 parent 41cd2f6 commit 2432b26

File tree

4 files changed

+33
-6
lines changed

4 files changed

+33
-6
lines changed

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,7 @@ def get_deploy_kwargs(
336336
tolerate_vulnerable_model=tolerate_vulnerable_model,
337337
tolerate_deprecated_model=tolerate_deprecated_model,
338338
training_instance_type=training_instance_type,
339+
disable_instance_type_logging=True,
339340
)
340341

341342
estimator_deploy_kwargs: JumpStartEstimatorDeployKwargs = JumpStartEstimatorDeployKwargs(

src/sagemaker/jumpstart/factory/model.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,9 @@ def _add_vulnerable_and_deprecated_status_to_kwargs(
171171
return kwargs
172172

173173

174-
def _add_instance_type_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelInitKwargs:
174+
def _add_instance_type_to_kwargs(
175+
kwargs: JumpStartModelInitKwargs, disable_instance_type_logging: bool = False
176+
) -> JumpStartModelInitKwargs:
175177
"""Sets instance type based on default or override, returns full kwargs."""
176178

177179
orig_instance_type = kwargs.instance_type
@@ -187,7 +189,7 @@ def _add_instance_type_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartM
187189
training_instance_type=kwargs.training_instance_type,
188190
)
189191

190-
if orig_instance_type is None:
192+
if not disable_instance_type_logging and orig_instance_type is None:
191193
JUMPSTART_LOGGER.info(
192194
"No instance type selected for inference hosting endpoint. Defaulting to %s.",
193195
kwargs.instance_type,
@@ -551,9 +553,7 @@ def get_deploy_kwargs(
551553

552554
deploy_kwargs = _add_endpoint_name_to_kwargs(kwargs=deploy_kwargs)
553555

554-
deploy_kwargs = _add_instance_type_to_kwargs(
555-
kwargs=deploy_kwargs,
556-
)
556+
deploy_kwargs = _add_instance_type_to_kwargs(kwargs=deploy_kwargs)
557557

558558
deploy_kwargs.initial_instance_count = initial_instance_count or 1
559559

@@ -677,6 +677,7 @@ def get_init_kwargs(
677677
git_config: Optional[Dict[str, str]] = None,
678678
model_package_arn: Optional[str] = None,
679679
training_instance_type: Optional[str] = None,
680+
disable_instance_type_logging: bool = False,
680681
resources: Optional[ResourceRequirements] = None,
681682
) -> JumpStartModelInitKwargs:
682683
"""Returns kwargs required to instantiate `sagemaker.estimator.Model` object."""
@@ -720,7 +721,7 @@ def get_init_kwargs(
720721
model_init_kwargs = _add_model_name_to_kwargs(kwargs=model_init_kwargs)
721722

722723
model_init_kwargs = _add_instance_type_to_kwargs(
723-
kwargs=model_init_kwargs,
724+
kwargs=model_init_kwargs, disable_instance_type_logging=disable_instance_type_logging
724725
)
725726

726727
model_init_kwargs = _add_image_uri_to_kwargs(kwargs=model_init_kwargs)

tests/unit/sagemaker/jumpstart/estimator/test_estimator.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@
5757

5858

5959
class EstimatorTest(unittest.TestCase):
60+
@mock.patch("sagemaker.jumpstart.factory.estimator.JUMPSTART_LOGGER")
61+
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_LOGGER")
6062
@mock.patch("sagemaker.utils.sagemaker_timestamp")
6163
@mock.patch("sagemaker.jumpstart.estimator.is_valid_model_id")
6264
@mock.patch("sagemaker.jumpstart.factory.model.Session")
@@ -77,6 +79,8 @@ def test_non_prepacked(
7779
mock_session_model: mock.Mock,
7880
mock_is_valid_model_id: mock.Mock,
7981
mock_sagemaker_timestamp: mock.Mock,
82+
mock_jumpstart_model_factory_logger: mock.Mock,
83+
mock_jumpstart_estimator_factory_logger: mock.Mock,
8084
):
8185
mock_is_valid_model_id.return_value = True
8286

@@ -94,6 +98,9 @@ def test_non_prepacked(
9498
estimator = JumpStartEstimator(
9599
model_id=model_id,
96100
)
101+
mock_jumpstart_estimator_factory_logger.info.assert_called_once_with(
102+
"No instance type selected for training job. Defaulting to %s.", "ml.p3.2xlarge"
103+
)
97104

98105
mock_estimator_init.assert_called_once_with(
99106
instance_type="ml.p3.2xlarge",
@@ -131,13 +138,22 @@ def test_non_prepacked(
131138
f"{get_training_dataset_for_model_and_version(model_id, model_version)}",
132139
}
133140

141+
mock_jumpstart_estimator_factory_logger.info.reset_mock()
134142
estimator.fit(channels)
143+
mock_jumpstart_estimator_factory_logger.info.assert_not_called()
135144

136145
mock_estimator_fit.assert_called_once_with(
137146
inputs=channels, wait=True, job_name="blahblahblah-9876"
138147
)
139148

149+
mock_jumpstart_model_factory_logger.info.reset_mock()
150+
mock_jumpstart_estimator_factory_logger.info.reset_mock()
140151
estimator.deploy()
152+
mock_jumpstart_model_factory_logger.info.assert_called_once_with(
153+
"No instance type selected for inference hosting endpoint. Defaulting to %s.",
154+
"ml.p2.xlarge",
155+
)
156+
mock_jumpstart_estimator_factory_logger.info.assert_not_called()
141157

142158
mock_estimator_deploy.assert_called_once_with(
143159
instance_type="ml.p2.xlarge",

tests/unit/sagemaker/jumpstart/model/test_model.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class ModelTest(unittest.TestCase):
5151

5252
mock_session_empty_config = MagicMock(sagemaker_config={})
5353

54+
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_LOGGER")
5455
@mock.patch("sagemaker.utils.sagemaker_timestamp")
5556
@mock.patch("sagemaker.jumpstart.model.is_valid_model_id")
5657
@mock.patch("sagemaker.jumpstart.factory.model.Session")
@@ -66,6 +67,7 @@ def test_non_prepacked(
6667
mock_session: mock.Mock,
6768
mock_is_valid_model_id: mock.Mock,
6869
mock_sagemaker_timestamp: mock.Mock,
70+
mock_jumpstart_model_factory_logger: mock.Mock,
6971
):
7072
mock_model_deploy.return_value = default_predictor
7173

@@ -78,9 +80,14 @@ def test_non_prepacked(
7880

7981
mock_session.return_value = sagemaker_session
8082

83+
mock_jumpstart_model_factory_logger.info.reset_mock()
8184
model = JumpStartModel(
8285
model_id=model_id,
8386
)
87+
mock_jumpstart_model_factory_logger.info.assert_called_once_with(
88+
"No " "instance type selected for inference hosting endpoint. " "Defaulting to %s.",
89+
"ml.p2.xlarge",
90+
)
8491

8592
mock_model_init.assert_called_once_with(
8693
image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/"
@@ -104,7 +111,9 @@ def test_non_prepacked(
104111
name="blahblahblah-7777",
105112
)
106113

114+
mock_jumpstart_model_factory_logger.info.reset_mock()
107115
model.deploy()
116+
mock_jumpstart_model_factory_logger.info.assert_not_called()
108117

109118
mock_model_deploy.assert_called_once_with(
110119
initial_instance_count=1,

0 commit comments

Comments
 (0)