Skip to content

Commit e884ba0

Browse files
committed
Fix all tests that rely on _get_expected_args()
1 parent 163c668 commit e884ba0

File tree

1 file changed

+45
-8
lines changed

1 file changed

+45
-8
lines changed

tests/unit/test_processing.py

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
ROLE = "arn:aws:iam::012345678901:role/SageMakerRole"
3838
ECR_HOSTNAME = "ecr.us-west-2.amazonaws.com"
3939
CUSTOM_IMAGE_URI = "012345678901.dkr.ecr.us-west-2.amazonaws.com/my-custom-image-uri"
40+
MOCKED_S3_URI = "s3://mocked_s3_uri_from_upload_data"
4041

4142

4243
@pytest.fixture(autouse=True)
@@ -57,9 +58,7 @@ def sagemaker_session():
5758
)
5859
session_mock.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME)
5960

60-
session_mock.upload_data = Mock(
61-
name="upload_data", return_value="mocked_s3_uri_from_upload_data"
62-
)
61+
session_mock.upload_data = Mock(name="upload_data", return_value=MOCKED_S3_URI)
6362
session_mock.download_data = Mock(name="download_data")
6463
session_mock.expand_role.return_value = ROLE
6564
session_mock.describe_processing_job = MagicMock(
@@ -77,7 +76,7 @@ def test_sklearn_processor_with_required_parameters(
7776
botocore_resolver.return_value.construct_endpoint.return_value = {"hostname": ECR_HOSTNAME}
7877

7978
processor = SKLearnProcessor(
80-
s3_prefix="s3://abcd/ef",
79+
s3_prefix=MOCKED_S3_URI,
8180
role=ROLE,
8281
instance_type="ml.m4.xlarge",
8382
framework_version=sklearn_version,
@@ -87,7 +86,7 @@ def test_sklearn_processor_with_required_parameters(
8786

8887
processor.run(entry_point="/local/path/to/processing_code.py")
8988

90-
expected_args = _get_expected_args(processor._current_job_name, "s3://abcd/ef")
89+
expected_args = _get_expected_args_modular_code(processor._current_job_name)
9190

9291
sklearn_image_uri = (
9392
"246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:{}-cpu-py3"
@@ -619,7 +618,45 @@ def _get_script_processor(sagemaker_session):
619618
)
620619

621620

622-
def _get_expected_args(job_name, code_s3_uri="mocked_s3_uri_from_upload_data"):
621+
def _get_expected_args(job_name, code_s3_uri=MOCKED_S3_URI):
622+
return {
623+
"inputs": [
624+
{
625+
"InputName": "code",
626+
"AppManaged": False,
627+
"S3Input": {
628+
"S3Uri": code_s3_uri,
629+
"LocalPath": "/opt/ml/processing/input/code",
630+
"S3DataType": "S3Prefix",
631+
"S3InputMode": "File",
632+
"S3DataDistributionType": "FullyReplicated",
633+
"S3CompressionType": "None",
634+
},
635+
},
636+
],
637+
"output_config": {"Outputs": []},
638+
"job_name": job_name,
639+
"resources": {
640+
"ClusterConfig": {
641+
"InstanceType": "ml.m4.xlarge",
642+
"InstanceCount": 1,
643+
"VolumeSizeInGB": 30,
644+
}
645+
},
646+
"stopping_condition": None,
647+
"app_specification": {
648+
"ImageUri": CUSTOM_IMAGE_URI,
649+
"ContainerEntrypoint": ["python3", "/opt/ml/processing/input/code/processing_code.py"],
650+
},
651+
"environment": None,
652+
"network_config": None,
653+
"role_arn": ROLE,
654+
"tags": None,
655+
"experiment_config": None,
656+
}
657+
658+
659+
def _get_expected_args_modular_code(job_name, code_s3_uri=MOCKED_S3_URI):
623660
return {
624661
"inputs": [
625662
{
@@ -674,7 +711,7 @@ def _get_data_input():
674711
"InputName": "input-1",
675712
"AppManaged": False,
676713
"S3Input": {
677-
"S3Uri": "mocked_s3_uri_from_upload_data",
714+
"S3Uri": MOCKED_S3_URI,
678715
"LocalPath": "/data/",
679716
"S3DataType": "S3Prefix",
680717
"S3InputMode": "File",
@@ -835,7 +872,7 @@ def _get_expected_args_all_parameters(job_name):
835872
"InputName": "code",
836873
"AppManaged": False,
837874
"S3Input": {
838-
"S3Uri": "mocked_s3_uri_from_upload_data",
875+
"S3Uri": MOCKED_S3_URI,
839876
"LocalPath": "/opt/ml/processing/input/code",
840877
"S3DataType": "S3Prefix",
841878
"S3InputMode": "File",

0 commit comments

Comments
 (0)