Skip to content

Commit db69b5f

Browse files
author
Brock Wade
committed
fix: frameworkprocessor side effects, testing
1 parent 2a1274c commit db69b5f

File tree

9 files changed

+219
-18
lines changed

9 files changed

+219
-18
lines changed

src/sagemaker/processing.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1704,6 +1704,7 @@ def _pack_and_upload_code(
17041704
self, code, source_dir, dependencies, git_config, job_name, inputs, kms_key=None
17051705
):
17061706
"""Pack local code bundle and upload to Amazon S3."""
1707+
from sagemaker.workflow.utilities import _pipeline_config, hash_object
17071708
if code.startswith("s3://"):
17081709
return code, inputs, job_name
17091710

@@ -1737,12 +1738,29 @@ def _pack_and_upload_code(
17371738
"runproc.sh",
17381739
)
17391740
script = estimator.uploaded_code.script_name
1740-
s3_runproc_sh = S3Uploader.upload_string_as_file_body(
1741-
self._generate_framework_script(script),
1742-
desired_s3_uri=entrypoint_s3_uri,
1743-
kms_key=kms_key,
1744-
sagemaker_session=self.sagemaker_session,
1745-
)
1741+
1742+
# If we are leveraging a pipeline session with optimized s3 artifact paths,
1743+
# we need to hash and upload the runproc.sh file to a separate location.
1744+
if _pipeline_config and _pipeline_config.pipeline_name:
1745+
runproc_file_str = self._generate_framework_script(script)
1746+
runproc_file_hash = hash_object(runproc_file_str)
1747+
s3_uri = (
1748+
f"s3://{self.sagemaker_session.default_bucket()}/{_pipeline_config.pipeline_name}/"
1749+
f"code/{runproc_file_hash}/runproc.sh"
1750+
)
1751+
s3_runproc_sh = S3Uploader.upload_string_as_file_body(
1752+
runproc_file_str,
1753+
desired_s3_uri=s3_uri,
1754+
kms_key=kms_key,
1755+
sagemaker_session=self.sagemaker_session,
1756+
)
1757+
else:
1758+
s3_runproc_sh = S3Uploader.upload_string_as_file_body(
1759+
self._generate_framework_script(script),
1760+
desired_s3_uri=entrypoint_s3_uri,
1761+
kms_key=kms_key,
1762+
sagemaker_session=self.sagemaker_session,
1763+
)
17461764
logger.info("runproc.sh uploaded to %s", s3_runproc_sh)
17471765

17481766
return s3_runproc_sh, inputs, job_name

src/sagemaker/spark/processing.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,12 @@ def run(
279279
def _extend_processing_args(self, inputs, outputs, **kwargs):
280280
"""Extends processing job args such as inputs."""
281281

282+
# make a copy of user outputs
283+
outputs = outputs or []
284+
extended_outputs = []
285+
for user_output in outputs:
286+
extended_outputs.append(user_output)
287+
282288
if kwargs.get("spark_event_logs_s3_uri"):
283289
spark_event_logs_s3_uri = kwargs.get("spark_event_logs_s3_uri")
284290
self._validate_s3_uri(spark_event_logs_s3_uri)
@@ -297,16 +303,20 @@ def _extend_processing_args(self, inputs, outputs, **kwargs):
297303
s3_upload_mode="Continuous",
298304
)
299305

300-
outputs = outputs or []
301-
outputs.append(output)
306+
extended_outputs.append(output)
307+
308+
# make a copy of user inputs
309+
inputs = inputs or []
310+
extended_inputs = []
311+
for user_input in inputs:
312+
extended_inputs.append(user_input)
302313

303314
if kwargs.get("configuration"):
304315
configuration = kwargs.get("configuration")
305316
self._validate_configuration(configuration)
306-
inputs = inputs or []
307-
inputs.append(self._stage_configuration(configuration))
317+
extended_inputs.append(self._stage_configuration(configuration))
308318

309-
return inputs, outputs
319+
return extended_inputs, extended_outputs
310320

