Skip to content

Commit 14340ae

Browse files
author
EC2 Default User
committed
Improved docstrings, handled DatasetDefinition and FeatureStoreOutput, other minor changes.
1 parent a846b32 commit 14340ae

File tree

4 files changed

+137
-25
lines changed

4 files changed

+137
-25
lines changed

src/sagemaker/local/entities.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,16 @@
3838
HEALTH_CHECK_TIMEOUT_LIMIT = 120
3939

4040

41-
class _LocalProcessingJob(object):
41+
class _LocalProcessingJob:
4242
"""Defines and starts a local processing job."""
4343

4444
_STARTING = "Starting"
4545
_PROCESSING = "Processing"
4646
_COMPLETED = "Completed"
47-
_states = ["Starting", "Processing", "Completed"]
4847

4948
def __init__(self, container):
50-
"""
49+
"""Creates a local processing job.
50+
5151
Args:
5252
container: the local container object.
5353
"""
@@ -61,16 +61,21 @@ def __init__(self, container):
6161
self.environment = None
6262

6363
def start(self, processing_inputs, processing_output_config, environment, processing_job_name):
64-
"""
64+
"""Starts a local processing job.
65+
6566
Args:
6667
processing_inputs: The processing input configuration.
6768
processing_output_config: The processing input configuration.
6869
environment: The collection of environment variables passed to the job.
6970
processing_job_name: The processing job name.
7071
"""
72+
self.state = self._STARTING
7173

