Skip to content

Commit b2f4cb5

Browse files
author
Verdi March
committed
Re-enable the new SKLearnProcessor
1 parent 39158b8 commit b2f4cb5

File tree

5 files changed

+157
-118
lines changed

5 files changed

+157
-118
lines changed

src/sagemaker/sklearn/processing.py

Lines changed: 45 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -17,93 +17,67 @@
1717
"""
1818
from __future__ import absolute_import
1919

20-
from sagemaker import image_uris, Session
21-
from sagemaker.processing import ScriptProcessor
22-
from sagemaker.sklearn import defaults
20+
from sagemaker.processing import FrameworkProcessor
21+
from sagemaker.sklearn.estimator import SKLearn
2322

2423

25-
class SKLearnProcessor(ScriptProcessor):
26-
"""Handles Amazon SageMaker processing tasks for jobs using scikit-learn."""
24+
class SKLearnProcessor(FrameworkProcessor):
25+
"""Initialize an ``SKLearnProcessor`` instance.
26+
27+
The SKLearnProcessor handles Amazon SageMaker processing tasks for jobs using scikit-learn.
28+
29+
Unless ``image_uri`` is specified, the scikit-learn environment is an
30+
Amazon-built Docker container that executes functions defined in the supplied
31+
``code`` Python script.
32+
33+
The arguments have the exact same meaning as in ``FrameworkProcessor``.
34+
35+
.. tip::
36+
37+
You can find additional parameters for initializing this class at
38+
:class:`~sagemaker.processing.FrameworkProcessor`.
39+
"""
40+
41+
estimator_cls = SKLearn
2742

2843
def __init__(
2944
self,
30-
framework_version,
45+
framework_version, # New arg
3146
role,
32-
instance_type,
3347
instance_count,
48+
instance_type,
49+
py_version="py3", # New kwarg
50+
image_uri=None,
3451
command=None,
3552
volume_size_in_gb=30,
3653
volume_kms_key=None,
3754
output_kms_key=None,
55+
code_location=None, # New arg
3856
max_runtime_in_seconds=None,
3957
base_job_name=None,
4058
sagemaker_session=None,
4159
env=None,
4260
tags=None,
4361
network_config=None,
4462
):
45-
"""Initialize an ``SKLearnProcessor`` instance.
46-
47-
The SKLearnProcessor handles Amazon SageMaker processing tasks for jobs using scikit-learn.
48-
49-
Args:
50-
framework_version (str): The version of scikit-learn.
51-
role (str): An AWS IAM role name or ARN. The Amazon SageMaker training jobs
52-
and APIs that create Amazon SageMaker endpoints use this role
53-
to access training data and model artifacts. After the endpoint
54-
is created, the inference code might use the IAM role, if it
55-
needs to access an AWS resource.
56-
instance_type (str): Type of EC2 instance to use for
57-
processing, for example, 'ml.c4.xlarge'.
58-
instance_count (int): The number of instances to run
59-
the Processing job with. Defaults to 1.
60-
command ([str]): The command to run, along with any command-line flags.
61-
Example: ["python3", "-v"]. If not provided, ["python3"] or ["python2"]
62-
will be chosen based on the py_version parameter.
63-
volume_size_in_gb (int): Size in GB of the EBS volume to
64-
use for storing data during processing (default: 30).
65-
volume_kms_key (str): A KMS key for the processing
66-
volume.
67-
output_kms_key (str): The KMS key id for all ProcessingOutputs.
68-
max_runtime_in_seconds (int): Timeout in seconds.
69-
After this amount of time Amazon SageMaker terminates the job
70-
regardless of its current status.
71-
base_job_name (str): Prefix for processing name. If not specified,
72-
the processor generates a default job name, based on the
73-
training image name and current timestamp.
74-
sagemaker_session (sagemaker.session.Session): Session object which
75-
manages interactions with Amazon SageMaker APIs and any other
76-
AWS services needed. If not specified, the processor creates one
77-
using the default AWS configuration chain.
78-
env (dict): Environment variables to be passed to the processing job.
79-
tags ([dict]): List of tags to be passed to the processing job.
80-
network_config (sagemaker.network.NetworkConfig): A NetworkConfig
81-
object that configures network isolation, encryption of
82-
inter-container traffic, security group IDs, and subnets.
83-
"""
84-
if not command:
85-
command = ["python3"]
86-
87-
session = sagemaker_session or Session()
88-
region = session.boto_region_name
89-
90-
image_uri = image_uris.retrieve(
91-
defaults.SKLEARN_NAME, region, version=framework_version, instance_type=instance_type
92-
)
93-
94-
super(SKLearnProcessor, self).__init__(
95-
role=role,
96-
image_uri=image_uri,
97-
instance_count=instance_count,
98-
instance_type=instance_type,
99-
command=command,
100-
volume_size_in_gb=volume_size_in_gb,
101-
volume_kms_key=volume_kms_key,
102-
output_kms_key=output_kms_key,
103-
max_runtime_in_seconds=max_runtime_in_seconds,
104-
base_job_name=base_job_name,
105-
sagemaker_session=session,
106-
env=env,
107-
tags=tags,
108-
network_config=network_config,
63+
"""This processor executes a Python script in a scikit-learn execution environment."""
64+
super().__init__(
65+
self.estimator_cls,
66+
framework_version,
67+
role,
68+
instance_count,
69+
instance_type,
70+
py_version,
71+
image_uri,
72+
command,
73+
volume_size_in_gb,
74+
volume_kms_key,
75+
output_kms_key,
76+
code_location,
77+
max_runtime_in_seconds,
78+
base_job_name,
79+
sagemaker_session,
80+
env,
81+
tags,
82+
network_config,
10983
)

tests/integ/test_local_mode.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -349,12 +349,12 @@ def test_local_processing_sklearn(sagemaker_local_session_no_local_code, sklearn
349349

350350
job_description = sklearn_processor.latest_job.describe()
351351

352-
assert len(job_description["ProcessingInputs"]) == 2
352+
assert len(job_description["ProcessingInputs"]) == 3
353353
assert job_description["ProcessingResources"]["ClusterConfig"]["InstanceCount"] == 1
354354
assert job_description["ProcessingResources"]["ClusterConfig"]["InstanceType"] == "local"
355355
assert job_description["AppSpecification"]["ContainerEntrypoint"] == [
356356
"python3",
357-
"/opt/ml/processing/input/code/dummy_script.py",
357+
"/opt/ml/processing/input/entrypoint/runproc.py",
358358
]
359359
assert job_description["RoleArn"] == "<no_role>"
360360

tests/integ/test_processing.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def test_sklearn(sagemaker_session, sklearn_latest_version, cpu_instance_type):
139139

140140
job_description = sklearn_processor.latest_job.describe()
141141

142-
assert len(job_description["ProcessingInputs"]) == 2
142+
assert len(job_description["ProcessingInputs"]) == 3
143143
assert job_description["ProcessingResources"]["ClusterConfig"]["InstanceCount"] == 1
144144
assert (
145145
job_description["ProcessingResources"]["ClusterConfig"]["InstanceType"] == cpu_instance_type
@@ -148,7 +148,7 @@ def test_sklearn(sagemaker_session, sklearn_latest_version, cpu_instance_type):
148148
assert job_description["StoppingCondition"] == {"MaxRuntimeInSeconds": 86400}
149149
assert job_description["AppSpecification"]["ContainerEntrypoint"] == [
150150
"python3",
151-
"/opt/ml/processing/input/code/dummy_script.py",
151+
"/opt/ml/processing/input/entrypoint/runproc.py",
152152
]
153153
assert ROLE in job_description["RoleArn"]
154154

@@ -204,6 +204,7 @@ def test_sklearn_with_customizations(
204204
assert job_description["ProcessingInputs"][0]["InputName"] == "dummy_input"
205205

206206
assert job_description["ProcessingInputs"][1]["InputName"] == "code"
207+
assert job_description["ProcessingInputs"][2]["InputName"] == "entrypoint"
207208

208209
assert job_description["ProcessingJobName"].startswith("test-sklearn-with-customizations")
209210

@@ -221,7 +222,7 @@ def test_sklearn_with_customizations(
221222
assert job_description["AppSpecification"]["ContainerArguments"] == ["-v"]
222223
assert job_description["AppSpecification"]["ContainerEntrypoint"] == [
223224
"python3",
224-
"/opt/ml/processing/input/code/dummy_script.py",
225+
"/opt/ml/processing/input/entrypoint/runproc.py",
225226
]
226227
assert job_description["AppSpecification"]["ImageUri"] == image_uri
227228

@@ -288,6 +289,9 @@ def test_sklearn_with_custom_default_bucket(
288289
assert job_description["ProcessingInputs"][0]["InputName"] == "dummy_input"
289290
assert custom_bucket_name in job_description["ProcessingInputs"][0]["S3Input"]["S3Uri"]
290291

292+
assert job_description["ProcessingInputs"][1]["InputName"] == "code"
293+
assert custom_bucket_name in job_description["ProcessingInputs"][1]["S3Input"]["S3Uri"]
294+
291295
assert job_description["ProcessingInputs"][2]["InputName"] == "entrypoint"
292296
assert custom_bucket_name in job_description["ProcessingInputs"][2]["S3Input"]["S3Uri"]
293297

@@ -307,7 +311,7 @@ def test_sklearn_with_custom_default_bucket(
307311
assert job_description["AppSpecification"]["ContainerArguments"] == ["-v"]
308312
assert job_description["AppSpecification"]["ContainerEntrypoint"] == [
309313
"python3",
310-
"/opt/ml/processing/input/code/dummy_script.py",
314+
"/opt/ml/processing/input/entrypoint/runproc.py",
311315
]
312316
assert job_description["AppSpecification"]["ImageUri"] == image_uri
313317

@@ -343,6 +347,7 @@ def test_sklearn_with_no_inputs_or_outputs(
343347
job_description = sklearn_processor.latest_job.describe()
344348

345349
assert job_description["ProcessingInputs"][0]["InputName"] == "code"
350+
assert job_description["ProcessingInputs"][1]["InputName"] == "entrypoint"
346351

347352
assert job_description["ProcessingJobName"].startswith("test-sklearn-with-no-inputs")
348353

@@ -357,7 +362,7 @@ def test_sklearn_with_no_inputs_or_outputs(
357362
assert job_description["AppSpecification"]["ContainerArguments"] == ["-v"]
358363
assert job_description["AppSpecification"]["ContainerEntrypoint"] == [
359364
"python3",
360-
"/opt/ml/processing/input/code/dummy_script.py",
365+
"/opt/ml/processing/input/entrypoint/runproc.py",
361366
]
362367
assert job_description["AppSpecification"]["ImageUri"] == image_uri
363368

tests/integ/test_sklearn.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,20 @@ def sklearn_training_job(
4646
sagemaker_session.boto_region_name
4747

4848

49+
def test_framework_processing_job_with_deps(
50+
sagemaker_session,
51+
sklearn_latest_version,
52+
sklearn_latest_py_version,
53+
cpu_instance_type,
54+
):
55+
return _run_processing_job(
56+
sagemaker_session,
57+
cpu_instance_type,
58+
sklearn_latest_version,
59+
sklearn_latest_py_version,
60+
)
61+
62+
4963
def test_training_with_additional_hyperparameters(
5064
sagemaker_session,
5165
sklearn_latest_version,

0 commit comments

Comments
 (0)