Skip to content

Commit ff67bb3

Browse files
committed
change: add unit for framework ProcessingStep
Add a unit test to check SKLearn FrameworkProcessor behaves as ProcessingStep expects to avoid previous regression on SKLearnProcessor breaking SageMaker Pipelines examples.
1 parent 5fc171b commit ff67bb3

File tree

2 files changed

+54
-1
lines changed

2 files changed

+54
-1
lines changed

tests/unit/sagemaker/workflow/test_steps.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,7 @@ def test_processing_step_normalizes_args(mock_normalize_args, sagemaker_session)
395395
volume_kms_key="arn:aws:kms:us-west-2:012345678901:key/volume-kms-key",
396396
output_kms_key="arn:aws:kms:us-west-2:012345678901:key/output-kms-key",
397397
max_runtime_in_seconds=3600,
398-
base_job_name="my_sklearn_processor",
398+
base_job_name="my_script_processor",
399399
env={"my_env_variable": "my_env_variable_value"},
400400
tags=[{"Key": "my-tag", "Value": "my-tag-value"}],
401401
network_config=NetworkConfig(

tests/unit/test_processing.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,59 @@ def test_sklearn_with_all_parameters(
161161
sagemaker_session.process.assert_called_with(**expected_args)
162162

163163

164+
@patch("sagemaker.utils._botocore_resolver")
165+
@patch("os.path.exists", return_value=True)
166+
@patch("os.path.isfile", return_value=True)
167+
def test_normalize_args_prepares_frameworkprocessor(
168+
exists_mock, isfile_mock, botocore_resolver, sklearn_version, sagemaker_session, uploaded_code
169+
):
170+
botocore_resolver.return_value.construct_endpoint.return_value = {"hostname": ECR_HOSTNAME}
171+
172+
processor = SKLearnProcessor(
173+
role=ROLE,
174+
framework_version=sklearn_version,
175+
instance_type="ml.m4.xlarge",
176+
instance_count=1,
177+
sagemaker_session=sagemaker_session,
178+
)
179+
180+
raw_job_inputs = _get_data_inputs_all_parameters()
181+
raw_job_outputs = _get_data_outputs_all_parameters()
182+
with patch("sagemaker.estimator.tar_and_upload_dir", return_value=uploaded_code):
183+
# sagemaker.workflow.steps.ProcessingStep assumes that calling _normalize_args() on a
184+
# Processor is sufficient to ensure it packages whatever code might be to S3 and prepares
185+
# final ProcessingInputs for the job:
186+
normalized_inputs, normalized_outputs = processor._normalize_args(
187+
inputs=raw_job_inputs,
188+
outputs=raw_job_outputs,
189+
code="processing_code.py",
190+
source_dir="/local/path/to/source_dir",
191+
)
192+
process_args = ProcessingJob._get_process_args(
193+
processor, normalized_inputs, normalized_outputs, experiment_config=dict()
194+
)
195+
196+
# Code and entrypoint inputs should *both* have been added to the inputs:
197+
assert len(normalized_inputs) == len(raw_job_inputs) + 2
198+
normalized_inputs[0].input_name == "code"
199+
code_inputs = list(filter(lambda i: i.input_name == "code", normalized_inputs))
200+
assert len(code_inputs) == 1
201+
assert code_inputs[0].source == uploaded_code.s3_prefix
202+
entrypoint_inputs = list(filter(lambda i: i.input_name == "entrypoint", normalized_inputs))
203+
assert len(entrypoint_inputs) == 1
204+
205+
# Outputs should be as per raw:
206+
assert len(normalized_outputs) == len(raw_job_outputs)
207+
208+
# Job "entrypoint" should be the framework bootstrap script, *not* the user's script
209+
job_command = process_args["app_specification"]["ContainerEntrypoint"]
210+
assert (
211+
job_command[0 : len(processor.framework_entrypoint_command)]
212+
== processor.framework_entrypoint_command
213+
)
214+
assert "processing_code.py" not in job_command[1]
215+
216+
164217
@patch("sagemaker.local.LocalSession.__init__", return_value=None)
165218
def test_local_mode_disables_local_code_by_default(localsession_mock):
166219
Processor(

0 commit comments

Comments
 (0)