Skip to content

Commit dac7ed3

Browse files
author
Verdi March
committed
Fixed test_sklearn_with_all_parameters()
1 parent 85dd50f commit dac7ed3

File tree

1 file changed

+147
-5
lines changed

1 file changed

+147
-5
lines changed

tests/unit/test_processing.py

Lines changed: 147 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def test_sklearn_with_all_parameters(
105105
botocore_resolver.return_value.construct_endpoint.return_value = {"hostname": ECR_HOSTNAME}
106106

107107
processor = SKLearnProcessor(
108+
s3_prefix=MOCKED_S3_URI,
108109
role=ROLE,
109110
framework_version=sklearn_version,
110111
instance_type="ml.m4.xlarge",
@@ -126,7 +127,7 @@ def test_sklearn_with_all_parameters(
126127
)
127128

128129
processor.run(
129-
code="/local/path/to/processing_code.py",
130+
entry_point="/local/path/to/processing_code.py",
130131
inputs=_get_data_inputs_all_parameters(),
131132
outputs=_get_data_outputs_all_parameters(),
132133
arguments=["--drop-columns", "'SelfEmployed'"],
@@ -136,7 +137,7 @@ def test_sklearn_with_all_parameters(
136137
experiment_config={"ExperimentName": "AnExperiment"},
137138
)
138139

139-
expected_args = _get_expected_args_all_parameters(processor._current_job_name)
140+
expected_args = _get_expected_args_all_parameters_modular_code(processor._current_job_name)
140141
sklearn_image_uri = (
141142
"246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:{}-cpu-py3"
142143
).format(sklearn_version)
@@ -748,9 +749,9 @@ def _get_data_inputs_all_parameters():
748749
input_name="redshift_dataset_definition",
749750
app_managed=True,
750751
dataset_definition=DatasetDefinition(
751-
local_path="/opt/ml/processing/input/dd",
752752
data_distribution_type="FullyReplicated",
753753
input_mode="File",
754+
local_path="/opt/ml/processing/input/dd",
754755
redshift_dataset_definition=RedshiftDatasetDefinition(
755756
cluster_id="cluster_id",
756757
database="database",
@@ -768,15 +769,15 @@ def _get_data_inputs_all_parameters():
768769
input_name="athena_dataset_definition",
769770
app_managed=True,
770771
dataset_definition=DatasetDefinition(
771-
local_path="/opt/ml/processing/input/dd",
772772
data_distribution_type="FullyReplicated",
773773
input_mode="File",
774+
local_path="/opt/ml/processing/input/dd",
774775
athena_dataset_definition=AthenaDatasetDefinition(
775776
catalog="catalog",
776777
database="database",
777-
work_group="workgroup",
778778
query_string="query_string",
779779
output_s3_uri="output_s3_uri",
780+
work_group="workgroup",
780781
kms_key_id="kms_key_id",
781782
output_format="AVRO",
782783
output_compression="ZLIB",
@@ -802,6 +803,147 @@ def _get_data_outputs_all_parameters():
802803
]
803804

804805

806+
def _get_expected_args_all_parameters_modular_code(job_name, code_s3_uri=MOCKED_S3_URI):
807+
# Add something to inputs
808+
return {
809+
"inputs": [
810+
{
811+
"InputName": "my_dataset",
812+
"AppManaged": False,
813+
"S3Input": {
814+
"S3Uri": "s3://path/to/my/dataset/census.csv",
815+
"LocalPath": "/container/path/",
816+
"S3DataType": "S3Prefix",
817+
"S3InputMode": "File",
818+
"S3DataDistributionType": "FullyReplicated",
819+
"S3CompressionType": "None",
820+
},
821+
},
822+
{
823+
"InputName": "s3_input",
824+
"AppManaged": False,
825+
"S3Input": {
826+
"S3Uri": "s3://path/to/my/dataset/census.csv",
827+
"LocalPath": "/container/path/",
828+
"S3DataType": "S3Prefix",
829+
"S3InputMode": "File",
830+
"S3DataDistributionType": "FullyReplicated",
831+
"S3CompressionType": "None",
832+
},
833+
},
834+
{
835+
"InputName": "redshift_dataset_definition",
836+
"AppManaged": True,
837+
"DatasetDefinition": {
838+
"DataDistributionType": "FullyReplicated",
839+
"InputMode": "File",
840+
"LocalPath": "/opt/ml/processing/input/dd",
841+
"RedshiftDatasetDefinition": {
842+
"ClusterId": "cluster_id",
843+
"Database": "database",
844+
"DbUser": "db_user",
845+
"QueryString": "query_string",
846+
"ClusterRoleArn": "cluster_role_arn",
847+
"OutputS3Uri": "output_s3_uri",
848+
"KmsKeyId": "kms_key_id",
849+
"OutputFormat": "CSV",
850+
"OutputCompression": "SNAPPY",
851+
},
852+
},
853+
},
854+
{
855+
"InputName": "athena_dataset_definition",
856+
"AppManaged": True,
857+
"DatasetDefinition": {
858+
"DataDistributionType": "FullyReplicated",
859+
"InputMode": "File",
860+
"LocalPath": "/opt/ml/processing/input/dd",
861+
"AthenaDatasetDefinition": {
862+
"Catalog": "catalog",
863+
"Database": "database",
864+
"QueryString": "query_string",
865+
"OutputS3Uri": "output_s3_uri",
866+
"WorkGroup": "workgroup",
867+
"KmsKeyId": "kms_key_id",
868+
"OutputFormat": "AVRO",
869+
"OutputCompression": "ZLIB",
870+
},
871+
},
872+
},
873+
{
874+
"InputName": "input-5",
875+
"AppManaged": False,
876+
"S3Input": {
877+
"S3Uri": f"{code_s3_uri}/{job_name}/source/sourcedir.tar.gz",
878+
"LocalPath": "/opt/ml/processing/input/code/payload/",
879+
"S3DataType": "S3Prefix",
880+
"S3InputMode": "File",
881+
"S3DataDistributionType": "FullyReplicated",
882+
"S3CompressionType": "None",
883+
},
884+
},
885+
{
886+
"InputName": "code",
887+
"AppManaged": False,
888+
"S3Input": {
889+
"S3Uri": f"{code_s3_uri}/{job_name}/source/runproc.sh",
890+
"LocalPath": "/opt/ml/processing/input/code",
891+
"S3DataType": "S3Prefix",
892+
"S3InputMode": "File",
893+
"S3DataDistributionType": "FullyReplicated",
894+
"S3CompressionType": "None",
895+
},
896+
},
897+
],
898+
"output_config": {
899+
"Outputs": [
900+
{
901+
"OutputName": "my_output",
902+
"AppManaged": False,
903+
"S3Output": {
904+
"S3Uri": "s3://uri/",
905+
"LocalPath": "/container/path/",
906+
"S3UploadMode": "EndOfJob",
907+
},
908+
},
909+
{
910+
"OutputName": "feature_store_output",
911+
"AppManaged": True,
912+
"FeatureStoreOutput": {"FeatureGroupName": "FeatureGroupName"},
913+
},
914+
],
915+
"KmsKeyId": "arn:aws:kms:us-west-2:012345678901:key/output-kms-key",
916+
},
917+
"experiment_config": {"ExperimentName": "AnExperiment"},
918+
"job_name": job_name,
919+
"resources": {
920+
"ClusterConfig": {
921+
"InstanceType": "ml.m4.xlarge",
922+
"InstanceCount": 1,
923+
"VolumeSizeInGB": 100,
924+
"VolumeKmsKeyId": "arn:aws:kms:us-west-2:012345678901:key/volume-kms-key",
925+
}
926+
},
927+
"stopping_condition": {"MaxRuntimeInSeconds": 3600},
928+
"app_specification": {
929+
"ImageUri": "012345678901.dkr.ecr.us-west-2.amazonaws.com/my-custom-image-uri",
930+
"ContainerArguments": ["--drop-columns", "'SelfEmployed'"],
931+
"ContainerEntrypoint": ["/bin/bash", "/opt/ml/processing/input/code/runproc.sh"],
932+
},
933+
"environment": {"my_env_variable": "my_env_variable_value"},
934+
"network_config": {
935+
"EnableNetworkIsolation": True,
936+
"EnableInterContainerTrafficEncryption": True,
937+
"VpcConfig": {
938+
"SecurityGroupIds": ["my_security_group_id"],
939+
"Subnets": ["my_subnet_id"],
940+
},
941+
},
942+
"role_arn": ROLE,
943+
"tags": [{"Key": "my-tag", "Value": "my-tag-value"}],
944+
}
945+
946+
805947
def _get_expected_args_all_parameters(job_name):
806948
return {
807949
"inputs": [

0 commit comments

Comments
 (0)