@@ -4394,3 +4394,51 @@ def test_insert_invalid_source_code_args():
4394
4394
assert (
4395
4395
"The entry_point should not be a pipeline variable " "when source_dir is a local path"
4396
4396
) in str (err .value )
4397
+
4398
+
4399
+ @patch ("time.time" , return_value = TIME )
4400
+ @patch ("sagemaker.estimator.tar_and_upload_dir" )
4401
+ @patch ("sagemaker.model.Model._upload_code" )
4402
+ def test_script_mode_estimator_escapes_hyperparameters_as_json (
4403
+ patched_upload_code , patched_tar_and_upload_dir , sagemaker_session
4404
+ ):
4405
+ patched_tar_and_upload_dir .return_value = UploadedCode (
4406
+ s3_prefix = "s3://%s/%s" % ("bucket" , "key" ), script_name = "script_name"
4407
+ )
4408
+ sagemaker_session .boto_region_name = REGION
4409
+
4410
+ instance_type = "ml.p2.xlarge"
4411
+ instance_count = 1
4412
+
4413
+ training_data_uri = "s3://bucket/mydata"
4414
+
4415
+ jumpstart_source_dir = f"s3://{ list (JUMPSTART_BUCKET_NAME_SET )[0 ]} /source_dirs/source.tar.gz"
4416
+
4417
+ hyperparameters = {
4418
+ "int_hyperparam" : 1 ,
4419
+ "string_hyperparam" : "hello" ,
4420
+ "stringified_numeric_hyperparam" : "44" ,
4421
+ "float_hyperparam" : 1.234 ,
4422
+ }
4423
+
4424
+ generic_estimator = Estimator (
4425
+ entry_point = SCRIPT_PATH ,
4426
+ role = ROLE ,
4427
+ region = REGION ,
4428
+ sagemaker_session = sagemaker_session ,
4429
+ instance_count = instance_count ,
4430
+ instance_type = instance_type ,
4431
+ source_dir = jumpstart_source_dir ,
4432
+ image_uri = IMAGE_URI ,
4433
+ model_uri = MODEL_DATA ,
4434
+ hyperparameters = hyperparameters ,
4435
+ )
4436
+ generic_estimator .fit (training_data_uri )
4437
+
4438
+ formatted_hyperparams = EstimatorBase ._json_encode_hyperparameters (hyperparameters )
4439
+
4440
+ assert (
4441
+ set (formatted_hyperparams .items ())
4442
+ - set (sagemaker_session .train .call_args_list [0 ][1 ]["hyperparameters" ].items ())
4443
+ == set ()
4444
+ )
0 commit comments