Skip to content

Commit 6c9b085

Browse files
committed
fix: estimator hyperparameters in script mode
1 parent b6f6e76 commit 6c9b085

File tree

2 files changed

+58
-2
lines changed

2 files changed

+58
-2
lines changed

src/sagemaker/estimator.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2629,6 +2629,8 @@ def __init__(
26292629
**kwargs,
26302630
)
26312631

2632+
self.set_hyperparameters(**self._hyperparameters)
2633+
26322634
def training_image_uri(self):
26332635
"""Returns the docker image to use for training.
26342636
@@ -2644,9 +2646,15 @@ def set_hyperparameters(self, **kwargs):
26442646
training code on SageMaker. For convenience, this accepts other types
26452647
for keys and values, but ``str()`` will be called to convert them before
26462648
training.
2649+
2650+
If a source directory is specified, this method escapes the dict argument as JSON,
2651+
and updates the private hyperparameter attribute.
26472652
"""
2648-
for k, v in kwargs.items():
2649-
self._hyperparameters[k] = v
2653+
if self.source_dir:
2654+
self._hyperparameters.update(EstimatorBase._json_encode_hyperparameters(kwargs))
2655+
else:
2656+
for k, v in kwargs.items():
2657+
self._hyperparameters[k] = v
26502658

26512659
def hyperparameters(self):
26522660
"""Returns the hyperparameters as a dictionary to use for training.

tests/unit/test_estimator.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4394,3 +4394,51 @@ def test_insert_invalid_source_code_args():
43944394
assert (
43954395
"The entry_point should not be a pipeline variable " "when source_dir is a local path"
43964396
) in str(err.value)
4397+
4398+
4399+
@patch("time.time", return_value=TIME)
4400+
@patch("sagemaker.estimator.tar_and_upload_dir")
4401+
@patch("sagemaker.model.Model._upload_code")
4402+
def test_script_mode_estimator_escapes_hyperparameters_as_json(
4403+
patched_upload_code, patched_tar_and_upload_dir, sagemaker_session
4404+
):
4405+
patched_tar_and_upload_dir.return_value = UploadedCode(
4406+
s3_prefix="s3://%s/%s" % ("bucket", "key"), script_name="script_name"
4407+
)
4408+
sagemaker_session.boto_region_name = REGION
4409+
4410+
instance_type = "ml.p2.xlarge"
4411+
instance_count = 1
4412+
4413+
training_data_uri = "s3://bucket/mydata"
4414+
4415+
jumpstart_source_dir = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[0]}/source_dirs/source.tar.gz"
4416+
4417+
hyperparameters = {
4418+
"int_hyperparam": 1,
4419+
"string_hyperparam": "hello",
4420+
"stringified_numeric_hyperparam": "44",
4421+
"float_hyperparam": 1.234,
4422+
}
4423+
4424+
generic_estimator = Estimator(
4425+
entry_point=SCRIPT_PATH,
4426+
role=ROLE,
4427+
region=REGION,
4428+
sagemaker_session=sagemaker_session,
4429+
instance_count=instance_count,
4430+
instance_type=instance_type,
4431+
source_dir=jumpstart_source_dir,
4432+
image_uri=IMAGE_URI,
4433+
model_uri=MODEL_DATA,
4434+
hyperparameters=hyperparameters,
4435+
)
4436+
generic_estimator.fit(training_data_uri)
4437+
4438+
formatted_hyperparams = EstimatorBase._json_encode_hyperparameters(hyperparameters)
4439+
4440+
assert (
4441+
set(formatted_hyperparams.items())
4442+
- set(sagemaker_session.train.call_args_list[0][1]["hyperparameters"].items())
4443+
== set()
4444+
)

0 commit comments

Comments
 (0)