@@ -687,3 +687,95 @@ def test_script_mode_model_uses_proper_sagemaker_submit_dir(repack_model, sagema
687
687
]
688
688
== "/opt/ml/model/code"
689
689
)
690
+
691
+
692
+ model_name = model_package_name = "my-flower-detection-model"
693
+ model_description = "This model accepts petal length, petal width, sepal length, sepal width and predicts whether flower is of type setosa, versicolor, or virginica"
694
+
695
+ supported_realtime_inference_instance_types = ["ml.m4.xlarge" ]
696
+ supported_batch_transform_instance_types = ["ml.m4.xlarge" ]
697
+
698
+ supported_content_types = ["text/csv" , "application/json" , "application/jsonlines" ]
699
+ supported_response_MIME_types = ["application/json" , "text/csv" , "application/jsonlines" ]
700
+
701
+ validation_file_name = "input.csv"
702
+ validation_input_path = "s3://" + BUCKET_NAME + "/validation-input-csv/"
703
+ validation_output_path = "s3://" + BUCKET_NAME + "/validation-output-csv/"
704
+
705
+ ValidationSpecification = {
706
+ "ValidationRole" : "some_role" ,
707
+ "ValidationProfiles" : [
708
+ {
709
+ "ProfileName" : "Validation-test" ,
710
+ "TransformJobDefinition" : {
711
+ "BatchStrategy" : "SingleRecord" ,
712
+ "TransformInput" : {
713
+ "DataSource" : {
714
+ "S3DataSource" : {
715
+ "S3DataType" : "S3Prefix" ,
716
+ "S3Uri" : validation_input_path ,
717
+ }
718
+ },
719
+ "ContentType" : supported_content_types [0 ],
720
+ },
721
+ "TransformOutput" : {
722
+ "S3OutputPath" : validation_output_path ,
723
+ },
724
+ "TransformResources" : {
725
+ "InstanceType" : supported_batch_transform_instance_types [0 ],
726
+ "InstanceCount" : 1 ,
727
+ },
728
+ },
729
+ },
730
+ ],
731
+ }
732
+
733
+ import pdb
734
+
735
+ @patch ("sagemaker.get_model_package_args" )
736
+ def test_call_to_get_model_package_args (get_model_package_args , sagemaker_session ):
737
+
738
+ source_dir = "s3://blah/blah/blah"
739
+ t = Model (
740
+ entry_point = ENTRY_POINT_INFERENCE ,
741
+ role = ROLE ,
742
+ sagemaker_session = sagemaker_session ,
743
+ source_dir = source_dir ,
744
+ image_uri = IMAGE_URI ,
745
+ model_data = MODEL_DATA ,
746
+ )
747
+
748
+ t .register (
749
+ supported_content_types ,
750
+ supported_response_MIME_types ,
751
+ supported_realtime_inference_instance_types ,
752
+ supported_batch_transform_instance_types ,
753
+ marketplace_cert = True ,
754
+ description = model_description ,
755
+ model_package_name = model_package_name ,
756
+ validation_specification = ValidationSpecification ,
757
+
758
+ )
759
+
760
+ # check that the kwarg validation_specification was passed to the internal method 'get_model_package_args'
761
+ assert (
762
+ "validation_specification" in get_model_package_args .call_args_list [0 ][1 ],
763
+ "validation_specification kwarg was not passed to get_model_package_args"
764
+ )
765
+
766
+ # check that the kwarg validation_specification is identical to the one passed into the method 'register'
767
+ assert (
768
+ ValidationSpecification == get_model_package_args .call_args_list [0 ][1 ]["validation_specification" ],
769
+ """ValidationSpecification from model.register method is not identical to validation_spec from
770
+ get_model_package_args"""
771
+ )
772
+
773
+
774
+
775
+
776
+
777
+
778
+
779
+
780
+
781
+
0 commit comments