7274
for item in processing_inputs:
73-
if item["S3Input"]:
75+
if "DatasetDefinition" in item:
76+
raise RuntimeError("DatasetDefinition is not currently supported in Local Mode")
77+
78+
if "S3Input" in item and item["S3Input"]:
7479
data_uri = item["S3Input"]["S3Uri"]
7580
else:
7681
raise ValueError("Processing input must have a valid ['S3Input']")
@@ -111,10 +116,15 @@ def start(self, processing_inputs, processing_output_config, environment, proces
111116
processing_outputs = processing_output_config["Outputs"]
112117

113118
for item in processing_outputs:
114-
if item["S3Output"]:
119+
if "FeatureStoreOutput" in item:
120+
raise RuntimeError(
121+
"FeatureStoreOutput is not currently supported in Local Mode"
122+
)
123+
124+
if "S3Output" in item and item["S3Output"]:
115125
upload_mode = item["S3Output"]["S3UploadMode"]
116126
else:
117-
raise ValueError("Processing output must have a valid ['S3Output']")
127+
raise ValueError("Please specify a valid ['S3Output'] when using Local Mode.")
118128

119129
if upload_mode != "EndOfJob":
120130
raise RuntimeError(
@@ -137,7 +147,10 @@ def start(self, processing_inputs, processing_output_config, environment, proces
137147
self.state = self._COMPLETED
138148

139149
def describe(self):
140-
"""Describes a local processing job."""
150+
"""Describes a local processing job.
151+
152+
Returns: An object describing the processing job.
153+
"""
141154

142155
response = {
143156
"ProcessingJobArn": self.processing_job_name,

src/sagemaker/local/image.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,6 @@ def process(
120120
processing_output_config: The processing output configuration specification.
121121
environment (dict): The environment collection for the processing job.
122122
processing_job_name (str): Name of the local processing job being run.
123-
124-
Returns:
125123
"""
126124

127125
self.container_root = self._create_tmp_folder()
@@ -383,8 +381,6 @@ def write_processing_config_files(
383381
This method writes the hyperparameters, resources and input data
384382
configuration files.
385383
386-
Returns: None
387-
388384
Args:
389385
host (str): Host to write the configuration for
390386
environment (dict): Environment variable collection.
@@ -509,11 +505,15 @@ def _prepare_training_volumes(
509505
return volumes
510506

511507
def _prepare_processing_volumes(self, data_dir, processing_inputs, processing_output_config):
512-
"""
508+
"""Prepares local container volumes for the processing job.
509+
513510
Args:
514-
data_dir:
515-
processing_inputs:
516-
processing_output_config:
511+
data_dir: The local data directory.
512+
processing_inputs: The configuration of processing inputs.
513+
processing_output_config: The configuration of processing outputs.
514+
515+
Returns:
516+
The volumes configuration.
517517
"""
518518
shared_dir = os.path.join(self.container_root, "shared")
519519
volumes = []
@@ -524,9 +524,6 @@ def _prepare_processing_volumes(self, data_dir, processing_inputs, processing_ou
524524
uri = item["DataUri"]
525525
input_container_dir = item["S3Input"]["LocalPath"]
526526

527-
# input_dir = os.path.join(data_dir, "input", input_name)
528-
# os.makedirs(input_dir)
529-
530527
data_source = sagemaker.local.data.get_data_source_instance(uri, self.sagemaker_session)
531528
volumes.append(_Volume(data_source.get_root_dir(), input_container_dir))
532529

@@ -545,10 +542,11 @@ def _prepare_processing_volumes(self, data_dir, processing_inputs, processing_ou
545542
return volumes
546543

547544
def _upload_processing_outputs(self, data_dir, processing_output_config):
548-
"""
545+
"""Uploads processing outputs to Amazon S3.
546+
549547
Args:
550-
data_dir:
551-
processing_output_config:
548+
data_dir: The local data directory.
549+
processing_output_config: The processing output configuration.
552550
"""
553551
if processing_output_config and "Outputs" in processing_output_config:
554552
for item in processing_output_config["Outputs"]:
@@ -711,7 +709,7 @@ def _create_docker_host(self, host, environment, optml_subdirs, command, volumes
711709

712710
host_config = {
713711
"image": self.image,
714-
"container_name": container_name_prefix + "-" + host,
712+
"container_name": f"{container_name_prefix}-{host}",
715713
"stdin_open": True,
716714
"tty": True,
717715
"volumes": [v.map for v in optml_volumes],

src/sagemaker/local/local_session.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def create_processing_job(
7575
ProcessingOutputConfig=None,
7676
**kwargs
7777
):
78-
"""Create a processing job in Local Mode
78+
"""Creates a processing job in Local Mode
7979
8080
Args:
8181
ProcessingJobName(str): local processing job name.
@@ -128,7 +128,7 @@ def create_processing_job(
128128
LocalSagemakerClient._processing_jobs[ProcessingJobName] = processing_job
129129

130130
def describe_processing_job(self, ProcessingJobName):
131-
"""Describe a local processing job.
131+
"""Describes a local processing job.
132132
133133
Args:
134134
ProcessingJobName(str): Processing job name to describe.

tests/unit/test_local_session.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,107 @@ def test_create_processing_job_invalid_upload_mode(process, LocalSession):
244244
)
245245

246246

247+
@patch("sagemaker.local.image._SageMakerContainer.process")
248+
@patch("sagemaker.local.local_session.LocalSession")
249+
def test_create_processing_job_invalid_processing_input(process, LocalSession):
250+
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
251+
252+
instance_count = 2
253+
image = "my-docker-image:1.0"
254+
255+
app_spec = {"ImageUri": image}
256+
resource_config = {"ClusterConfig": {"InstanceCount": instance_count, "InstanceType": "local"}}
257+
environment = {"Var1": "Value1"}
258+
processing_inputs = [
259+
{
260+
"InputName": "input1",
261+
"DatasetDefinition": {
262+
"AthenaDatasetDefinition": {
263+
"Catalog": "cat1",
264+
"Database": "db1",
265+
"OutputS3Uri": "s3://bucket_name/prefix/",
266+
"QueryString": "SELECT * FROM SOMETHING",
267+
},
268+
"DataDistributionType": "FullyReplicated",
269+
"InputMode": "File",
270+
"LocalPath": "/opt/ml/processing/input/athena",
271+
},
272+
}
273+
]
274+
processing_output_config = {
275+
"Outputs": [
276+
{
277+
"OutputName": "output1",
278+
"S3Output": {
279+
"LocalPath": "/opt/ml/processing/output/output1",
280+
"S3Uri": "s3://some-bucket/some-path/output1",
281+
"S3UploadMode": "Continuous",
282+
},
283+
}
284+
]
285+
}
286+
with pytest.raises(RuntimeError):
287+
local_sagemaker_client.create_processing_job(
288+
"my-processing-job",
289+
app_spec,
290+
resource_config,
291+
environment,
292+
processing_inputs,
293+
processing_output_config,
294+
)
295+
296+
297+
@patch("sagemaker.local.image._SageMakerContainer.process")
298+
@patch("sagemaker.local.local_session.LocalSession")
299+
def test_create_processing_job_invalid_processing_output(process, LocalSession):
300+
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()
301+
302+
instance_count = 2
303+
image = "my-docker-image:1.0"
304+
305+
app_spec = {"ImageUri": image}
306+
resource_config = {"ClusterConfig": {"InstanceCount": instance_count, "InstanceType": "local"}}
307+
environment = {"Var1": "Value1"}
308+
processing_inputs = [
309+
{
310+
"InputName": "input1",
311+
"S3Input": {
312+
"LocalPath": "/opt/ml/processing/input/input1",
313+
"S3Uri": "s3://some-bucket/some-path/input1",
314+
"S3DataDistributionType": "FullyReplicated",
315+
"S3InputMode": "File",
316+
},
317+
},
318+
{
319+
"InputName": "input2",
320+
"S3Input": {
321+
"LocalPath": "/opt/ml/processing/input/input2",
322+
"S3Uri": "s3://some-bucket/some-path/input2",
323+
"S3DataDistributionType": "FullyReplicated",
324+
"S3CompressionType": "None",
325+
"S3InputMode": "File",
326+
},
327+
},
328+
]
329+
processing_output_config = {
330+
"Outputs": [
331+
{
332+
"OutputName": "output1",
333+
"FeatureStoreOutput": {"FeatureGroupName": "Group1"},
334+
}
335+
]
336+
}
337+
with pytest.raises(RuntimeError):
338+
local_sagemaker_client.create_processing_job(
339+
"my-processing-job",
340+
app_spec,
341+
resource_config,
342+
environment,
343+
processing_inputs,
344+
processing_output_config,
345+
)
346+
347+
247348
@patch("sagemaker.local.local_session.LocalSession")
248349
def test_describe_invalid_processing_job(*args):
249350
local_sagemaker_client = sagemaker.local.local_session.LocalSagemakerClient()

0 commit comments

Comments
 (0)