@@ -161,6 +161,59 @@ def test_sklearn_with_all_parameters(
161
161
sagemaker_session .process .assert_called_with (** expected_args )
162
162
163
163
164
+ @patch ("sagemaker.utils._botocore_resolver" )
165
+ @patch ("os.path.exists" , return_value = True )
166
+ @patch ("os.path.isfile" , return_value = True )
167
+ def test_normalize_args_prepares_frameworkprocessor (
168
+ exists_mock , isfile_mock , botocore_resolver , sklearn_version , sagemaker_session , uploaded_code
169
+ ):
170
+ botocore_resolver .return_value .construct_endpoint .return_value = {"hostname" : ECR_HOSTNAME }
171
+
172
+ processor = SKLearnProcessor (
173
+ role = ROLE ,
174
+ framework_version = sklearn_version ,
175
+ instance_type = "ml.m4.xlarge" ,
176
+ instance_count = 1 ,
177
+ sagemaker_session = sagemaker_session ,
178
+ )
179
+
180
+ raw_job_inputs = _get_data_inputs_all_parameters ()
181
+ raw_job_outputs = _get_data_outputs_all_parameters ()
182
+ with patch ("sagemaker.estimator.tar_and_upload_dir" , return_value = uploaded_code ):
183
+ # sagemaker.workflow.steps.ProcessingStep assumes that calling _normalize_args() on a
184
+ # Processor is sufficient to ensure it packages whatever code might be to S3 and prepares
185
+ # final ProcessingInputs for the job:
186
+ normalized_inputs , normalized_outputs = processor ._normalize_args (
187
+ inputs = raw_job_inputs ,
188
+ outputs = raw_job_outputs ,
189
+ code = "processing_code.py" ,
190
+ source_dir = "/local/path/to/source_dir" ,
191
+ )
192
+ process_args = ProcessingJob ._get_process_args (
193
+ processor , normalized_inputs , normalized_outputs , experiment_config = dict ()
194
+ )
195
+
196
+ # Code and entrypoint inputs should *both* have been added to the inputs:
197
+ assert len (normalized_inputs ) == len (raw_job_inputs ) + 2
198
+ normalized_inputs [0 ].input_name == "code"
199
+ code_inputs = list (filter (lambda i : i .input_name == "code" , normalized_inputs ))
200
+ assert len (code_inputs ) == 1
201
+ assert code_inputs [0 ].source == uploaded_code .s3_prefix
202
+ entrypoint_inputs = list (filter (lambda i : i .input_name == "entrypoint" , normalized_inputs ))
203
+ assert len (entrypoint_inputs ) == 1
204
+
205
+ # Outputs should be as per raw:
206
+ assert len (normalized_outputs ) == len (raw_job_outputs )
207
+
208
+ # Job "entrypoint" should be the framework bootstrap script, *not* the user's script
209
+ job_command = process_args ["app_specification" ]["ContainerEntrypoint" ]
210
+ assert (
211
+ job_command [0 : len (processor .framework_entrypoint_command )]
212
+ == processor .framework_entrypoint_command
213
+ )
214
+ assert "processing_code.py" not in job_command [1 ]
215
+
216
+
164
217
@patch ("sagemaker.local.LocalSession.__init__" , return_value = None )
165
218
def test_local_mode_disables_local_code_by_default (localsession_mock ):
166
219
Processor (
0 commit comments