Skip to content

Commit f95df5b

Browse files
author
Basil Beirouti
committed
added unit test
1 parent 89493cf commit f95df5b

File tree

1 file changed

+92
-0
lines changed

1 file changed

+92
-0
lines changed

tests/unit/sagemaker/model/test_model.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -687,3 +687,95 @@ def test_script_mode_model_uses_proper_sagemaker_submit_dir(repack_model, sagema
687687
]
688688
== "/opt/ml/model/code"
689689
)
690+
691+
692+
model_name = model_package_name = "my-flower-detection-model"
693+
model_description = "This model accepts petal length, petal width, sepal length, sepal width and predicts whether flower is of type setosa, versicolor, or virginica"
694+
695+
supported_realtime_inference_instance_types = ["ml.m4.xlarge"]
696+
supported_batch_transform_instance_types = ["ml.m4.xlarge"]
697+
698+
supported_content_types = ["text/csv", "application/json", "application/jsonlines"]
699+
supported_response_MIME_types = ["application/json", "text/csv", "application/jsonlines"]
700+
701+
validation_file_name = "input.csv"
702+
validation_input_path = "s3://" + BUCKET_NAME + "/validation-input-csv/"
703+
validation_output_path = "s3://" + BUCKET_NAME + "/validation-output-csv/"
704+
705+
ValidationSpecification = {
706+
"ValidationRole": "some_role",
707+
"ValidationProfiles": [
708+
{
709+
"ProfileName": "Validation-test",
710+
"TransformJobDefinition": {
711+
"BatchStrategy": "SingleRecord",
712+
"TransformInput": {
713+
"DataSource": {
714+
"S3DataSource": {
715+
"S3DataType": "S3Prefix",
716+
"S3Uri": validation_input_path,
717+
}
718+
},
719+
"ContentType": supported_content_types[0],
720+
},
721+
"TransformOutput": {
722+
"S3OutputPath": validation_output_path,
723+
},
724+
"TransformResources": {
725+
"InstanceType": supported_batch_transform_instance_types[0],
726+
"InstanceCount": 1,
727+
},
728+
},
729+
},
730+
],
731+
}
732+
733+
import pdb
734+
735+
@patch("sagemaker.get_model_package_args")
736+
def test_call_to_get_model_package_args(get_model_package_args, sagemaker_session):
737+
738+
source_dir = "s3://blah/blah/blah"
739+
t = Model(
740+
entry_point=ENTRY_POINT_INFERENCE,
741+
role=ROLE,
742+
sagemaker_session=sagemaker_session,
743+
source_dir=source_dir,
744+
image_uri=IMAGE_URI,
745+
model_data=MODEL_DATA,
746+
)
747+
748+
t.register(
749+
supported_content_types,
750+
supported_response_MIME_types,
751+
supported_realtime_inference_instance_types,
752+
supported_batch_transform_instance_types,
753+
marketplace_cert=True,
754+
description=model_description,
755+
model_package_name=model_package_name,
756+
validation_specification=ValidationSpecification,
757+
758+
)
759+
760+
# check that the kwarg validation_specification was passed to the internal method 'get_model_package_args'
761+
assert(
762+
"validation_specification" in get_model_package_args.call_args_list[0][1],
763+
"validation_specification kwarg was not passed to get_model_package_args"
764+
)
765+
766+
# check that the kwarg validation_specification is identical to the one passed into the method 'register'
767+
assert(
768+
ValidationSpecification == get_model_package_args.call_args_list[0][1]["validation_specification"],
769+
"""ValidationSpecification from model.register method is not identical to validation_spec from
770+
get_model_package_args"""
771+
)
772+
773+
774+
775+
776+
777+
778+
779+
780+
781+

0 commit comments

Comments
 (0)