Skip to content

Commit 2bb37c5

Browse files
author
EC2 Default User
committed
Improved abstraction on Processor objects when handling local session. Other minor changes.
1 parent 14340ae commit 2bb37c5

File tree

6 files changed

+46
-74
lines changed

6 files changed

+46
-74
lines changed

src/sagemaker/local/entities.py

Lines changed: 22 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -75,41 +75,30 @@ def start(self, processing_inputs, processing_output_config, environment, proces
7575
if "DatasetDefinition" in item:
7676
raise RuntimeError("DatasetDefinition is not currently supported in Local Mode")
7777

78-
if "S3Input" in item and item["S3Input"]:
79-
data_uri = item["S3Input"]["S3Uri"]
80-
else:
78+
try:
79+
s3_input = item["S3Input"]
80+
except KeyError:
8181
raise ValueError("Processing input must have a valid ['S3Input']")
8282

83-
if item["S3Input"]["S3InputMode"]:
84-
input_mode = item["S3Input"]["S3InputMode"]
85-
else:
86-
raise ValueError("Processing input must have a valid ['S3InputMode']")
87-
88-
item["DataUri"] = data_uri
89-
90-
if (
91-
"S3DataDistributionType" in item["S3Input"]
92-
and item["S3Input"]["S3DataDistributionType"] != "FullyReplicated"
93-
):
83+
item["DataUri"] = s3_input["S3Uri"]
9484

85+
if "S3InputMode" in s3_input and s3_input["S3InputMode"] != "File":
9586
raise RuntimeError(
96-
"DataDistribution: %s is not currently supported in Local Mode"
97-
% item["S3Input"]["S3DataDistributionType"]
87+
"S3InputMode: %s is not currently supported in Local Mode"
88+
% s3_input["S3InputMode"]
9889
)
9990

100-
if input_mode != "File":
91+
if ("S3DataDistributionType" in s3_input
92+
and s3_input["S3DataDistributionType"] != "FullyReplicated"):
10193
raise RuntimeError(
102-
"S3InputMode: %s is not currently supported in Local Mode" % input_mode
94+
"DataDistribution: %s is not currently supported in Local Mode"
95+
% s3_input["S3DataDistributionType"]
10396
)
10497

105-
if (
106-
"S3CompressionType" in item["S3Input"]
107-
and item["S3Input"]["S3CompressionType"] != "None"
108-
):
109-
98+
if "S3CompressionType" in s3_input and s3_input["S3CompressionType"] != "None":
11099
raise RuntimeError(
111100
"CompressionType: %s is not currently supported in Local Mode"
112-
% item["S3Input"]["S3CompressionType"]
101+
% s3_input["S3CompressionType"]
113102
)
114103

115104
if processing_output_config and "Outputs" in processing_output_config:
@@ -121,14 +110,15 @@ def start(self, processing_inputs, processing_output_config, environment, proces
121110
"FeatureStoreOutput is not currently supported in Local Mode"
122111
)
123112

124-
if "S3Output" in item and item["S3Output"]:
125-
upload_mode = item["S3Output"]["S3UploadMode"]
126-
else:
127-
raise ValueError("Please specify a valid ['S3Output'] when using Local Mode.")
113+
try:
114+
s3_output = item["S3Output"]
115+
except KeyError:
116+
raise ValueError("Processing output must have a valid ['S3Output']")
128117

129-
if upload_mode != "EndOfJob":
118+
if s3_output["S3UploadMode"] != "EndOfJob":
130119
raise RuntimeError(
131-
"UploadMode: %s is not currently supported in Local Mode." % upload_mode
120+
"UploadMode: %s is not currently supported in Local Mode."
121+
% s3_output["S3UploadMode"]
132122
)
133123

134124
self.start_time = datetime.datetime.now()
@@ -149,7 +139,8 @@ def start(self, processing_inputs, processing_output_config, environment, proces
149139
def describe(self):
150140
"""Describes a local processing job.
151141
152-
Returns: An object describing the processing job.
142+
Returns:
143+
An object describing the processing job.
153144
"""
154145

155146
response = {

src/sagemaker/local/image.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def process(
117117
118118
Args:
119119
processing_inputs (dict): The processing input specification.
120-
processing_output_config: The processing output configuration specification.
120+
processing_output_config (dict): The processing output configuration specification.
121121
environment (dict): The environment collection for the processing job.
122122
processing_job_name (str): Name of the local processing job being run.
123123
"""
@@ -158,15 +158,18 @@ def process(
158158
except RuntimeError as e:
159159
# _stream_output() doesn't have the command line. We will handle the exception
160160
# which contains the exit code and append the command line to it.
161-
msg = "Failed to run: %s, %s" % (compose_command, str(e))
162-
raise RuntimeError(msg)
161+
msg = f"Failed to run: {compose_command}"
162+
raise RuntimeError(msg) from e
163163
finally:
164164
# Uploading processing outputs back to Amazon S3.
165165
self._upload_processing_outputs(data_dir, processing_output_config)
166166

167-
# Deleting temporary directories.
168-
dirs_to_delete = [shared_dir, data_dir]
169-
self._cleanup(dirs_to_delete)
167+
try:
168+
# Deleting temporary directories.
169+
dirs_to_delete = [shared_dir, data_dir]
170+
self._cleanup(dirs_to_delete)
171+
except OSError:
172+
pass
170173

171174
# Print our Job Complete line to have a similar experience to training on SageMaker where
172175
# you see this line at the end.
@@ -910,10 +913,11 @@ def _check_output(cmd, *popenargs, **kwargs):
910913

911914

912915
def _create_processing_config_file_directories(root, host):
913-
"""
916+
"""Creates the directory for the processing config files.
917+
914918
Args:
915-
root:
916-
host:
919+
root: The root path.
920+
host: The current host.
917921
"""
918922
for d in ["config"]:
919923
os.makedirs(os.path.join(root, host, d))

src/sagemaker/processing.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -115,14 +115,10 @@ def __init__(
115115
self.arguments = None
116116

117117
if self.instance_type in ("local", "local_gpu"):
118-
self.sagemaker_session = sagemaker_session or LocalSession()
119-
if not isinstance(self.sagemaker_session, LocalSession):
120-
raise RuntimeError(
121-
"instance_type local or local_gpu is only supported with an"
122-
"instance of LocalSession"
123-
)
124-
else:
125-
self.sagemaker_session = sagemaker_session or Session()
118+
if not isinstance(sagemaker_session, LocalSession):
119+
sagemaker_session = LocalSession()
120+
121+
self.sagemaker_session = sagemaker_session or Session()
126122

127123
def run(
128124
self,

src/sagemaker/sklearn/processing.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from __future__ import absolute_import
1919

2020
from sagemaker import image_uris, Session
21-
from sagemaker.local import LocalSession
2221
from sagemaker.processing import ScriptProcessor
2322
from sagemaker.sklearn import defaults
2423

@@ -84,17 +83,9 @@ def __init__(
8483
if not command:
8584
command = ["python3"]
8685

87-
if instance_type in ("local", "local_gpu"):
88-
session = sagemaker_session or LocalSession()
89-
if not isinstance(session, LocalSession):
90-
raise RuntimeError(
91-
"instance_type local or local_gpu is only supported with an"
92-
"instance of LocalSession"
93-
)
94-
else:
95-
session = sagemaker_session or Session()
96-
86+
session = sagemaker_session or Session()
9787
region = session.boto_region_name
88+
9889
image_uri = image_uris.retrieve(
9990
defaults.SKLEARN_NAME, region, version=framework_version, instance_type=instance_type
10091
)

src/sagemaker/spark/processing.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
from sagemaker.local.image import _ecr_login_if_needed, _pull_image
3636
from sagemaker.processing import ProcessingInput, ProcessingOutput, ScriptProcessor
3737
from sagemaker.s3 import S3Uploader
38-
from sagemaker.local import LocalSession
3938
from sagemaker.session import Session
4039
from sagemaker.spark import defaults
4140

@@ -145,16 +144,7 @@ def __init__(
145144
self.history_server = None
146145
self._spark_event_logs_s3_uri = None
147146

148-
if instance_type in ("local", "local_gpu"):
149-
session = sagemaker_session or LocalSession()
150-
if not isinstance(session, LocalSession):
151-
raise RuntimeError(
152-
"instance_type local or local_gpu is only supported with an"
153-
"instance of LocalSession"
154-
)
155-
else:
156-
session = sagemaker_session or Session()
157-
147+
session = sagemaker_session or Session()
158148
region = session.boto_region_name
159149

160150
self.image_uri = self._retrieve_image_uri(

tests/integ/test_local_mode.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client,
5959

6060

6161
@pytest.fixture(scope="module")
62-
def image_uri(
62+
def sklearn_image_uri(
6363
sklearn_latest_version,
6464
sklearn_latest_py_version,
6565
cpu_instance_type,
@@ -355,12 +355,12 @@ def test_local_processing_sklearn(sagemaker_local_session, sklearn_latest_versio
355355

356356

357357
@pytest.mark.local_mode
358-
def test_local_processing_script_processor(sagemaker_local_session, image_uri):
358+
def test_local_processing_script_processor(sagemaker_local_session, sklearn_image_uri):
359359
input_file_path = os.path.join(DATA_DIR, "dummy_input.txt")
360360

361361
script_processor = ScriptProcessor(
362362
role="SageMakerRole",
363-
image_uri=image_uri,
363+
image_uri=sklearn_image_uri,
364364
command=["python3"],
365365
instance_count=1,
366366
instance_type="local",
@@ -419,6 +419,6 @@ def test_local_processing_script_processor(sagemaker_local_session, image_uri):
419419
"python3",
420420
"/opt/ml/processing/input/code/dummy_script.py",
421421
]
422-
assert job_description["AppSpecification"]["ImageUri"] == image_uri
422+
assert job_description["AppSpecification"]["ImageUri"] == sklearn_image_uri
423423

424424
assert job_description["Environment"] == {"DUMMY_ENVIRONMENT_VARIABLE": "dummy-value"}

0 commit comments

Comments
 (0)