Skip to content

Commit dd7a869

Browse files
author
Verdi March
committed
Merge branch 'test-sklearn-all-modular' into pr-framework-processor-round-02
2 parents 85dd50f + 24534b3 commit dd7a869

File tree

1 file changed

+155
-9
lines changed

1 file changed

+155
-9
lines changed

tests/unit/test_processing.py

Lines changed: 155 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def test_sklearn_with_all_parameters(
105105
botocore_resolver.return_value.construct_endpoint.return_value = {"hostname": ECR_HOSTNAME}
106106

107107
processor = SKLearnProcessor(
108+
s3_prefix=MOCKED_S3_URI,
108109
role=ROLE,
109110
framework_version=sklearn_version,
110111
instance_type="ml.m4.xlarge",
@@ -126,7 +127,7 @@ def test_sklearn_with_all_parameters(
126127
)
127128

128129
processor.run(
129-
code="/local/path/to/processing_code.py",
130+
entry_point="/local/path/to/processing_code.py",
130131
inputs=_get_data_inputs_all_parameters(),
131132
outputs=_get_data_outputs_all_parameters(),
132133
arguments=["--drop-columns", "'SelfEmployed'"],
@@ -136,7 +137,7 @@ def test_sklearn_with_all_parameters(
136137
experiment_config={"ExperimentName": "AnExperiment"},
137138
)
138139

139-
expected_args = _get_expected_args_all_parameters(processor._current_job_name)
140+
expected_args = _get_expected_args_all_parameters_modular_code(processor._current_job_name)
140141
sklearn_image_uri = (
141142
"246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:{}-cpu-py3"
142143
).format(sklearn_version)
@@ -154,6 +155,7 @@ def test_sklearn_with_all_parameters_via_run_args(
154155
botocore_resolver.return_value.construct_endpoint.return_value = {"hostname": ECR_HOSTNAME}
155156

156157
processor = SKLearnProcessor(
158+
s3_prefix=MOCKED_S3_URI,
157159
role=ROLE,
158160
framework_version=sklearn_version,
159161
instance_type="ml.m4.xlarge",
@@ -174,6 +176,8 @@ def test_sklearn_with_all_parameters_via_run_args(
174176
sagemaker_session=sagemaker_session,
175177
)
176178

179+
# FIXME: to check FrameworkProcessor.get_run_args(), and possibly fix with
180+
# source_dir, dependencies.
177181
run_args = processor.get_run_args(
178182
code="/local/path/to/processing_code.py",
179183
inputs=_get_data_inputs_all_parameters(),
@@ -182,7 +186,7 @@ def test_sklearn_with_all_parameters_via_run_args(
182186
)
183187

184188
processor.run(
185-
code=run_args.code,
189+
entry_point=run_args.code,
186190
inputs=run_args.inputs,
187191
outputs=run_args.outputs,
188192
arguments=run_args.arguments,
@@ -191,7 +195,7 @@ def test_sklearn_with_all_parameters_via_run_args(
191195
experiment_config={"ExperimentName": "AnExperiment"},
192196
)
193197

194-
expected_args = _get_expected_args_all_parameters(processor._current_job_name)
198+
expected_args = _get_expected_args_all_parameters_modular_code(processor._current_job_name)
195199
sklearn_image_uri = (
196200
"246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:{}-cpu-py3"
197201
).format(sklearn_version)
@@ -209,6 +213,7 @@ def test_sklearn_with_all_parameters_via_run_args_called_twice(
209213
botocore_resolver.return_value.construct_endpoint.return_value = {"hostname": ECR_HOSTNAME}
210214

211215
processor = SKLearnProcessor(
216+
s3_prefix=MOCKED_S3_URI,
212217
role=ROLE,
213218
framework_version=sklearn_version,
214219
instance_type="ml.m4.xlarge",
@@ -244,7 +249,7 @@ def test_sklearn_with_all_parameters_via_run_args_called_twice(
244249
)
245250

246251
processor.run(
247-
code=run_args.code,
252+
entry_point=run_args.code,
248253
inputs=run_args.inputs,
249254
outputs=run_args.outputs,
250255
arguments=run_args.arguments,
@@ -253,7 +258,7 @@ def test_sklearn_with_all_parameters_via_run_args_called_twice(
253258
experiment_config={"ExperimentName": "AnExperiment"},
254259
)
255260

256-
expected_args = _get_expected_args_all_parameters(processor._current_job_name)
261+
expected_args = _get_expected_args_all_parameters_modular_code(processor._current_job_name)
257262
sklearn_image_uri = (
258263
"246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:{}-cpu-py3"
259264
).format(sklearn_version)
@@ -748,9 +753,9 @@ def _get_data_inputs_all_parameters():
748753
input_name="redshift_dataset_definition",
749754
app_managed=True,
750755
dataset_definition=DatasetDefinition(
751-
local_path="/opt/ml/processing/input/dd",
752756
data_distribution_type="FullyReplicated",
753757
input_mode="File",
758+
local_path="/opt/ml/processing/input/dd",
754759
redshift_dataset_definition=RedshiftDatasetDefinition(
755760
cluster_id="cluster_id",
756761
database="database",
@@ -768,15 +773,15 @@ def _get_data_inputs_all_parameters():
768773
input_name="athena_dataset_definition",
769774
app_managed=True,
770775
dataset_definition=DatasetDefinition(
771-
local_path="/opt/ml/processing/input/dd",
772776
data_distribution_type="FullyReplicated",
773777
input_mode="File",
778+
local_path="/opt/ml/processing/input/dd",
774779
athena_dataset_definition=AthenaDatasetDefinition(
775780
catalog="catalog",
776781
database="database",
777-
work_group="workgroup",
778782
query_string="query_string",
779783
output_s3_uri="output_s3_uri",
784+
work_group="workgroup",
780785
kms_key_id="kms_key_id",
781786
output_format="AVRO",
782787
output_compression="ZLIB",
@@ -802,6 +807,147 @@ def _get_data_outputs_all_parameters():
802807
]
803808

804809

810+
def _get_expected_args_all_parameters_modular_code(job_name, code_s3_uri=MOCKED_S3_URI):
811+
# Add something to inputs
812+
return {
813+
"inputs": [
814+
{
815+
"InputName": "my_dataset",
816+
"AppManaged": False,
817+
"S3Input": {
818+
"S3Uri": "s3://path/to/my/dataset/census.csv",
819+
"LocalPath": "/container/path/",
820+
"S3DataType": "S3Prefix",
821+
"S3InputMode": "File",
822+
"S3DataDistributionType": "FullyReplicated",
823+
"S3CompressionType": "None",
824+
},
825+
},
826+
{
827+
"InputName": "s3_input",
828+
"AppManaged": False,
829+
"S3Input": {
830+
"S3Uri": "s3://path/to/my/dataset/census.csv",
831+
"LocalPath": "/container/path/",
832+
"S3DataType": "S3Prefix",
833+
"S3InputMode": "File",
834+
"S3DataDistributionType": "FullyReplicated",
835+
"S3CompressionType": "None",
836+
},
837+
},
838+
{
839+
"InputName": "redshift_dataset_definition",
840+
"AppManaged": True,
841+
"DatasetDefinition": {
842+
"DataDistributionType": "FullyReplicated",
843+
"InputMode": "File",
844+
"LocalPath": "/opt/ml/processing/input/dd",
845+
"RedshiftDatasetDefinition": {
846+
"ClusterId": "cluster_id",
847+
"Database": "database",
848+
"DbUser": "db_user",
849+
"QueryString": "query_string",
850+
"ClusterRoleArn": "cluster_role_arn",
851+
"OutputS3Uri": "output_s3_uri",
852+
"KmsKeyId": "kms_key_id",
853+
"OutputFormat": "CSV",
854+
"OutputCompression": "SNAPPY",
855+
},
856+
},
857+
},
858+
{
859+
"InputName": "athena_dataset_definition",
860+
"AppManaged": True,
861+
"DatasetDefinition": {
862+
"DataDistributionType": "FullyReplicated",
863+
"InputMode": "File",
864+
"LocalPath": "/opt/ml/processing/input/dd",
865+
"AthenaDatasetDefinition": {
866+
"Catalog": "catalog",
867+
"Database": "database",
868+
"QueryString": "query_string",
869+
"OutputS3Uri": "output_s3_uri",
870+
"WorkGroup": "workgroup",
871+
"KmsKeyId": "kms_key_id",
872+
"OutputFormat": "AVRO",
873+
"OutputCompression": "ZLIB",
874+
},
875+
},
876+
},
877+
{
878+
"InputName": "input-5",
879+
"AppManaged": False,
880+
"S3Input": {
881+
"S3Uri": f"{code_s3_uri}/{job_name}/source/sourcedir.tar.gz",
882+
"LocalPath": "/opt/ml/processing/input/code/payload/",
883+
"S3DataType": "S3Prefix",
884+
"S3InputMode": "File",
885+
"S3DataDistributionType": "FullyReplicated",
886+
"S3CompressionType": "None",
887+
},
888+
},
889+
{
890+
"InputName": "code",
891+
"AppManaged": False,
892+
"S3Input": {
893+
"S3Uri": f"{code_s3_uri}/{job_name}/source/runproc.sh",
894+
"LocalPath": "/opt/ml/processing/input/code",
895+
"S3DataType": "S3Prefix",
896+
"S3InputMode": "File",
897+
"S3DataDistributionType": "FullyReplicated",
898+
"S3CompressionType": "None",
899+
},
900+
},
901+
],
902+
"output_config": {
903+
"Outputs": [
904+
{
905+
"OutputName": "my_output",
906+
"AppManaged": False,
907+
"S3Output": {
908+
"S3Uri": "s3://uri/",
909+
"LocalPath": "/container/path/",
910+
"S3UploadMode": "EndOfJob",
911+
},
912+
},
913+
{
914+
"OutputName": "feature_store_output",
915+
"AppManaged": True,
916+
"FeatureStoreOutput": {"FeatureGroupName": "FeatureGroupName"},
917+
},
918+
],
919+
"KmsKeyId": "arn:aws:kms:us-west-2:012345678901:key/output-kms-key",
920+
},
921+
"experiment_config": {"ExperimentName": "AnExperiment"},
922+
"job_name": job_name,
923+
"resources": {
924+
"ClusterConfig": {
925+
"InstanceType": "ml.m4.xlarge",
926+
"InstanceCount": 1,
927+
"VolumeSizeInGB": 100,
928+
"VolumeKmsKeyId": "arn:aws:kms:us-west-2:012345678901:key/volume-kms-key",
929+
}
930+
},
931+
"stopping_condition": {"MaxRuntimeInSeconds": 3600},
932+
"app_specification": {
933+
"ImageUri": "012345678901.dkr.ecr.us-west-2.amazonaws.com/my-custom-image-uri",
934+
"ContainerArguments": ["--drop-columns", "'SelfEmployed'"],
935+
"ContainerEntrypoint": ["/bin/bash", "/opt/ml/processing/input/code/runproc.sh"],
936+
},
937+
"environment": {"my_env_variable": "my_env_variable_value"},
938+
"network_config": {
939+
"EnableNetworkIsolation": True,
940+
"EnableInterContainerTrafficEncryption": True,
941+
"VpcConfig": {
942+
"SecurityGroupIds": ["my_security_group_id"],
943+
"Subnets": ["my_subnet_id"],
944+
},
945+
},
946+
"role_arn": ROLE,
947+
"tags": [{"Key": "my-tag", "Value": "my-tag-value"}],
948+
}
949+
950+
805951
def _get_expected_args_all_parameters(job_name):
806952
return {
807953
"inputs": [

0 commit comments

Comments
 (0)