@@ -105,6 +105,7 @@ def test_sklearn_with_all_parameters(
105
105
botocore_resolver .return_value .construct_endpoint .return_value = {"hostname" : ECR_HOSTNAME }
106
106
107
107
processor = SKLearnProcessor (
108
+ s3_prefix = MOCKED_S3_URI ,
108
109
role = ROLE ,
109
110
framework_version = sklearn_version ,
110
111
instance_type = "ml.m4.xlarge" ,
@@ -126,7 +127,7 @@ def test_sklearn_with_all_parameters(
126
127
)
127
128
128
129
processor .run (
129
- code = "/local/path/to/processing_code.py" ,
130
+ entry_point = "/local/path/to/processing_code.py" ,
130
131
inputs = _get_data_inputs_all_parameters (),
131
132
outputs = _get_data_outputs_all_parameters (),
132
133
arguments = ["--drop-columns" , "'SelfEmployed'" ],
@@ -136,7 +137,7 @@ def test_sklearn_with_all_parameters(
136
137
experiment_config = {"ExperimentName" : "AnExperiment" },
137
138
)
138
139
139
- expected_args = _get_expected_args_all_parameters (processor ._current_job_name )
140
+ expected_args = _get_expected_args_all_parameters_modular_code (processor ._current_job_name )
140
141
sklearn_image_uri = (
141
142
"246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:{}-cpu-py3"
142
143
).format (sklearn_version )
@@ -748,9 +749,9 @@ def _get_data_inputs_all_parameters():
748
749
input_name = "redshift_dataset_definition" ,
749
750
app_managed = True ,
750
751
dataset_definition = DatasetDefinition (
751
- local_path = "/opt/ml/processing/input/dd" ,
752
752
data_distribution_type = "FullyReplicated" ,
753
753
input_mode = "File" ,
754
+ local_path = "/opt/ml/processing/input/dd" ,
754
755
redshift_dataset_definition = RedshiftDatasetDefinition (
755
756
cluster_id = "cluster_id" ,
756
757
database = "database" ,
@@ -768,15 +769,15 @@ def _get_data_inputs_all_parameters():
768
769
input_name = "athena_dataset_definition" ,
769
770
app_managed = True ,
770
771
dataset_definition = DatasetDefinition (
771
- local_path = "/opt/ml/processing/input/dd" ,
772
772
data_distribution_type = "FullyReplicated" ,
773
773
input_mode = "File" ,
774
+ local_path = "/opt/ml/processing/input/dd" ,
774
775
athena_dataset_definition = AthenaDatasetDefinition (
775
776
catalog = "catalog" ,
776
777
database = "database" ,
777
- work_group = "workgroup" ,
778
778
query_string = "query_string" ,
779
779
output_s3_uri = "output_s3_uri" ,
780
+ work_group = "workgroup" ,
780
781
kms_key_id = "kms_key_id" ,
781
782
output_format = "AVRO" ,
782
783
output_compression = "ZLIB" ,
@@ -802,6 +803,147 @@ def _get_data_outputs_all_parameters():
802
803
]
803
804
804
805
806
+ def _get_expected_args_all_parameters_modular_code (job_name , code_s3_uri = MOCKED_S3_URI ):
807
+ # Add something to inputs
808
+ return {
809
+ "inputs" : [
810
+ {
811
+ "InputName" : "my_dataset" ,
812
+ "AppManaged" : False ,
813
+ "S3Input" : {
814
+ "S3Uri" : "s3://path/to/my/dataset/census.csv" ,
815
+ "LocalPath" : "/container/path/" ,
816
+ "S3DataType" : "S3Prefix" ,
817
+ "S3InputMode" : "File" ,
818
+ "S3DataDistributionType" : "FullyReplicated" ,
819
+ "S3CompressionType" : "None" ,
820
+ },
821
+ },
822
+ {
823
+ "InputName" : "s3_input" ,
824
+ "AppManaged" : False ,
825
+ "S3Input" : {
826
+ "S3Uri" : "s3://path/to/my/dataset/census.csv" ,
827
+ "LocalPath" : "/container/path/" ,
828
+ "S3DataType" : "S3Prefix" ,
829
+ "S3InputMode" : "File" ,
830
+ "S3DataDistributionType" : "FullyReplicated" ,
831
+ "S3CompressionType" : "None" ,
832
+ },
833
+ },
834
+ {
835
+ "InputName" : "redshift_dataset_definition" ,
836
+ "AppManaged" : True ,
837
+ "DatasetDefinition" : {
838
+ "DataDistributionType" : "FullyReplicated" ,
839
+ "InputMode" : "File" ,
840
+ "LocalPath" : "/opt/ml/processing/input/dd" ,
841
+ "RedshiftDatasetDefinition" : {
842
+ "ClusterId" : "cluster_id" ,
843
+ "Database" : "database" ,
844
+ "DbUser" : "db_user" ,
845
+ "QueryString" : "query_string" ,
846
+ "ClusterRoleArn" : "cluster_role_arn" ,
847
+ "OutputS3Uri" : "output_s3_uri" ,
848
+ "KmsKeyId" : "kms_key_id" ,
849
+ "OutputFormat" : "CSV" ,
850
+ "OutputCompression" : "SNAPPY" ,
851
+ },
852
+ },
853
+ },
854
+ {
855
+ "InputName" : "athena_dataset_definition" ,
856
+ "AppManaged" : True ,
857
+ "DatasetDefinition" : {
858
+ "DataDistributionType" : "FullyReplicated" ,
859
+ "InputMode" : "File" ,
860
+ "LocalPath" : "/opt/ml/processing/input/dd" ,
861
+ "AthenaDatasetDefinition" : {
862
+ "Catalog" : "catalog" ,
863
+ "Database" : "database" ,
864
+ "QueryString" : "query_string" ,
865
+ "OutputS3Uri" : "output_s3_uri" ,
866
+ "WorkGroup" : "workgroup" ,
867
+ "KmsKeyId" : "kms_key_id" ,
868
+ "OutputFormat" : "AVRO" ,
869
+ "OutputCompression" : "ZLIB" ,
870
+ },
871
+ },
872
+ },
873
+ {
874
+ "InputName" : "input-5" ,
875
+ "AppManaged" : False ,
876
+ "S3Input" : {
877
+ "S3Uri" : f"{ code_s3_uri } /{ job_name } /source/sourcedir.tar.gz" ,
878
+ "LocalPath" : "/opt/ml/processing/input/code/payload/" ,
879
+ "S3DataType" : "S3Prefix" ,
880
+ "S3InputMode" : "File" ,
881
+ "S3DataDistributionType" : "FullyReplicated" ,
882
+ "S3CompressionType" : "None" ,
883
+ },
884
+ },
885
+ {
886
+ "InputName" : "code" ,
887
+ "AppManaged" : False ,
888
+ "S3Input" : {
889
+ "S3Uri" : f"{ code_s3_uri } /{ job_name } /source/runproc.sh" ,
890
+ "LocalPath" : "/opt/ml/processing/input/code" ,
891
+ "S3DataType" : "S3Prefix" ,
892
+ "S3InputMode" : "File" ,
893
+ "S3DataDistributionType" : "FullyReplicated" ,
894
+ "S3CompressionType" : "None" ,
895
+ },
896
+ },
897
+ ],
898
+ "output_config" : {
899
+ "Outputs" : [
900
+ {
901
+ "OutputName" : "my_output" ,
902
+ "AppManaged" : False ,
903
+ "S3Output" : {
904
+ "S3Uri" : "s3://uri/" ,
905
+ "LocalPath" : "/container/path/" ,
906
+ "S3UploadMode" : "EndOfJob" ,
907
+ },
908
+ },
909
+ {
910
+ "OutputName" : "feature_store_output" ,
911
+ "AppManaged" : True ,
912
+ "FeatureStoreOutput" : {"FeatureGroupName" : "FeatureGroupName" },
913
+ },
914
+ ],
915
+ "KmsKeyId" : "arn:aws:kms:us-west-2:012345678901:key/output-kms-key" ,
916
+ },
917
+ "experiment_config" : {"ExperimentName" : "AnExperiment" },
918
+ "job_name" : job_name ,
919
+ "resources" : {
920
+ "ClusterConfig" : {
921
+ "InstanceType" : "ml.m4.xlarge" ,
922
+ "InstanceCount" : 1 ,
923
+ "VolumeSizeInGB" : 100 ,
924
+ "VolumeKmsKeyId" : "arn:aws:kms:us-west-2:012345678901:key/volume-kms-key" ,
925
+ }
926
+ },
927
+ "stopping_condition" : {"MaxRuntimeInSeconds" : 3600 },
928
+ "app_specification" : {
929
+ "ImageUri" : "012345678901.dkr.ecr.us-west-2.amazonaws.com/my-custom-image-uri" ,
930
+ "ContainerArguments" : ["--drop-columns" , "'SelfEmployed'" ],
931
+ "ContainerEntrypoint" : ["/bin/bash" , "/opt/ml/processing/input/code/runproc.sh" ],
932
+ },
933
+ "environment" : {"my_env_variable" : "my_env_variable_value" },
934
+ "network_config" : {
935
+ "EnableNetworkIsolation" : True ,
936
+ "EnableInterContainerTrafficEncryption" : True ,
937
+ "VpcConfig" : {
938
+ "SecurityGroupIds" : ["my_security_group_id" ],
939
+ "Subnets" : ["my_subnet_id" ],
940
+ },
941
+ },
942
+ "role_arn" : ROLE ,
943
+ "tags" : [{"Key" : "my-tag" , "Value" : "my-tag-value" }],
944
+ }
945
+
946
+
805
947
def _get_expected_args_all_parameters (job_name ):
806
948
return {
807
949
"inputs" : [
0 commit comments