Skip to content

Commit 8d7f82e

Browse files
author
Brock Wade
committed
fix: unrelated test suite pollution, and update spark base class
1 parent 864a90a commit 8d7f82e

File tree

3 files changed

+26
-10
lines changed

3 files changed

+26
-10
lines changed

src/sagemaker/spark/processing.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,12 @@ def run(
279279
def _extend_processing_args(self, inputs, outputs, **kwargs):
280280
"""Extends processing job args such as inputs."""
281281

282+
# make a copy of user outputs
283+
outputs = outputs or []
284+
extended_outputs = []
285+
for user_output in outputs:
286+
extended_outputs.append(user_output)
287+
282288
if kwargs.get("spark_event_logs_s3_uri"):
283289
spark_event_logs_s3_uri = kwargs.get("spark_event_logs_s3_uri")
284290
self._validate_s3_uri(spark_event_logs_s3_uri)
@@ -297,16 +303,20 @@ def _extend_processing_args(self, inputs, outputs, **kwargs):
297303
s3_upload_mode="Continuous",
298304
)
299305

300-
outputs = outputs or []
301-
outputs.append(output)
306+
extended_outputs.append(output)
307+
308+
# make a copy of user inputs
309+
inputs = inputs or []
310+
extended_inputs = []
311+
for user_input in inputs:
312+
extended_inputs.append(user_input)
302313

303314
if kwargs.get("configuration"):
304315
configuration = kwargs.get("configuration")
305316
self._validate_configuration(configuration)
306-
inputs = inputs or []
307-
inputs.append(self._stage_configuration(configuration))
317+
extended_inputs.append(self._stage_configuration(configuration))
308318

309-
return inputs, outputs
319+
return extended_inputs, extended_outputs
310320

311321
def start_history_server(self, spark_event_logs_s3_uri=None):
312322
"""Starts a Spark history server.

tests/unit/sagemaker/workflow/test_pipeline.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import pytest
1919

20-
from mock import Mock
20+
from mock import Mock, patch
2121

2222
from sagemaker import s3
2323
from sagemaker.workflow.condition_step import ConditionStep
@@ -78,6 +78,7 @@ def test_pipeline_create_with_parallelism_config(sagemaker_session_mock, role_ar
7878
)
7979

8080

81+
@patch("sagemaker.s3.S3Uploader.upload_string_as_file_body")
8182
def test_large_pipeline_create(sagemaker_session_mock, role_arn):
8283
parameter = ParameterString("MyStr")
8384
pipeline = Pipeline(
@@ -87,8 +88,6 @@ def test_large_pipeline_create(sagemaker_session_mock, role_arn):
8788
sagemaker_session=sagemaker_session_mock,
8889
)
8990

90-
s3.S3Uploader.upload_string_as_file_body = Mock()
91-
9291
pipeline.create(role_arn=role_arn)
9392

9493
assert s3.S3Uploader.upload_string_as_file_body.called_with(
@@ -151,6 +150,7 @@ def test_pipeline_update_with_parallelism_config(sagemaker_session_mock, role_ar
151150
)
152151

153152

153+
@patch("sagemaker.s3.S3Uploader.upload_string_as_file_body")
154154
def test_large_pipeline_update(sagemaker_session_mock, role_arn):
155155
parameter = ParameterString("MyStr")
156156
pipeline = Pipeline(
@@ -160,8 +160,6 @@ def test_large_pipeline_update(sagemaker_session_mock, role_arn):
160160
sagemaker_session=sagemaker_session_mock,
161161
)
162162

163-
s3.S3Uploader.upload_string_as_file_body = Mock()
164-
165163
pipeline.create(role_arn=role_arn)
166164

167165
assert s3.S3Uploader.upload_string_as_file_body.called_with(

tests/unit/sagemaker/workflow/test_processing_step.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -992,6 +992,10 @@ def test_spark_processor(spark_processor, processing_input, pipeline_session):
992992
SPARK_SUBMIT_FILE2,
993993
],
994994
"spark_event_logs_s3_uri": ParameterString("MySparkEventLogS3Uri"),
995+
"configuration": {
996+
"Classification": "core-site",
997+
"Properties": {"hadoop.security.groups.cache.secs": "250"},
998+
},
995999
},
9961000
),
9971001
(
@@ -1016,6 +1020,10 @@ def test_spark_processor(spark_processor, processing_input, pipeline_session):
10161020
"submit_jars": [SPARK_DEP_JAR],
10171021
"submit_files": [SPARK_SUBMIT_FILE1, SPARK_SUBMIT_FILE2],
10181022
"spark_event_logs_s3_uri": ParameterString("MySparkEventLogS3Uri"),
1023+
"configuration": {
1024+
"Classification": "core-site",
1025+
"Properties": {"hadoop.security.groups.cache.secs": "250"},
1026+
},
10191027
},
10201028
),
10211029
],

0 commit comments

Comments
 (0)