@@ -105,6 +105,7 @@ def test_sklearn_with_all_parameters(
105
105
botocore_resolver .return_value .construct_endpoint .return_value = {"hostname" : ECR_HOSTNAME }
106
106
107
107
processor = SKLearnProcessor (
108
+ s3_prefix = MOCKED_S3_URI ,
108
109
role = ROLE ,
109
110
framework_version = sklearn_version ,
110
111
instance_type = "ml.m4.xlarge" ,
@@ -126,7 +127,7 @@ def test_sklearn_with_all_parameters(
126
127
)
127
128
128
129
processor .run (
129
- code = "/local/path/to/processing_code.py" ,
130
+ entry_point = "/local/path/to/processing_code.py" ,
130
131
inputs = _get_data_inputs_all_parameters (),
131
132
outputs = _get_data_outputs_all_parameters (),
132
133
arguments = ["--drop-columns" , "'SelfEmployed'" ],
@@ -136,7 +137,7 @@ def test_sklearn_with_all_parameters(
136
137
experiment_config = {"ExperimentName" : "AnExperiment" },
137
138
)
138
139
139
- expected_args = _get_expected_args_all_parameters (processor ._current_job_name )
140
+ expected_args = _get_expected_args_all_parameters_modular_code (processor ._current_job_name )
140
141
sklearn_image_uri = (
141
142
"246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:{}-cpu-py3"
142
143
).format (sklearn_version )
@@ -154,6 +155,7 @@ def test_sklearn_with_all_parameters_via_run_args(
154
155
botocore_resolver .return_value .construct_endpoint .return_value = {"hostname" : ECR_HOSTNAME }
155
156
156
157
processor = SKLearnProcessor (
158
+ s3_prefix = MOCKED_S3_URI ,
157
159
role = ROLE ,
158
160
framework_version = sklearn_version ,
159
161
instance_type = "ml.m4.xlarge" ,
@@ -174,6 +176,8 @@ def test_sklearn_with_all_parameters_via_run_args(
174
176
sagemaker_session = sagemaker_session ,
175
177
)
176
178
179
+ # FIXME: to check FrameworkProcessor.get_run_args(), and possibly fix with
180
+ # source_dir, dependencies.
177
181
run_args = processor .get_run_args (
178
182
code = "/local/path/to/processing_code.py" ,
179
183
inputs = _get_data_inputs_all_parameters (),
@@ -182,7 +186,7 @@ def test_sklearn_with_all_parameters_via_run_args(
182
186
)
183
187
184
188
processor .run (
185
- code = run_args .code ,
189
+ entry_point = run_args .code ,
186
190
inputs = run_args .inputs ,
187
191
outputs = run_args .outputs ,
188
192
arguments = run_args .arguments ,
@@ -191,7 +195,7 @@ def test_sklearn_with_all_parameters_via_run_args(
191
195
experiment_config = {"ExperimentName" : "AnExperiment" },
192
196
)
193
197
194
- expected_args = _get_expected_args_all_parameters (processor ._current_job_name )
198
+ expected_args = _get_expected_args_all_parameters_modular_code (processor ._current_job_name )
195
199
sklearn_image_uri = (
196
200
"246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:{}-cpu-py3"
197
201
).format (sklearn_version )
@@ -209,6 +213,7 @@ def test_sklearn_with_all_parameters_via_run_args_called_twice(
209
213
botocore_resolver .return_value .construct_endpoint .return_value = {"hostname" : ECR_HOSTNAME }
210
214
211
215
processor = SKLearnProcessor (
216
+ s3_prefix = MOCKED_S3_URI ,
212
217
role = ROLE ,
213
218
framework_version = sklearn_version ,
214
219
instance_type = "ml.m4.xlarge" ,
@@ -244,7 +249,7 @@ def test_sklearn_with_all_parameters_via_run_args_called_twice(
244
249
)
245
250
246
251
processor .run (
247
- code = run_args .code ,
252
+ entry_point = run_args .code ,
248
253
inputs = run_args .inputs ,
249
254
outputs = run_args .outputs ,
250
255
arguments = run_args .arguments ,
@@ -253,7 +258,7 @@ def test_sklearn_with_all_parameters_via_run_args_called_twice(
253
258
experiment_config = {"ExperimentName" : "AnExperiment" },
254
259
)
255
260
256
- expected_args = _get_expected_args_all_parameters (processor ._current_job_name )
261
+ expected_args = _get_expected_args_all_parameters_modular_code (processor ._current_job_name )
257
262
sklearn_image_uri = (
258
263
"246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:{}-cpu-py3"
259
264
).format (sklearn_version )
@@ -748,9 +753,9 @@ def _get_data_inputs_all_parameters():
748
753
input_name = "redshift_dataset_definition" ,
749
754
app_managed = True ,
750
755
dataset_definition = DatasetDefinition (
751
- local_path = "/opt/ml/processing/input/dd" ,
752
756
data_distribution_type = "FullyReplicated" ,
753
757
input_mode = "File" ,
758
+ local_path = "/opt/ml/processing/input/dd" ,
754
759
redshift_dataset_definition = RedshiftDatasetDefinition (
755
760
cluster_id = "cluster_id" ,
756
761
database = "database" ,
@@ -768,15 +773,15 @@ def _get_data_inputs_all_parameters():
768
773
input_name = "athena_dataset_definition" ,
769
774
app_managed = True ,
770
775
dataset_definition = DatasetDefinition (
771
- local_path = "/opt/ml/processing/input/dd" ,
772
776
data_distribution_type = "FullyReplicated" ,
773
777
input_mode = "File" ,
778
+ local_path = "/opt/ml/processing/input/dd" ,
774
779
athena_dataset_definition = AthenaDatasetDefinition (
775
780
catalog = "catalog" ,
776
781
database = "database" ,
777
- work_group = "workgroup" ,
778
782
query_string = "query_string" ,
779
783
output_s3_uri = "output_s3_uri" ,
784
+ work_group = "workgroup" ,
780
785
kms_key_id = "kms_key_id" ,
781
786
output_format = "AVRO" ,
782
787
output_compression = "ZLIB" ,
@@ -802,6 +807,147 @@ def _get_data_outputs_all_parameters():
802
807
]
803
808
804
809
810
+ def _get_expected_args_all_parameters_modular_code (job_name , code_s3_uri = MOCKED_S3_URI ):
811
+ # Add something to inputs
812
+ return {
813
+ "inputs" : [
814
+ {
815
+ "InputName" : "my_dataset" ,
816
+ "AppManaged" : False ,
817
+ "S3Input" : {
818
+ "S3Uri" : "s3://path/to/my/dataset/census.csv" ,
819
+ "LocalPath" : "/container/path/" ,
820
+ "S3DataType" : "S3Prefix" ,
821
+ "S3InputMode" : "File" ,
822
+ "S3DataDistributionType" : "FullyReplicated" ,
823
+ "S3CompressionType" : "None" ,
824
+ },
825
+ },
826
+ {
827
+ "InputName" : "s3_input" ,
828
+ "AppManaged" : False ,
829
+ "S3Input" : {
830
+ "S3Uri" : "s3://path/to/my/dataset/census.csv" ,
831
+ "LocalPath" : "/container/path/" ,
832
+ "S3DataType" : "S3Prefix" ,
833
+ "S3InputMode" : "File" ,
834
+ "S3DataDistributionType" : "FullyReplicated" ,
835
+ "S3CompressionType" : "None" ,
836
+ },
837
+ },
838
+ {
839
+ "InputName" : "redshift_dataset_definition" ,
840
+ "AppManaged" : True ,
841
+ "DatasetDefinition" : {
842
+ "DataDistributionType" : "FullyReplicated" ,
843
+ "InputMode" : "File" ,
844
+ "LocalPath" : "/opt/ml/processing/input/dd" ,
845
+ "RedshiftDatasetDefinition" : {
846
+ "ClusterId" : "cluster_id" ,
847
+ "Database" : "database" ,
848
+ "DbUser" : "db_user" ,
849
+ "QueryString" : "query_string" ,
850
+ "ClusterRoleArn" : "cluster_role_arn" ,
851
+ "OutputS3Uri" : "output_s3_uri" ,
852
+ "KmsKeyId" : "kms_key_id" ,
853
+ "OutputFormat" : "CSV" ,
854
+ "OutputCompression" : "SNAPPY" ,
855
+ },
856
+ },
857
+ },
858
+ {
859
+ "InputName" : "athena_dataset_definition" ,
860
+ "AppManaged" : True ,
861
+ "DatasetDefinition" : {
862
+ "DataDistributionType" : "FullyReplicated" ,
863
+ "InputMode" : "File" ,
864
+ "LocalPath" : "/opt/ml/processing/input/dd" ,
865
+ "AthenaDatasetDefinition" : {
866
+ "Catalog" : "catalog" ,
867
+ "Database" : "database" ,
868
+ "QueryString" : "query_string" ,
869
+ "OutputS3Uri" : "output_s3_uri" ,
870
+ "WorkGroup" : "workgroup" ,
871
+ "KmsKeyId" : "kms_key_id" ,
872
+ "OutputFormat" : "AVRO" ,
873
+ "OutputCompression" : "ZLIB" ,
874
+ },
875
+ },
876
+ },
877
+ {
878
+ "InputName" : "input-5" ,
879
+ "AppManaged" : False ,
880
+ "S3Input" : {
881
+ "S3Uri" : f"{ code_s3_uri } /{ job_name } /source/sourcedir.tar.gz" ,
882
+ "LocalPath" : "/opt/ml/processing/input/code/payload/" ,
883
+ "S3DataType" : "S3Prefix" ,
884
+ "S3InputMode" : "File" ,
885
+ "S3DataDistributionType" : "FullyReplicated" ,
886
+ "S3CompressionType" : "None" ,
887
+ },
888
+ },
889
+ {
890
+ "InputName" : "code" ,
891
+ "AppManaged" : False ,
892
+ "S3Input" : {
893
+ "S3Uri" : f"{ code_s3_uri } /{ job_name } /source/runproc.sh" ,
894
+ "LocalPath" : "/opt/ml/processing/input/code" ,
895
+ "S3DataType" : "S3Prefix" ,
896
+ "S3InputMode" : "File" ,
897
+ "S3DataDistributionType" : "FullyReplicated" ,
898
+ "S3CompressionType" : "None" ,
899
+ },
900
+ },
901
+ ],
902
+ "output_config" : {
903
+ "Outputs" : [
904
+ {
905
+ "OutputName" : "my_output" ,
906
+ "AppManaged" : False ,
907
+ "S3Output" : {
908
+ "S3Uri" : "s3://uri/" ,
909
+ "LocalPath" : "/container/path/" ,
910
+ "S3UploadMode" : "EndOfJob" ,
911
+ },
912
+ },
913
+ {
914
+ "OutputName" : "feature_store_output" ,
915
+ "AppManaged" : True ,
916
+ "FeatureStoreOutput" : {"FeatureGroupName" : "FeatureGroupName" },
917
+ },
918
+ ],
919
+ "KmsKeyId" : "arn:aws:kms:us-west-2:012345678901:key/output-kms-key" ,
920
+ },
921
+ "experiment_config" : {"ExperimentName" : "AnExperiment" },
922
+ "job_name" : job_name ,
923
+ "resources" : {
924
+ "ClusterConfig" : {
925
+ "InstanceType" : "ml.m4.xlarge" ,
926
+ "InstanceCount" : 1 ,
927
+ "VolumeSizeInGB" : 100 ,
928
+ "VolumeKmsKeyId" : "arn:aws:kms:us-west-2:012345678901:key/volume-kms-key" ,
929
+ }
930
+ },
931
+ "stopping_condition" : {"MaxRuntimeInSeconds" : 3600 },
932
+ "app_specification" : {
933
+ "ImageUri" : "012345678901.dkr.ecr.us-west-2.amazonaws.com/my-custom-image-uri" ,
934
+ "ContainerArguments" : ["--drop-columns" , "'SelfEmployed'" ],
935
+ "ContainerEntrypoint" : ["/bin/bash" , "/opt/ml/processing/input/code/runproc.sh" ],
936
+ },
937
+ "environment" : {"my_env_variable" : "my_env_variable_value" },
938
+ "network_config" : {
939
+ "EnableNetworkIsolation" : True ,
940
+ "EnableInterContainerTrafficEncryption" : True ,
941
+ "VpcConfig" : {
942
+ "SecurityGroupIds" : ["my_security_group_id" ],
943
+ "Subnets" : ["my_subnet_id" ],
944
+ },
945
+ },
946
+ "role_arn" : ROLE ,
947
+ "tags" : [{"Key" : "my-tag" , "Value" : "my-tag-value" }],
948
+ }
949
+
950
+
805
951
def _get_expected_args_all_parameters (job_name ):
806
952
return {
807
953
"inputs" : [
0 commit comments