Skip to content

Commit f135d56

Browse files
guoqiao1992metrizable
authored andcommitted
feature: add spark processor
1 parent 1a9abd4 commit f135d56

File tree

18 files changed

+3963
-6
lines changed

18 files changed

+3963
-6
lines changed
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
{
2+
"processing": {
3+
"processors": ["cpu"],
4+
"versions": {
5+
"2.4": {
6+
"py_versions": ["py37"],
7+
"registries": {
8+
"me-south-1": "750251592176",
9+
"ap-south-1": "105495057255",
10+
"eu-west-3": "136845547031",
11+
"us-east-2": "314815235551",
12+
"eu-west-1": "571004829621",
13+
"eu-central-1": "906073651304",
14+
"sa-east-1": "737130764395",
15+
"ap-east-1": "732049463269",
16+
"us-east-1": "173754725891",
17+
"ap-northeast-2": "860869212795",
18+
"eu-west-2": "836651553127",
19+
"ap-northeast-1": "411782140378",
20+
"us-west-2": "153931337802",
21+
"us-west-1": "667973535471",
22+
"ap-southeast-1": "759080221371",
23+
"ap-southeast-2": "440695851116",
24+
"ca-central-1": "446299261295",
25+
"cn-north-1": "671472414489",
26+
"cn-northwest-1": "844356804704",
27+
"eu-south-1": "753923664805",
28+
"af-south-1": "309385258863",
29+
"us-gov-west-1": "271483468897"
30+
},
31+
"repository": "sagemaker-spark-processing"
32+
}
33+
}
34+
}
35+
}

src/sagemaker/image_uris.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import re
2020

2121
from sagemaker import utils
22+
from sagemaker.spark import defaults
2223

2324
logger = logging.getLogger(__name__)
2425

@@ -33,6 +34,7 @@ def retrieve(
3334
instance_type=None,
3435
accelerator_type=None,
3536
image_scope=None,
37+
container_version=None,
3638
):
3739
"""Retrieves the ECR URI for the Docker image matching the given arguments.
3840
@@ -51,6 +53,7 @@ def retrieve(
5153
image_scope (str): The image type, i.e. what it is used for.
5254
Valid values: "training", "inference", "eia". If ``accelerator_type`` is set,
5355
``image_scope`` is ignored.
56+
container_version (str): the version of docker image
5457
5558
Returns:
5659
str: the ECR URI for the corresponding SageMaker Docker image.
@@ -63,7 +66,7 @@ def retrieve(
6366
version = _validate_version_and_set_if_needed(version, config, framework)
6467
version_config = config["versions"][_version_for_config(version, config)]
6568

66-
py_version = _validate_py_version_and_set_if_needed(py_version, version_config)
69+
py_version = _validate_py_version_and_set_if_needed(py_version, version_config, framework)
6770
version_config = version_config.get(py_version) or version_config
6871

6972
registry = _registry_from_region(region, version_config["registries"])
@@ -74,7 +77,9 @@ def retrieve(
7477
processor = _processor(
7578
instance_type, config.get("processors") or version_config.get("processors")
7679
)
77-
tag = _format_tag(version_config.get("tag_prefix", version), processor, py_version)
80+
tag = _format_tag(
81+
version_config.get("tag_prefix", version), processor, py_version, container_version
82+
)
7883

7984
if tag:
8085
repo += ":{}".format(tag)
@@ -103,7 +108,7 @@ def _config_for_framework_and_scope(framework, image_scope, accelerator_type=Non
103108
available_scopes[0],
104109
image_scope,
105110
)
106-
image_scope = available_scopes[0]
111+
image_scope = list(available_scopes)[0]
107112

108113
if not image_scope and "scope" in config and set(available_scopes) == {"training", "inference"}:
109114
logger.info(
@@ -212,7 +217,7 @@ def _processor(instance_type, available_processors):
212217
return processor
213218

214219

215-
def _validate_py_version_and_set_if_needed(py_version, version_config):
220+
def _validate_py_version_and_set_if_needed(py_version, version_config, framework):
216221
"""Checks if the Python version is one of the supported versions."""
217222
if "repository" in version_config:
218223
available_versions = version_config.get("py_versions")
@@ -224,6 +229,9 @@ def _validate_py_version_and_set_if_needed(py_version, version_config):
224229
logger.info("Ignoring unnecessary Python version: %s.", py_version)
225230
return None
226231

232+
if py_version is None and defaults.SPARK_NAME == framework:
233+
return None
234+
227235
if py_version is None and len(available_versions) == 1:
228236
logger.info("Defaulting to only available Python version: %s", available_versions[0])
229237
return available_versions[0]
@@ -242,6 +250,6 @@ def _validate_arg(arg, available_options, arg_name):
242250
)
243251

244252

245-
def _format_tag(tag_prefix, processor, py_version):
253+
def _format_tag(tag_prefix, processor, py_version, container_version):
246254
"""Creates a tag for the image URI."""
247-
return "-".join([x for x in (tag_prefix, processor, py_version) if x])
255+
return "-".join([x for x in (tag_prefix, processor, py_version, container_version) if x])

src/sagemaker/processing.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,10 @@ def run(
166166
if wait:
167167
self.latest_job.wait(logs=logs)
168168

169+
def _extend_processing_args(self, inputs, outputs, **kwargs): # pylint: disable=W0613
170+
"""Extend inputs and outputs based on extra parameters"""
171+
return inputs, outputs
172+
169173
def _normalize_args(self, job_name=None, arguments=None, inputs=None, outputs=None, code=None):
170174
"""Normalizes the arguments so that they can be passed to the job run
171175

src/sagemaker/spark/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Placeholder docstring"""
14+
from __future__ import absolute_import
15+
16+
from sagemaker.spark.processing import PySparkProcessor, SparkJarProcessor # noqa: F401

src/sagemaker/spark/defaults.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Default constants used by Spark processing."""
14+
from __future__ import absolute_import
15+
16+
SPARK_NAME = "spark"

0 commit comments

Comments
 (0)