Skip to content

Commit 769c0ae

Browse files
committed
Merge branch 'master' of https://github.com/rahven14/sagemaker-python-sdk into feature/infer-framework-and-version
2 parents 9b39f6a + 6d24b28 commit 769c0ae

File tree

13 files changed

+562
-230
lines changed

13 files changed

+562
-230
lines changed

requirements/extras/test_requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@ awslogs==0.14.0
1212
black==22.3.0
1313
stopit==1.1.2
1414
apache-airflow==2.2.4
15-
apache-airflow-providers-amazon==3.0.0
15+
apache-airflow-providers-amazon==4.0.0
1616
attrs==20.3.0
1717
fabric==2.6.0
1818
requests==2.27.1
1919
sagemaker-experiments==0.1.35
2020
Jinja2==3.0.3
21+
pandas>=1.3.5,<1.5

src/sagemaker/spark/processing.py

Lines changed: 91 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,20 @@
3131
from io import BytesIO
3232
from urllib.parse import urlparse
3333

34+
from typing import Union, List, Dict, Optional
35+
3436
from sagemaker import image_uris
3537
from sagemaker.local.image import _ecr_login_if_needed, _pull_image
3638
from sagemaker.processing import ProcessingInput, ProcessingOutput, ScriptProcessor
3739
from sagemaker.s3 import S3Uploader
3840
from sagemaker.session import Session
41+
from sagemaker.network import NetworkConfig
3942
from sagemaker.spark import defaults
4043

44+
from sagemaker.workflow import is_pipeline_variable
45+
from sagemaker.workflow.entities import PipelineVariable
46+
from sagemaker.workflow.functions import Join
47+
4148
logger = logging.getLogger(__name__)
4249

4350

