@@ -291,6 +291,7 @@ def test_s3_input_all_arguments():
291
291
}
292
292
293
293
COMPLETED_DESCRIBE_JOB_RESULT = dict (DEFAULT_EXPECTED_TRAIN_JOB_ARGS )
294
+ COMPLETED_DESCRIBE_JOB_RESULT .update ({'TrainingJobArn' : 'arn:aws:sagemaker:us-west-2:336:training-job/' + JOB_NAME })
294
295
COMPLETED_DESCRIBE_JOB_RESULT .update ({'TrainingJobStatus' : 'Completed' })
295
296
COMPLETED_DESCRIBE_JOB_RESULT .update (
296
297
{'ModelArtifacts' : {
@@ -861,18 +862,21 @@ def test_create_model_failure(expand_container_def, sagemaker_session):
861
862
def test_create_model_from_job (sagemaker_session ):
862
863
ims = sagemaker_session
863
864
ims .sagemaker_client .describe_training_job .return_value = COMPLETED_DESCRIBE_JOB_RESULT
865
+ ims .sagemaker_client .list_tags .return_value = {'Tags' : TAGS }
864
866
ims .create_model_from_job (JOB_NAME )
865
867
866
868
assert call (TrainingJobName = JOB_NAME ) in ims .sagemaker_client .describe_training_job .call_args_list
867
869
ims .sagemaker_client .create_model .assert_called_with (ExecutionRoleArn = EXPANDED_ROLE ,
868
870
ModelName = JOB_NAME ,
869
871
PrimaryContainer = PRIMARY_CONTAINER ,
870
- VpcConfig = VPC_CONFIG )
872
+ VpcConfig = VPC_CONFIG ,
873
+ Tags = TAGS )
871
874
872
875
873
876
def test_create_model_from_job_with_image (sagemaker_session ):
874
877
ims = sagemaker_session
875
878
ims .sagemaker_client .describe_training_job .return_value = COMPLETED_DESCRIBE_JOB_RESULT
879
+ ims .sagemaker_client .list_tags .return_value = {'Tags' : TAGS }
876
880
ims .create_model_from_job (JOB_NAME , primary_container_image = 'some-image' )
877
881
[create_model_call ] = ims .sagemaker_client .create_model .call_args_list
878
882
assert dict (create_model_call [1 ]['PrimaryContainer' ])['Image' ] == 'some-image'
@@ -881,6 +885,7 @@ def test_create_model_from_job_with_image(sagemaker_session):
881
885
def test_create_model_from_job_with_container_def (sagemaker_session ):
882
886
ims = sagemaker_session
883
887
ims .sagemaker_client .describe_training_job .return_value = COMPLETED_DESCRIBE_JOB_RESULT
888
+ ims .sagemaker_client .list_tags .return_value = {'Tags' : TAGS }
884
889
ims .create_model_from_job (JOB_NAME , primary_container_image = 'some-image' , model_data_url = 'some-data' ,
885
890
env = {'a' : 'b' })
886
891
[create_model_call ] = ims .sagemaker_client .create_model .call_args_list
@@ -895,6 +900,7 @@ def test_create_model_from_job_with_vpc_config_override(sagemaker_session):
895
900
896
901
ims = sagemaker_session
897
902
ims .sagemaker_client .describe_training_job .return_value = COMPLETED_DESCRIBE_JOB_RESULT
903
+ ims .sagemaker_client .list_tags .return_value = {'Tags' : TAGS }
898
904
ims .create_model_from_job (JOB_NAME , vpc_config_override = vpc_config_override )
899
905
assert ims .sagemaker_client .create_model .call_args [1 ]['VpcConfig' ] == vpc_config_override
900
906
0 commit comments