|
17 | 17 | """
|
18 | 18 | from __future__ import absolute_import
|
19 | 19 |
|
20 |
| -from sagemaker import image_uris, Session |
21 |
| -from sagemaker.processing import ScriptProcessor |
22 |
| -from sagemaker.sklearn import defaults |
| 20 | +from sagemaker.processing import FrameworkProcessor |
| 21 | +from sagemaker.sklearn.estimator import SKLearn |
23 | 22 |
|
24 | 23 |
|
25 |
| -class SKLearnProcessor(ScriptProcessor): |
26 |
| - """Handles Amazon SageMaker processing tasks for jobs using scikit-learn.""" |
| 24 | +class SKLearnProcessor(FrameworkProcessor): |
| 25 | + """Initialize an ``SKLearnProcessor`` instance. |
| 26 | +
|
| 27 | + The SKLearnProcessor handles Amazon SageMaker processing tasks for jobs using scikit-learn. |
| 28 | +
|
| 29 | + Unless ``image_uri`` is specified, the scikit-learn environment is an |
| 30 | + Amazon-built Docker container that executes functions defined in the supplied |
| 31 | + ``code`` Python script. |
| 32 | +
|
| 33 | + The arguments have the exact same meaning as in ``FrameworkProcessor``. |
| 34 | +
|
| 35 | + .. tip:: |
| 36 | +
|
| 37 | + You can find additional parameters for initializing this class at |
| 38 | + :class:`~sagemaker.processing.FrameworkProcessor`. |
| 39 | + """ |
| 40 | + |
| 41 | + estimator_cls = SKLearn |
27 | 42 |
|
28 | 43 | def __init__(
|
29 | 44 | self,
|
30 |
| - framework_version, |
| 45 | + framework_version, # New arg |
31 | 46 | role,
|
32 |
| - instance_type, |
33 | 47 | instance_count,
|
| 48 | + instance_type, |
| 49 | + py_version="py3", # New kwarg |
| 50 | + image_uri=None, |
34 | 51 | command=None,
|
35 | 52 | volume_size_in_gb=30,
|
36 | 53 | volume_kms_key=None,
|
37 | 54 | output_kms_key=None,
|
| 55 | + code_location=None, # New arg |
38 | 56 | max_runtime_in_seconds=None,
|
39 | 57 | base_job_name=None,
|
40 | 58 | sagemaker_session=None,
|
41 | 59 | env=None,
|
42 | 60 | tags=None,
|
43 | 61 | network_config=None,
|
44 | 62 | ):
|
45 |
| - """Initialize an ``SKLearnProcessor`` instance. |
46 |
| -
|
47 |
| - The SKLearnProcessor handles Amazon SageMaker processing tasks for jobs using scikit-learn. |
48 |
| -
|
49 |
| - Args: |
50 |
| - framework_version (str): The version of scikit-learn. |
51 |
| - role (str): An AWS IAM role name or ARN. The Amazon SageMaker training jobs |
52 |
| - and APIs that create Amazon SageMaker endpoints use this role |
53 |
| - to access training data and model artifacts. After the endpoint |
54 |
| - is created, the inference code might use the IAM role, if it |
55 |
| - needs to access an AWS resource. |
56 |
| - instance_type (str): Type of EC2 instance to use for |
57 |
| - processing, for example, 'ml.c4.xlarge'. |
58 |
| - instance_count (int): The number of instances to run |
59 |
| - the Processing job with. Defaults to 1. |
60 |
| - command ([str]): The command to run, along with any command-line flags. |
61 |
| - Example: ["python3", "-v"]. If not provided, ["python3"] or ["python2"] |
62 |
| - will be chosen based on the py_version parameter. |
63 |
| - volume_size_in_gb (int): Size in GB of the EBS volume to |
64 |
| - use for storing data during processing (default: 30). |
65 |
| - volume_kms_key (str): A KMS key for the processing |
66 |
| - volume. |
67 |
| - output_kms_key (str): The KMS key id for all ProcessingOutputs. |
68 |
| - max_runtime_in_seconds (int): Timeout in seconds. |
69 |
| - After this amount of time Amazon SageMaker terminates the job |
70 |
| - regardless of its current status. |
71 |
| - base_job_name (str): Prefix for processing name. If not specified, |
72 |
| - the processor generates a default job name, based on the |
73 |
| - training image name and current timestamp. |
74 |
| - sagemaker_session (sagemaker.session.Session): Session object which |
75 |
| - manages interactions with Amazon SageMaker APIs and any other |
76 |
| - AWS services needed. If not specified, the processor creates one |
77 |
| - using the default AWS configuration chain. |
78 |
| - env (dict): Environment variables to be passed to the processing job. |
79 |
| - tags ([dict]): List of tags to be passed to the processing job. |
80 |
| - network_config (sagemaker.network.NetworkConfig): A NetworkConfig |
81 |
| - object that configures network isolation, encryption of |
82 |
| - inter-container traffic, security group IDs, and subnets. |
83 |
| - """ |
84 |
| - if not command: |
85 |
| - command = ["python3"] |
86 |
| - |
87 |
| - session = sagemaker_session or Session() |
88 |
| - region = session.boto_region_name |
89 |
| - |
90 |
| - image_uri = image_uris.retrieve( |
91 |
| - defaults.SKLEARN_NAME, region, version=framework_version, instance_type=instance_type |
92 |
| - ) |
93 |
| - |
94 |
| - super(SKLearnProcessor, self).__init__( |
95 |
| - role=role, |
96 |
| - image_uri=image_uri, |
97 |
| - instance_count=instance_count, |
98 |
| - instance_type=instance_type, |
99 |
| - command=command, |
100 |
| - volume_size_in_gb=volume_size_in_gb, |
101 |
| - volume_kms_key=volume_kms_key, |
102 |
| - output_kms_key=output_kms_key, |
103 |
| - max_runtime_in_seconds=max_runtime_in_seconds, |
104 |
| - base_job_name=base_job_name, |
105 |
| - sagemaker_session=session, |
106 |
| - env=env, |
107 |
| - tags=tags, |
108 |
| - network_config=network_config, |
| 63 | + """This processor executes a Python script in a scikit-learn execution environment.""" |
| 64 | + super().__init__( |
| 65 | + self.estimator_cls, |
| 66 | + framework_version, |
| 67 | + role, |
| 68 | + instance_count, |
| 69 | + instance_type, |
| 70 | + py_version, |
| 71 | + image_uri, |
| 72 | + command, |
| 73 | + volume_size_in_gb, |
| 74 | + volume_kms_key, |
| 75 | + output_kms_key, |
| 76 | + code_location, |
| 77 | + max_runtime_in_seconds, |
| 78 | + base_job_name, |
| 79 | + sagemaker_session, |
| 80 | + env, |
| 81 | + tags, |
| 82 | + network_config, |
109 | 83 | )
|
0 commit comments