@@ -249,6 +256,12 @@ def run(
249256
"""
250257
self._current_job_name = self._generate_current_job_name(job_name=job_name)
251258

259+
if is_pipeline_variable(submit_app):
260+
raise ValueError(
261+
"submit_app argument has to be a valid S3 URI or local file path "
262+
+ "rather than a pipeline variable"
263+
)
264+
252265
return super().run(
253266
submit_app,
254267
inputs,
@@ -437,9 +450,14 @@ def _stage_submit_deps(self, submit_deps, input_channel_name):
437450

438451
use_input_channel = False
439452
spark_opt_s3_uris = []
453+
spark_opt_s3_uris_has_pipeline_var = False
440454

441455
with tempfile.TemporaryDirectory() as tmpdir:
442456
for dep_path in submit_deps:
457+
if is_pipeline_variable(dep_path):
458+
spark_opt_s3_uris.append(dep_path)
459+
spark_opt_s3_uris_has_pipeline_var = True
460+
continue
443461
dep_url = urlparse(dep_path)
444462
# S3 URIs are included as-is in the spark-submit argument
445463
if dep_url.scheme in ["s3", "s3a"]:
@@ -482,11 +500,19 @@ def _stage_submit_deps(self, submit_deps, input_channel_name):
482500
destination=f"{self._conf_container_base_path}{input_channel_name}",
483501
input_name=input_channel_name,
484502
)
485-
spark_opt = ",".join(spark_opt_s3_uris + [input_channel.destination])
503+
spark_opt = (
504+
Join(on=",", values=spark_opt_s3_uris + [input_channel.destination])
505+
if spark_opt_s3_uris_has_pipeline_var
506+
else ",".join(spark_opt_s3_uris + [input_channel.destination])
507+
)
486508
# If no local files were uploaded, form the spark-submit option from a list of S3 URIs
487509
else:
488510
input_channel = None
489-
spark_opt = ",".join(spark_opt_s3_uris)
511+
spark_opt = (
512+
Join(on=",", values=spark_opt_s3_uris)
513+
if spark_opt_s3_uris_has_pipeline_var
514+
else ",".join(spark_opt_s3_uris)
515+
)
490516

491517
return input_channel, spark_opt
492518

@@ -592,6 +618,9 @@ def _validate_s3_uri(self, spark_output_s3_path):
592618
Args:
593619
spark_output_s3_path (str): The URI of the Spark output S3 Path.
594620
"""
621+
if is_pipeline_variable(spark_output_s3_path):
622+
return
623+
595624
if urlparse(spark_output_s3_path).scheme != "s3":
596625
raise ValueError(
597626
f"Invalid s3 path: {spark_output_s3_path}. Please enter something like "
@@ -650,22 +679,22 @@ class PySparkProcessor(_SparkProcessorBase):
650679

651680
def __init__(
652681
self,
653-
role,
654-
instance_type,
655-
instance_count,
656-
framework_version=None,
657-
py_version=None,
658-
container_version=None,
659-
image_uri=None,
660-
volume_size_in_gb=30,
661-
volume_kms_key=None,
662-
output_kms_key=None,
663-
max_runtime_in_seconds=None,
664-
base_job_name=None,
665-
sagemaker_session=None,
666-
env=None,
667-
tags=None,
668-
network_config=None,
682+
role: str,
683+
instance_type: Union[int, PipelineVariable],
684+
instance_count: Union[str, PipelineVariable],
685+
framework_version: Optional[str] = None,
686+
py_version: Optional[str] = None,
687+
container_version: Optional[str] = None,
688+
image_uri: Optional[Union[str, PipelineVariable]] = None,
689+
volume_size_in_gb: Union[int, PipelineVariable] = 30,
690+
volume_kms_key: Optional[Union[str, PipelineVariable]] = None,
691+
output_kms_key: Optional[Union[str, PipelineVariable]] = None,
692+
max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None,
693+
base_job_name: Optional[str] = None,
694+
sagemaker_session: Optional[Session] = None,
695+
env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
696+
tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
697+
network_config: Optional[NetworkConfig] = None,
669698
):
670699
"""Initialize an ``PySparkProcessor`` instance.
671700
@@ -795,20 +824,20 @@ def get_run_args(
795824

796825
def run(
797826
self,
798-
submit_app,
799-
submit_py_files=None,
800-
submit_jars=None,
801-
submit_files=None,
802-
inputs=None,
803-
outputs=None,
804-
arguments=None,
805-
wait=True,
806-
logs=True,
807-
job_name=None,
808-
experiment_config=None,
809-
configuration=None,
810-
spark_event_logs_s3_uri=None,
811-
kms_key=None,
827+
submit_app: str,
828+
submit_py_files: Optional[List[Union[str, PipelineVariable]]] = None,
829+
submit_jars: Optional[List[Union[str, PipelineVariable]]] = None,
830+
submit_files: Optional[List[Union[str, PipelineVariable]]] = None,
831+
inputs: Optional[List[ProcessingInput]] = None,
832+
outputs: Optional[List[ProcessingOutput]] = None,
833+
arguments: Optional[List[Union[str, PipelineVariable]]] = None,
834+
wait: bool = True,
835+
logs: bool = True,
836+
job_name: Optional[str] = None,
837+
experiment_config: Optional[Dict[str, str]] = None,
838+
configuration: Optional[Union[List[Dict], Dict]] = None,
839+
spark_event_logs_s3_uri: Optional[Union[str, PipelineVariable]] = None,
840+
kms_key: Optional[str] = None,
812841
):
813842
"""Runs a processing job.
814843
@@ -907,22 +936,22 @@ class SparkJarProcessor(_SparkProcessorBase):
907936

908937
def __init__(
909938
self,
910-
role,
911-
instance_type,
912-
instance_count,
913-
framework_version=None,
914-
py_version=None,
915-
container_version=None,
916-
image_uri=None,
917-
volume_size_in_gb=30,
918-
volume_kms_key=None,
919-
output_kms_key=None,
920-
max_runtime_in_seconds=None,
921-
base_job_name=None,
922-
sagemaker_session=None,
923-
env=None,
924-
tags=None,
925-
network_config=None,
939+
role: str,
940+
instance_type: Union[int, PipelineVariable],
941+
instance_count: Union[str, PipelineVariable],
942+
framework_version: Optional[str] = None,
943+
py_version: Optional[str] = None,
944+
container_version: Optional[str] = None,
945+
image_uri: Optional[Union[str, PipelineVariable]] = None,
946+
volume_size_in_gb: Union[int, PipelineVariable] = 30,
947+
volume_kms_key: Optional[Union[str, PipelineVariable]] = None,
948+
output_kms_key: Optional[Union[str, PipelineVariable]] = None,
949+
max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None,
950+
base_job_name: Optional[str] = None,
951+
sagemaker_session: Optional[Session] = None,
952+
env: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
953+
tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
954+
network_config: Optional[NetworkConfig] = None,
926955
):
927956
"""Initialize a ``SparkJarProcessor`` instance.
928957
@@ -1052,20 +1081,20 @@ def get_run_args(
10521081

10531082
def run(
10541083
self,
1055-
submit_app,
1056-
submit_class=None,
1057-
submit_jars=None,
1058-
submit_files=None,
1059-
inputs=None,
1060-
outputs=None,
1061-
arguments=None,
1062-
wait=True,
1063-
logs=True,
1064-
job_name=None,
1065-
experiment_config=None,
1066-
configuration=None,
1067-
spark_event_logs_s3_uri=None,
1068-
kms_key=None,
1084+
submit_app: str,
1085+
submit_class: Union[str, PipelineVariable],
1086+
submit_jars: Optional[List[Union[str, PipelineVariable]]] = None,
1087+
submit_files: Optional[List[Union[str, PipelineVariable]]] = None,
1088+
inputs: Optional[List[ProcessingInput]] = None,
1089+
outputs: Optional[List[ProcessingOutput]] = None,
1090+
arguments: Optional[List[Union[str, PipelineVariable]]] = None,
1091+
wait: bool = True,
1092+
logs: bool = True,
1093+
job_name: Optional[str] = None,
1094+
experiment_config: Optional[Dict[str, str]] = None,
1095+
configuration: Optional[Union[List[Dict], Dict]] = None,
1096+
spark_event_logs_s3_uri: Optional[Union[str, PipelineVariable]] = None,
1097+
kms_key: Optional[str] = None,
10691098
):
10701099
"""Runs a processing job.
10711100

src/sagemaker/utils.py

Lines changed: 47 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -759,32 +759,55 @@ def update_container_with_inference_params(
759759
dict: dict with inference recommender params
760760
"""
761761

762-
if (
763-
framework is not None
764-
and framework_version is not None
765-
and nearest_model_name is not None
766-
and data_input_configuration is not None
767-
):
762+
if framework is not None and framework_version is not None and nearest_model_name is not None:
768763
if container_list is not None:
769764
for obj in container_list:
770-
obj.update(
771-
{
772-
"Framework": framework,
773-
"FrameworkVersion": framework_version,
774-
"NearestModelName": nearest_model_name,
775-
"ModelInput": {
776-
"DataInputConfig": data_input_configuration,
777-
},
778-
}
765+
construct_container_object(
766+
obj, data_input_configuration, framework, framework_version, nearest_model_name
779767
)
780768
if container_obj is not None:
781-
container_obj.update(
782-
{
783-
"Framework": framework,
784-
"FrameworkVersion": framework_version,
785-
"NearestModelName": nearest_model_name,
786-
"ModelInput": {
787-
"DataInputConfig": data_input_configuration,
788-
},
789-
}
769+
construct_container_object(
770+
container_obj,
771+
data_input_configuration,
772+
framework,
773+
framework_version,
774+
nearest_model_name,
790775
)
776+
777+
778+
def construct_container_object(
779+
obj, data_input_configuration, framework, framework_version, nearest_model_name
780+
):
781+
"""Function to construct container object.
782+
783+
Args:
784+
framework (str): Machine learning framework of the model package container image
785+
(default: None).
786+
framework_version (str): Framework version of the Model Package Container Image
787+
(default: None).
788+
nearest_model_name (str): Name of a pre-trained machine learning benchmarked by
789+
Amazon SageMaker Inference Recommender (default: None).
790+
data_input_configuration (str): Input object for the model (default: None).
791+
container_obj (dict): object to be updated.
792+
container_list (list): list to be updated.
793+
794+
Returns:
795+
dict: container object
796+
"""
797+
798+
obj.update(
799+
{
800+
"Framework": framework,
801+
"FrameworkVersion": framework_version,
802+
"NearestModelName": nearest_model_name,
803+
}
804+
)
805+
806+
if data_input_configuration is not None:
807+
obj.update(
808+
{
809+
"ModelInput": {
810+
"DataInputConfig": data_input_configuration,
811+
},
812+
}
813+
)

src/sagemaker/workflow/condition_step.py

Lines changed: 6 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,17 @@
1515

1616
from typing import List, Union, Optional
1717

18-
import attr
1918

2019
from sagemaker.deprecations import deprecated_class
2120
from sagemaker.workflow.conditions import Condition
2221
from sagemaker.workflow.step_collections import StepCollection
22+
from sagemaker.workflow.functions import JsonGet as NewJsonGet
2323
from sagemaker.workflow.steps import (
2424
Step,
2525
StepTypeEnum,
2626
)
2727
from sagemaker.workflow.utilities import list_to_request
28-
from sagemaker.workflow.entities import (
29-
RequestType,
30-
PipelineVariable,
31-
)
28+
from sagemaker.workflow.entities import RequestType
3229
from sagemaker.workflow.properties import (
3330
Properties,
3431
PropertyFile,
@@ -93,16 +90,15 @@ def arguments(self) -> RequestType:
9390
@property
9491
def step_only_arguments(self):
9592
"""Argument dict pertaining to the step only, and not the `if_steps` or `else_steps`."""
96-
return self.conditions
93+
return [condition.to_request() for condition in self.conditions]
9794

9895
@property
9996
def properties(self):
10097
"""A simple Properties object with `Outcome` as the only property"""
10198
return self._properties
10299

103100

104-
@attr.s
105-
class JsonGet(PipelineVariable): # pragma: no cover
101+
class JsonGet(NewJsonGet): # pragma: no cover
106102
"""Get JSON properties from PropertyFiles.
107103
108104
Attributes:
@@ -112,28 +108,8 @@ class JsonGet(PipelineVariable): # pragma: no cover
112108
json_path (str): The JSON path expression to the requested value.
113109
"""
114110

115-
step: Step = attr.ib()
116-
property_file: Union[PropertyFile, str] = attr.ib()
117-
json_path: str = attr.ib()
118-
119-
@property
120-
def expr(self):
121-
"""The expression dict for a `JsonGet` function."""
122-
if isinstance(self.property_file, PropertyFile):
123-
name = self.property_file.name
124-
else:
125-
name = self.property_file
126-
return {
127-
"Std:JsonGet": {
128-
"PropertyFile": {"Get": f"Steps.{self.step.name}.PropertyFiles.{name}"},
129-
"Path": self.json_path,
130-
}
131-
}
132-
133-
@property
134-
def _referenced_steps(self) -> List[str]:
135-
"""List of step names that this function depends on."""
136-
return [self.step.name]
111+
def __init__(self, step: Step, property_file: Union[PropertyFile, str], json_path: str):
112+
super().__init__(step_name=step.name, property_file=property_file, json_path=json_path)
137113

138114

139115
JsonGet = deprecated_class(JsonGet, "JsonGet")

0 commit comments

Comments
 (0)