Skip to content

Commit 20cd3b6

Browse files
nmadanrohangujarathisvia3ZhankuilDewen Qi
authored andcommitted
feature: Add Pipeline step decorator, NotebookJobStep, and scheduler (#1271)
Co-authored-by: Rohan Gujarathi <[email protected]> Co-authored-by: svia3 <[email protected]> Co-authored-by: Zhankui Lu <[email protected]> Co-authored-by: Dewen Qi <[email protected]> Co-authored-by: Edward Sun <[email protected]> Co-authored-by: Stephen Via <[email protected]> Co-authored-by: Namrata Madan <[email protected]> Co-authored-by: Stacia Choe <[email protected]> Co-authored-by: Edward Sun <[email protected]> Co-authored-by: Edward Sun <[email protected]> Co-authored-by: Rohan Gujarathi <[email protected]> fix: Multiple bug fixes including removing unsupported feature. (#1105) Fix some problems with pipeline compilation (#1125) fix: Refactor JsonGet s3 URI and add serialize_output_to_json flag (#1164) fix: invoke_function circular import (#1262) fix: pylint (#1264) fix: Add logging for docker build failures (#1267)
1 parent c3069b3 commit 20cd3b6

File tree

112 files changed

+13876
-1001
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

112 files changed

+13876
-1001
lines changed

doc/workflows/pipelines/sagemaker.workflow.pipelines.rst

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,10 @@ Pipeline Context
9696
.. autoclass:: sagemaker.workflow.pipeline_context.LocalPipelineSession
9797
:members:
9898

99+
Pipeline Schedule
100+
-----------------
101+
102+
.. autoclass:: sagemaker.workflow.triggers.PipelineSchedule
99103

100104
Parallelism Configuration
101105
-------------------------
@@ -120,7 +124,6 @@ Selective Execution Config
120124

121125
.. autoclass:: sagemaker.workflow.selective_execution_config.SelectiveExecutionConfig
122126

123-
124127
Properties
125128
----------
126129

@@ -162,6 +165,8 @@ Steps
162165

163166
.. autoclass:: sagemaker.workflow.steps.ProcessingStep
164167

168+
.. autoclass:: sagemaker.workflow.notebook_job_step.NotebookJobStep
169+
165170
.. autoclass:: sagemaker.workflow.steps.CreateModelStep
166171

167172
.. autoclass:: sagemaker.workflow.callback_step.CallbackStep
@@ -185,3 +190,14 @@ Steps
185190
.. autoclass:: sagemaker.workflow.emr_step.EMRStep
186191

187192
.. autoclass:: sagemaker.workflow.automl_step.AutoMLStep
193+
194+
@step decorator
195+
---------------
196+
197+
.. automethod:: sagemaker.workflow.function_step.step
198+
199+
.. autoclass:: sagemaker.workflow.function_step.DelayedReturn
200+
201+
.. autoclass:: sagemaker.workflow.step_outputs.StepOutput
202+
203+
.. autofunction:: sagemaker.workflow.step_outputs.get_step

requirements/extras/test_requirements.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,8 @@ sagemaker-feature-store-pyspark-3.3
3232
# TODO find workaround
3333
xgboost>=1.6.2,<=1.7.6
3434
pillow>=9.5.0,<=10.0.0
35-
torch@https://download.pytorch.org/whl/cpu/torch-2.0.0%2Bcpu-cp310-cp310-linux_x86_64.whl
36-
torchvision@https://download.pytorch.org/whl/cpu/torchvision-0.15.1%2Bcpu-cp310-cp310-linux_x86_64.whl
3735
transformers==4.32.0
3836
sentencepiece==0.1.99
3937
# https://github.com/triton-inference-server/server/issues/6246
4038
tritonclient[http]<2.37.0
39+
nbformat>=5.9,<6

src/sagemaker/config/config_schema.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@
112112
CONTAINER_ROOT = "container_root"
113113
REGION_NAME = "region_name"
114114
TELEMETRY_OPT_OUT = "TelemetryOptOut"
115+
NOTEBOOK_JOB = "NotebookJob"
115116

116117

117118
def _simple_path(*args: str):
@@ -276,6 +277,7 @@ def _simple_path(*args: str):
276277
MODEL_PACKAGE_VALIDATION_PROFILES_PATH = _simple_path(
277278
SAGEMAKER, MODEL_PACKAGE, VALIDATION_SPECIFICATION, VALIDATION_PROFILES
278279
)
280+
REMOTE_FUNCTION_PATH = _simple_path(SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION)
279281
REMOTE_FUNCTION_DEPENDENCIES = _simple_path(
280282
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, DEPENDENCIES
281283
)
@@ -322,6 +324,20 @@ def _simple_path(*args: str):
322324
REMOTE_FUNCTION,
323325
ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION,
324326
)
327+
NOTEBOOK_JOB_ROLE_ARN = _simple_path(SAGEMAKER, PYTHON_SDK, MODULES, NOTEBOOK_JOB, ROLE_ARN)
328+
NOTEBOOK_JOB_S3_ROOT_URI = _simple_path(SAGEMAKER, PYTHON_SDK, MODULES, NOTEBOOK_JOB, S3_ROOT_URI)
329+
NOTEBOOK_JOB_S3_KMS_KEY_ID = _simple_path(
330+
SAGEMAKER, PYTHON_SDK, MODULES, NOTEBOOK_JOB, S3_KMS_KEY_ID
331+
)
332+
NOTEBOOK_JOB_VOLUME_KMS_KEY_ID = _simple_path(
333+
SAGEMAKER, PYTHON_SDK, MODULES, NOTEBOOK_JOB, VOLUME_KMS_KEY_ID
334+
)
335+
NOTEBOOK_JOB_VPC_CONFIG_SUBNETS = _simple_path(
336+
SAGEMAKER, PYTHON_SDK, MODULES, NOTEBOOK_JOB, VPC_CONFIG, SUBNETS
337+
)
338+
NOTEBOOK_JOB_VPC_CONFIG_SECURITY_GROUP_IDS = _simple_path(
339+
SAGEMAKER, PYTHON_SDK, MODULES, NOTEBOOK_JOB, VPC_CONFIG, SECURITY_GROUP_IDS
340+
)
325341
MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION_PATH = _simple_path(
326342
SAGEMAKER,
327343
MONITORING_SCHEDULE,
@@ -709,6 +725,16 @@ def _simple_path(*args: str):
709725
},
710726
IMAGE_URI: {TYPE: "string"},
711727
INCLUDE_LOCAL_WORKDIR: {TYPE: "boolean"},
728+
"WorkdirConfig": {
729+
TYPE: OBJECT,
730+
ADDITIONAL_PROPERTIES: False,
731+
PROPERTIES: {
732+
"IgnoreNamePatterns": {
733+
TYPE: "array",
734+
"items": {"type": "string"},
735+
}
736+
},
737+
},
712738
INSTANCE_TYPE: {TYPE: "string"},
713739
JOB_CONDA_ENV: {TYPE: "string"},
714740
ROLE_ARN: {"$ref": "#/definitions/roleArn"},
@@ -719,6 +745,17 @@ def _simple_path(*args: str):
719745
VPC_CONFIG: {"$ref": "#/definitions/vpcConfig"},
720746
},
721747
},
748+
NOTEBOOK_JOB: {
749+
TYPE: OBJECT,
750+
ADDITIONAL_PROPERTIES: False,
751+
PROPERTIES: {
752+
ROLE_ARN: {"$ref": "#/definitions/roleArn"},
753+
S3_ROOT_URI: {"$ref": "#/definitions/s3Uri"},
754+
S3_KMS_KEY_ID: {"$ref": "#/definitions/kmsKeyId"},
755+
VOLUME_KMS_KEY_ID: {"$ref": "#/definitions/kmsKeyId"},
756+
VPC_CONFIG: {"$ref": "#/definitions/vpcConfig"},
757+
},
758+
},
722759
SERVE: {
723760
TYPE: OBJECT,
724761
ADDITIONAL_PROPERTIES: False,

src/sagemaker/feature_store/feature_processor/_config_uploader.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from sagemaker.inputs import TrainingInput
2727
from sagemaker.remote_function.core.stored_function import StoredFunction
2828
from sagemaker.remote_function.job import (
29-
_prepare_and_upload_dependencies,
29+
_prepare_and_upload_workspace,
3030
_prepare_and_upload_runtime_scripts,
3131
_JobSettings,
3232
RUNTIME_SCRIPTS_CHANNEL_NAME,
@@ -38,6 +38,7 @@
3838
RuntimeEnvironmentManager,
3939
)
4040
from sagemaker.remote_function.spark_config import SparkConfig
41+
from sagemaker.remote_function.workdir_config import WorkdirConfig
4142
from sagemaker.s3 import s3_path_join
4243

4344

@@ -62,9 +63,10 @@ def prepare_step_input_channel_for_spark_mode(
6263
dependencies_list_path = self.runtime_env_manager.snapshot(
6364
self.remote_decorator_config.dependencies
6465
)
65-
user_dependencies_s3uri = self._prepare_and_upload_dependencies(
66+
user_workspace_s3uri = self._prepare_and_upload_workspace(
6667
dependencies_list_path,
6768
self.remote_decorator_config.include_local_workdir,
69+
self.remote_decorator_config.workdir_config,
6870
self.remote_decorator_config.pre_execution_commands,
6971
self.remote_decorator_config.pre_execution_script,
7072
s3_base_uri,
@@ -92,7 +94,7 @@ def prepare_step_input_channel_for_spark_mode(
9294
distribution=S3_DATA_DISTRIBUTION_TYPE,
9395
)
9496
}
95-
if user_dependencies_s3uri:
97+
if user_workspace_s3uri:
9698
input_data_config[REMOTE_FUNCTION_WORKSPACE] = TrainingInput(
9799
s3_data=s3_path_join(s3_base_uri, REMOTE_FUNCTION_WORKSPACE),
98100
s3_data_type="S3Prefix",
@@ -126,10 +128,11 @@ def _prepare_and_upload_callable(
126128
)
127129
stored_function.save(func)
128130

129-
def _prepare_and_upload_dependencies(
131+
def _prepare_and_upload_workspace(
130132
self,
131133
local_dependencies_path: str,
132134
include_local_workdir: bool,
135+
workdir_config: WorkdirConfig,
133136
pre_execution_commands: List[str],
134137
pre_execution_script_local_path: str,
135138
s3_base_uri: str,
@@ -138,9 +141,10 @@ def _prepare_and_upload_dependencies(
138141
custom_file_filter: Optional[Callable[[str, List], List]] = None,
139142
) -> str:
140143
"""Upload the training step dependencies to S3 if present"""
141-
return _prepare_and_upload_dependencies(
144+
return _prepare_and_upload_workspace(
142145
local_dependencies_path=local_dependencies_path,
143146
include_local_workdir=include_local_workdir,
147+
workdir_config=workdir_config,
144148
pre_execution_commands=pre_execution_commands,
145149
pre_execution_script_local_path=pre_execution_script_local_path,
146150
s3_base_uri=s3_base_uri,

src/sagemaker/local/entities.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,9 @@ def start(self, input_data_config, output_data_config, hyperparameters, environm
215215
"""
216216
for channel in input_data_config:
217217
if channel["DataSource"] and "S3DataSource" in channel["DataSource"]:
218-
data_distribution = channel["DataSource"]["S3DataSource"]["S3DataDistributionType"]
218+
data_distribution = channel["DataSource"]["S3DataSource"].get(
219+
"S3DataDistributionType", None
220+
)
219221
data_uri = channel["DataSource"]["S3DataSource"]["S3Uri"]
220222
elif channel["DataSource"] and "FileDataSource" in channel["DataSource"]:
221223
data_distribution = channel["DataSource"]["FileDataSource"][
@@ -230,7 +232,7 @@ def start(self, input_data_config, output_data_config, hyperparameters, environm
230232
# use a single Data URI - this makes handling S3 and File Data easier down the stack
231233
channel["DataUri"] = data_uri
232234

233-
if data_distribution != "FullyReplicated":
235+
if data_distribution and data_distribution != "FullyReplicated":
234236
raise RuntimeError(
235237
"DataDistribution: %s is not currently supported in Local Mode"
236238
% data_distribution

src/sagemaker/local/image.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -812,7 +812,13 @@ def _create_docker_host(
812812
"networks": {"sagemaker-local": {"aliases": [host]}},
813813
}
814814

815-
if command != "process":
815+
is_train_with_entrypoint = False
816+
if command == "train" and self.container_entrypoint:
817+
# Remote function or Pipeline function step is translated into a training job
818+
# with container_entrypoint configured
819+
is_train_with_entrypoint = True
820+
821+
if command != "process" and not is_train_with_entrypoint:
816822
host_config["command"] = command
817823
else:
818824
if self.container_entrypoint:

src/sagemaker/local/local_session.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,10 @@ def create_training_job(
197197
AlgorithmSpecification["TrainingImage"],
198198
sagemaker_session=self.sagemaker_session,
199199
)
200+
if AlgorithmSpecification.get("ContainerEntrypoint", None):
201+
container.container_entrypoint = AlgorithmSpecification["ContainerEntrypoint"]
202+
if AlgorithmSpecification.get("ContainerArguments", None):
203+
container.container_arguments = AlgorithmSpecification["ContainerArguments"]
200204
training_job = _LocalTrainingJob(container)
201205
hyperparameters = kwargs["HyperParameters"] if "HyperParameters" in kwargs else {}
202206
logger.info("Starting training job")
@@ -718,6 +722,8 @@ def _initialize(
718722
else load_sagemaker_config(s3_resource=self.s3_resource)
719723
)
720724
else:
725+
self.s3_resource = self.boto_session.resource("s3", region_name=self._region_name)
726+
self.s3_client = self.boto_session.client("s3", region_name=self._region_name)
721727
self.sagemaker_config = (
722728
sagemaker_config if sagemaker_config else load_sagemaker_config()
723729
)

src/sagemaker/local/pipeline.py

Lines changed: 54 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@
1515
from abc import ABC, abstractmethod
1616

1717
import json
18-
from copy import deepcopy
1918
from datetime import datetime
2019
from typing import Dict, List, Union
2120
from botocore.exceptions import ClientError
2221

2322
from sagemaker.workflow.conditions import ConditionTypeEnum
23+
from sagemaker.workflow.function_step import DelayedReturn
2424
from sagemaker.workflow.steps import StepTypeEnum, Step
2525
from sagemaker.workflow.step_collections import StepCollection
2626
from sagemaker.workflow.entities import PipelineVariable
@@ -87,7 +87,7 @@ def evaluate_step_arguments(self, step):
8787
def _parse_arguments(self, obj, step_name):
8888
"""Parse and evaluate arguments field"""
8989
if isinstance(obj, dict):
90-
obj_copy = deepcopy(obj)
90+
obj_copy = {}
9191
for k, v in obj.items():
9292
obj_copy[k] = self._parse_arguments(v, step_name)
9393
return obj_copy
@@ -108,16 +108,17 @@ def evaluate_pipeline_variable(self, pipeline_variable, step_name):
108108
elif isinstance(pipeline_variable, Parameter):
109109
value = self.execution.pipeline_parameters.get(pipeline_variable.name)
110110
elif isinstance(pipeline_variable, Join):
111-
evaluated = [
112-
str(self.evaluate_pipeline_variable(v, step_name)) for v in pipeline_variable.values
113-
]
114-
value = pipeline_variable.on.join(evaluated)
111+
value = self._evaluate_join_function(pipeline_variable, step_name)
115112
elif isinstance(pipeline_variable, Properties):
116113
value = self._evaluate_property_reference(pipeline_variable, step_name)
117114
elif isinstance(pipeline_variable, ExecutionVariable):
118115
value = self._evaluate_execution_variable(pipeline_variable)
119116
elif isinstance(pipeline_variable, JsonGet):
120117
value = self._evaluate_json_get_function(pipeline_variable, step_name)
118+
elif isinstance(pipeline_variable, DelayedReturn):
119+
# DelayedReturn showing up in arguments, meaning that it's data referenced
120+
# We should convert it to JsonGet and evaluate the JsonGet object
121+
value = self._evaluate_json_get_function(pipeline_variable._to_json_get(), step_name)
121122
else:
122123
self.execution.update_step_failure(
123124
step_name, f"Unrecognized pipeline variable {pipeline_variable.expr}."
@@ -127,6 +128,13 @@ def evaluate_pipeline_variable(self, pipeline_variable, step_name):
127128
self.execution.update_step_failure(step_name, f"{pipeline_variable.expr} is undefined.")
128129
return value
129130

131+
def _evaluate_join_function(self, pipeline_variable, step_name):
132+
"""Evaluate join function runtime value"""
133+
evaluated = [
134+
str(self.evaluate_pipeline_variable(v, step_name)) for v in pipeline_variable.values
135+
]
136+
return pipeline_variable.on.join(evaluated)
137+
130138
def _evaluate_property_reference(self, pipeline_variable, step_name):
131139
"""Evaluate property reference runtime value."""
132140
try:
@@ -156,6 +164,43 @@ def _evaluate_execution_variable(self, pipeline_variable):
156164

157165
def _evaluate_json_get_function(self, pipeline_variable, step_name):
158166
"""Evaluate join function runtime value."""
167+
s3_bucket = None
168+
s3_key = None
169+
try:
170+
if pipeline_variable.property_file:
171+
s3_bucket, s3_key = self._evaluate_json_get_property_file_reference(
172+
pipeline_variable=pipeline_variable, step_name=step_name
173+
)
174+
else:
175+
# JsonGet's s3_uri can only be a Join function
176+
# This has been validated in _validate_json_get_function
177+
s3_uri = self._evaluate_join_function(pipeline_variable.s3_uri, step_name)
178+
s3_bucket, s3_key = parse_s3_url(s3_uri)
179+
180+
file_content = self.sagemaker_session.read_s3_file(s3_bucket, s3_key)
181+
file_json = json.loads(file_content)
182+
return get_using_dot_notation(file_json, pipeline_variable.json_path)
183+
except ClientError as e:
184+
self.execution.update_step_failure(
185+
step_name,
186+
f"Received an error while reading file {s3_path_join('s3://', s3_bucket, s3_key)} "
187+
f"from S3: {e.response.get('Code')}: {e.response.get('Message')}",
188+
)
189+
except json.JSONDecodeError:
190+
self.execution.update_step_failure(
191+
step_name,
192+
f"Contents of file {s3_path_join('s3://', s3_bucket, s3_key)} are not "
193+
f"in valid JSON format.",
194+
)
195+
except ValueError:
196+
self.execution.update_step_failure(
197+
step_name, f"Invalid json path '{pipeline_variable.json_path}'"
198+
)
199+
200+
def _evaluate_json_get_property_file_reference(
201+
self, pipeline_variable: JsonGet, step_name: str
202+
):
203+
"""Evaluate JsonGet's property file reference to get s3 bucket and key"""
159204
property_file_reference = pipeline_variable.property_file
160205
property_file = None
161206
if isinstance(property_file_reference, str):
@@ -180,28 +225,9 @@ def _evaluate_json_get_function(self, pipeline_variable, step_name):
180225
processing_output_s3_bucket = processing_step_response["ProcessingOutputConfig"]["Outputs"][
181226
property_file.output_name
182227
]["S3Output"]["S3Uri"]
183-
try:
184-
s3_bucket, s3_key_prefix = parse_s3_url(processing_output_s3_bucket)
185-
file_content = self.sagemaker_session.read_s3_file(
186-
s3_bucket, s3_path_join(s3_key_prefix, property_file.path)
187-
)
188-
file_json = json.loads(file_content)
189-
return get_using_dot_notation(file_json, pipeline_variable.json_path)
190-
except ClientError as e:
191-
self.execution.update_step_failure(
192-
step_name,
193-
f"Received an error while file reading file '{property_file.path}' from S3: "
194-
f"{e.response.get('Code')}: {e.response.get('Message')}",
195-
)
196-
except json.JSONDecodeError:
197-
self.execution.update_step_failure(
198-
step_name,
199-
f"Contents of property file '{property_file.name}' are not in valid JSON format.",
200-
)
201-
except ValueError:
202-
self.execution.update_step_failure(
203-
step_name, f"Invalid json path '{pipeline_variable.json_path}'"
204-
)
228+
s3_bucket, s3_key_prefix = parse_s3_url(processing_output_s3_bucket)
229+
s3_key = s3_path_join(s3_key_prefix, property_file.path)
230+
return s3_bucket, s3_key
205231

206232

207233
class _StepExecutor(ABC):

src/sagemaker/remote_function/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,5 @@
1515

1616
from sagemaker.remote_function.client import remote, RemoteExecutor # noqa: F401
1717
from sagemaker.remote_function.checkpoint_location import CheckpointLocation # noqa: F401
18+
from sagemaker.remote_function.workdir_config import WorkdirConfig # noqa: F401
19+
from sagemaker.remote_function.spark_config import SparkConfig # noqa: F401

0 commit comments

Comments
 (0)