@@ -943,3 +943,75 @@ def test_algorithm_no_required_hyperparameters(session):
943
943
train_instance_count = 1 ,
944
944
sagemaker_session = session ,
945
945
)
946
+
947
+
948
+ def test_algorithm_attach_from_hyperparameter_tuning ():
949
+ session = Mock ()
950
+ job_name = "training-job-that-is-part-of-a-tuning-job"
951
+ algo_arn = "arn:aws:sagemaker:us-east-2:000000000000:algorithm/scikit-decision-trees"
952
+ role_arn = "arn:aws:iam::123412341234:role/SageMakerRole"
953
+ instance_count = 1
954
+ instance_type = "ml.m4.xlarge"
955
+ train_volume_size = 30
956
+ input_mode = "File"
957
+
958
+ session .sagemaker_client .list_tags .return_value = {"Tags" : []}
959
+ session .sagemaker_client .describe_algorithm .return_value = DESCRIBE_ALGORITHM_RESPONSE
960
+ session .sagemaker_client .describe_training_job .return_value = {
961
+ "TrainingJobName" : job_name ,
962
+ "TrainingJobArn" : "arn:aws:sagemaker:us-east-2:123412341234:training-job/%s" % job_name ,
963
+ "TuningJobArn" : "arn:aws:sagemaker:us-east-2:123412341234:hyper-parameter-tuning-job/%s"
964
+ % job_name ,
965
+ "ModelArtifacts" : {
966
+ "S3ModelArtifacts" : "s3://sagemaker-us-east-2-123412341234/output/model.tar.gz"
967
+ },
968
+ "TrainingJobOutput" : {
969
+ "S3TrainingJobOutput" : "s3://sagemaker-us-east-2-123412341234/output/output.tar.gz"
970
+ },
971
+ "TrainingJobStatus" : "Succeeded" ,
972
+ "HyperParameters" : {
973
+ "_tuning_objective_metric" : "validation:accuracy" ,
974
+ "max_leaf_nodes" : 1 ,
975
+ "free_text_hp1" : "foo" ,
976
+ },
977
+ "AlgorithmSpecification" : {"AlgorithmName" : algo_arn , "TrainingInputMode" : input_mode },
978
+ "MetricDefinitions" : [
979
+ {"Name" : "validation:accuracy" , "Regex" : "validation-accuracy: (\\ S+)" }
980
+ ],
981
+ "RoleArn" : role_arn ,
982
+ "InputDataConfig" : [
983
+ {
984
+ "ChannelName" : "training" ,
985
+ "DataSource" : {
986
+ "S3DataSource" : {
987
+ "S3DataType" : "S3Prefix" ,
988
+ "S3Uri" : "s3://sagemaker-us-east-2-123412341234/input/training.csv" ,
989
+ "S3DataDistributionType" : "FullyReplicated" ,
990
+ }
991
+ },
992
+ "CompressionType" : "None" ,
993
+ "RecordWrapperType" : "None" ,
994
+ }
995
+ ],
996
+ "OutputDataConfig" : {
997
+ "KmsKeyId" : "" ,
998
+ "S3OutputPath" : "s3://sagemaker-us-east-2-123412341234/output" ,
999
+ "RemoveJobNameFromS3OutputPath" : False ,
1000
+ },
1001
+ "ResourceConfig" : {
1002
+ "InstanceType" : instance_type ,
1003
+ "InstanceCount" : instance_count ,
1004
+ "VolumeSizeInGB" : train_volume_size ,
1005
+ },
1006
+ "StoppingCondition" : {"MaxRuntimeInSeconds" : 86400 },
1007
+ }
1008
+
1009
+ estimator = AlgorithmEstimator .attach (job_name , sagemaker_session = session )
1010
+ assert estimator .hyperparameters () == {"max_leaf_nodes" : 1 , "free_text_hp1" : "foo" }
1011
+ assert estimator .algorithm_arn == algo_arn
1012
+ assert estimator .role == role_arn
1013
+ assert estimator .train_instance_count == instance_count
1014
+ assert estimator .train_instance_type == instance_type
1015
+ assert estimator .train_volume_size == train_volume_size
1016
+ assert estimator .input_mode == input_mode
1017
+ assert estimator .sagemaker_session == session
0 commit comments