311321
def start_history_server(self, spark_event_logs_s3_uri=None):
312322
"""Starts a Spark history server.

src/sagemaker/workflow/utilities.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,12 @@ def get_code_hash(step: Entity) -> str:
114114
if isinstance(step, ProcessingStep) and step.step_args:
115115
kwargs = step.step_args.func_kwargs
116116
source_dir = kwargs.get("source_dir")
117+
submit_class = kwargs.get("submit_class")
117118
dependencies = get_processing_dependencies(
118119
[
119120
kwargs.get("dependencies"),
120121
kwargs.get("submit_py_files"),
121-
kwargs.get("submit_class"),
122+
[submit_class] if submit_class else None,
122123
kwargs.get("submit_jars"),
123124
kwargs.get("submit_files"),
124125
]
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
"""
2+
Integ test file evaluate.py
3+
"""
4+
5+
def evaluate():
6+
print("evaluate")
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
"""
2+
Integ test file preprocess.py
3+
"""
4+
5+
def preprocess():
6+
print("preprocess")
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
"""
2+
Integ test file query_data.py
3+
"""
4+
5+
def query_data():
6+
print("query data")
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
"""
2+
Integ test file train_test_split.py
3+
"""
4+
5+
def train_test_split():
6+
print("train, test, split")

tests/integ/sagemaker/workflow/test_workflow.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,14 @@
1919
import time
2020
import shutil
2121

22+
from pathlib import Path
2223
from contextlib import contextmanager
2324
import pytest
2425

2526
from botocore.exceptions import WaiterError
2627
import pandas as pd
28+
from sagemaker.network import NetworkConfig
29+
from sagemaker.tensorflow import TensorFlow
2730

2831
from tests.integ.s3_utils import extract_files_from_s3
2932
from sagemaker.workflow.model_step import (
@@ -47,6 +50,7 @@
4750
ProcessingOutput,
4851
FeatureStoreOutput,
4952
ScriptProcessor,
53+
FrameworkProcessor
5054
)
5155
from sagemaker.s3 import S3Uploader
5256
from sagemaker.session import get_execution_role
@@ -83,6 +87,7 @@
8387
TransformInput,
8488
PropertyFile,
8589
TuningStep,
90+
CacheConfig
8691
)
8792
from sagemaker.workflow.step_collections import RegisterModel
8893
from sagemaker.workflow.pipeline import Pipeline
@@ -1310,3 +1315,149 @@ def test_caching_behavior(
13101315
except Exception:
13111316
os.remove(script_dir + "/dummy_script.py")
13121317
pass
1318+
1319+
def test_processing_steps_with_framework_processor(pipeline_session, role):
1320+
default_bucket = pipeline_session.default_bucket()
1321+
cache_config = CacheConfig(enable_caching=True, expire_after="PT1H")
1322+
evaluation_report = PropertyFile(
1323+
name="EvaluationReport", output_name="evaluation", path="evaluation.json"
1324+
)
1325+
query_processor = ScriptProcessor(
1326+
command=["python3"],
1327+
image_uri="my-img",
1328+
role=role,
1329+
instance_count=1,
1330+
instance_type="ml.m5.xlarge",
1331+
network_config=NetworkConfig(
1332+
enable_network_isolation=False,
1333+
# VPC-Prod
1334+
subnets=["subnet-something"],
1335+
security_group_ids=["sg-something"],
1336+
),
1337+
sagemaker_session=pipeline_session,
1338+
)
1339+
1340+
data_processor = FrameworkProcessor(
1341+
role=role,
1342+
instance_type="ml.m5.xlarge",
1343+
instance_count=1,
1344+
estimator_cls=TensorFlow,
1345+
framework_version="2.9",
1346+
py_version="py39",
1347+
sagemaker_session=pipeline_session,
1348+
)
1349+
1350+
query_step = ProcessingStep(
1351+
name="Query-Data",
1352+
step_args=query_processor.run(
1353+
code=os.path.join(DATA_DIR, "framework_processor_data/query_data.py"),
1354+
arguments=[
1355+
"--output-path",
1356+
"s3://out1",
1357+
"--region",
1358+
"s3://out2",
1359+
],
1360+
),
1361+
cache_config=cache_config,
1362+
)
1363+
1364+
input_path = "/opt/ml/processing/input"
1365+
output_path = "/opt/ml/processing/output"
1366+
1367+
prepare_step = ProcessingStep(
1368+
name="Prepare-Data",
1369+
step_args=data_processor.run(
1370+
code="preprocess.py",
1371+
source_dir=DATA_DIR + "/framework_processor_data",
1372+
inputs=[
1373+
ProcessingInput(
1374+
input_name="task_preprocess_input",
1375+
source=query_step.properties.ProcessingOutputConfig.Outputs["task_query_output"].S3Output.S3Uri,
1376+
destination=input_path,
1377+
)
1378+
],
1379+
arguments=[
1380+
"--input-path",
1381+
input_path,
1382+
"--output-path",
1383+
output_path,
1384+
],
1385+
),
1386+
cache_config=cache_config,
1387+
)
1388+
1389+
split_step = ProcessingStep(
1390+
name="Split-Data",
1391+
step_args=data_processor.run(
1392+
code="train_test_split.py",
1393+
source_dir=DATA_DIR + "/framework_processor_data",
1394+
inputs=[
1395+
ProcessingInput(
1396+
source=prepare_step.properties.ProcessingOutputConfig.Outputs[
1397+
"task_preprocess_output"
1398+
].S3Output.S3Uri,
1399+
destination=input_path,
1400+
),
1401+
],
1402+
arguments=["--input-path", input_path, "--output-path", output_path],
1403+
),
1404+
cache_config=cache_config,
1405+
)
1406+
1407+
sk_processor = FrameworkProcessor(
1408+
framework_version="1.0-1",
1409+
instance_type="ml.m5.xlarge",
1410+
instance_count=1,
1411+
base_job_name="my-job",
1412+
role=role,
1413+
estimator_cls=SKLearn,
1414+
sagemaker_session=pipeline_session,
1415+
)
1416+
1417+
evaluate_step = ProcessingStep(
1418+
name="Evaluate-Model",
1419+
step_args=sk_processor.run(
1420+
code="evaluate.py",
1421+
source_dir=DATA_DIR + "/framework_processor_data",
1422+
outputs=[
1423+
ProcessingOutput(
1424+
output_name="evaluation",
1425+
source="/opt/ml/processing/evaluation",
1426+
),
1427+
],
1428+
),
1429+
property_files=[evaluation_report],
1430+
cache_config=cache_config,
1431+
)
1432+
1433+
pipeline = Pipeline(
1434+
name="test-fw-proc-steps-pipeline",
1435+
steps=[query_step, prepare_step, split_step, evaluate_step]
1436+
)
1437+
try:
1438+
# create pipeline
1439+
pipeline.create(role)
1440+
definition = json.loads(pipeline.definition())
1441+
1442+
source_dir_tar_prefix = f"s3://{default_bucket}/{pipeline.name}" \
1443+
f"/code/{hash_files_or_dirs([DATA_DIR + '/framework_processor_data'])}"
1444+
1445+
run_procs = []
1446+
1447+
for step in definition["Steps"]:
1448+
for input_obj in step["Arguments"]["ProcessingInputs"]:
1449+
if input_obj["InputName"] == "entrypoint":
1450+
s3_uri = input_obj["S3Input"]["S3Uri"]
1451+
run_procs.append(s3_uri)
1452+
1453+
# verify runproc.sh prefix is different from code artifact prefix
1454+
assert Path(s3_uri).parent != source_dir_tar_prefix
1455+
1456+
# verify all the run_proc.sh artifact paths are distinct
1457+
assert len(run_procs) == len(set(run_procs))
1458+
1459+
finally:
1460+
try:
1461+
pipeline.delete()
1462+
except Exception:
1463+
pass

tests/unit/sagemaker/workflow/test_pipeline.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import pytest
1919

20-
from mock import Mock
20+
from mock import Mock, patch
2121

2222
from sagemaker import s3
2323
from sagemaker.workflow.condition_step import ConditionStep
@@ -78,6 +78,7 @@ def test_pipeline_create_with_parallelism_config(sagemaker_session_mock, role_ar
7878
)
7979

8080

81+
@patch("sagemaker.spark.processing.S3Uploader.upload_string_as_file_body")
8182
def test_large_pipeline_create(sagemaker_session_mock, role_arn):
8283
parameter = ParameterString("MyStr")
8384
pipeline = Pipeline(
@@ -87,8 +88,6 @@ def test_large_pipeline_create(sagemaker_session_mock, role_arn):
8788
sagemaker_session=sagemaker_session_mock,
8889
)
8990

90-
s3.S3Uploader.upload_string_as_file_body = Mock()
91-
9291
pipeline.create(role_arn=role_arn)
9392

9493
assert s3.S3Uploader.upload_string_as_file_body.called_with(
@@ -150,7 +149,7 @@ def test_pipeline_update_with_parallelism_config(sagemaker_session_mock, role_ar
150149
ParallelismConfiguration={"MaxParallelExecutionSteps": 10},
151150
)
152151

153-
152+
@patch("sagemaker.spark.processing.S3Uploader.upload_string_as_file_body")
154153
def test_large_pipeline_update(sagemaker_session_mock, role_arn):
155154
parameter = ParameterString("MyStr")
156155
pipeline = Pipeline(
@@ -160,8 +159,6 @@ def test_large_pipeline_update(sagemaker_session_mock, role_arn):
160159
sagemaker_session=sagemaker_session_mock,
161160
)
162161

163-
s3.S3Uploader.upload_string_as_file_body = Mock()
164-
165162
pipeline.create(role_arn=role_arn)
166163

167164
assert s3.S3Uploader.upload_string_as_file_body.called_with(

0 commit comments

Comments
 (0)