Skip to content

Commit 6879d31

Browse files
ericangelokimlaurenyu
authored andcommitted
Fix SKLearnModel default account in image uri. (#624)
1 parent 7d2e06d commit 6879d31

File tree

4 files changed

+33
-14
lines changed

4 files changed

+33
-14
lines changed

src/sagemaker/sklearn/defaults.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,6 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
SKLEARN_NAME = 'scikit-learn'
16+
1517
SKLEARN_VERSION = '0.20.0'

src/sagemaker/sklearn/estimator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from sagemaker.estimator import Framework
1818
from sagemaker.fw_registry import default_framework_uri
1919
from sagemaker.fw_utils import framework_name_from_image, empty_framework_version_warning
20-
from sagemaker.sklearn.defaults import SKLEARN_VERSION
20+
from sagemaker.sklearn.defaults import SKLEARN_VERSION, SKLEARN_NAME
2121
from sagemaker.sklearn.model import SKLearnModel
2222
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
2323

@@ -28,7 +28,7 @@
2828
class SKLearn(Framework):
2929
"""Handle end-to-end training and deployment of custom Scikit-learn code."""
3030

31-
__framework_name__ = "scikit-learn"
31+
__framework_name__ = SKLEARN_NAME
3232

3333
def __init__(self, entry_point, framework_version=SKLEARN_VERSION, source_dir=None, hyperparameters=None,
3434
py_version='py3', image_name=None, **kwargs):
@@ -74,7 +74,7 @@ def __init__(self, entry_point, framework_version=SKLEARN_VERSION, source_dir=No
7474
train_instance_count = kwargs.get('train_instance_count')
7575
if train_instance_count:
7676
if train_instance_count != 1:
77-
raise AttributeError("SciKit-Learn does not support distributed training. "
77+
raise AttributeError("Scikit-Learn does not support distributed training. "
7878
"Please remove the 'train_instance_count' argument or set "
7979
"'train_instance_count=1' when initializing SKLearn.")
8080
super(SKLearn, self).__init__(entry_point, source_dir, hyperparameters, image_name=image_name,
@@ -154,6 +154,6 @@ def _validate_not_gpu_instance_type(training_instance_type):
154154
'ml.p3.xlarge', 'ml.p3.8xlarge', 'ml.p3.16xlarge']
155155

156156
if training_instance_type in gpu_instance_types:
157-
raise ValueError("GPU training in not supported for SciKit-Learn. "
157+
raise ValueError("GPU training in not supported for Scikit-Learn. "
158158
"Please pick a different instance type from here: "
159159
"https://aws.amazon.com/ec2/instance-types/")

src/sagemaker/sklearn/model.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@
1313
from __future__ import absolute_import
1414

1515
import sagemaker
16-
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix
16+
from sagemaker.fw_utils import model_code_key_prefix
17+
from sagemaker.fw_registry import default_framework_uri
1718
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
1819
from sagemaker.predictor import RealTimePredictor, npy_serializer, numpy_deserializer
19-
from sagemaker.sklearn.defaults import SKLEARN_VERSION
20+
from sagemaker.sklearn.defaults import SKLEARN_VERSION, SKLEARN_NAME
2021

2122

2223
class SKLearnPredictor(RealTimePredictor):
@@ -40,7 +41,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
4041
class SKLearnModel(FrameworkModel):
4142
"""An Scikit-learn SageMaker ``Model`` that can be deployed to a SageMaker ``Endpoint``."""
4243

43-
__framework_name__ = 'scikit-learn'
44+
__framework_name__ = SKLEARN_NAME
4445

4546
def __init__(self, model_data, role, entry_point, image=None, py_version='py3', framework_version=SKLEARN_VERSION,
4647
predictor_cls=SKLearnPredictor, model_server_workers=None, **kwargs):
@@ -77,16 +78,22 @@ def prepare_container_def(self, instance_type, accelerator_type=None):
7778
Args:
7879
instance_type (str): The EC2 instance type to deploy this Model to. For example, 'ml.p2.xlarge'.
7980
accelerator_type (str): The Elastic Inference accelerator type to deploy to the instance for loading and
80-
making inferences to the model. For example, 'ml.eia1.medium'.
81+
making inferences to the model. For example, 'ml.eia1.medium'. Note: accelerator types are not
82+
supported by SKLearnModel.
8183
8284
Returns:
8385
dict[str, str]: A container definition object usable with the CreateModel API.
8486
"""
87+
if accelerator_type:
88+
raise ValueError("Accelerator types are not supported for Scikit-Learn.")
89+
8590
deploy_image = self.image
8691
if not deploy_image:
87-
region_name = self.sagemaker_session.boto_session.region_name
88-
deploy_image = create_image_uri(region_name, self.__framework_name__, instance_type,
89-
self.framework_version, self.py_version, accelerator_type=accelerator_type)
92+
image_tag = "{}-{}-{}".format(self.framework_version, "cpu", self.py_version)
93+
deploy_image = default_framework_uri(
94+
self.__framework_name__,
95+
self.sagemaker_session.boto_region_name,
96+
image_tag)
9097

9198
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image)
9299
self._upload_code(deploy_key_prefix)

tests/unit/test_sklearn.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,17 @@ def test_train_image(sagemaker_session, sklearn_version):
128128
assert train_image == '246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:0.20.0-cpu-py3'
129129

130130

131-
def test_create_model(sagemaker_session, sklearn_version):
131+
def test_create_model(sagemaker_session):
132+
source_dir = 's3://mybucket/source'
133+
134+
sklearn_model = SKLearnModel(model_data=source_dir, role=ROLE, sagemaker_session=sagemaker_session,
135+
entry_point=SCRIPT_PATH)
136+
default_image_uri = _get_full_cpu_image_uri('0.20.0')
137+
model_values = sklearn_model.prepare_container_def(CPU)
138+
assert model_values['Image'] == default_image_uri
139+
140+
141+
def test_create_model_from_estimator(sagemaker_session, sklearn_version):
132142
container_log_level = '"logging.INFO"'
133143
source_dir = 's3://mybucket/source'
134144
sklearn = SKLearn(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
@@ -231,15 +241,15 @@ def test_fail_distributed_training(sagemaker_session, sklearn_version):
231241
SKLearn(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
232242
train_instance_count=DIST_INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE,
233243
py_version=PYTHON_VERSION, framework_version=sklearn_version)
234-
assert "SciKit-Learn does not support distributed training." in str(error)
244+
assert "Scikit-Learn does not support distributed training." in str(error)
235245

236246

237247
def test_fail_GPU_training(sagemaker_session, sklearn_version):
238248
with pytest.raises(ValueError) as error:
239249
SKLearn(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sagemaker_session,
240250
train_instance_type=GPU_INSTANCE_TYPE, py_version=PYTHON_VERSION,
241251
framework_version=sklearn_version)
242-
assert "GPU training in not supported for SciKit-Learn." in str(error)
252+
assert "GPU training in not supported for Scikit-Learn." in str(error)
243253

244254

245255
def test_model(sagemaker_session):

0 commit comments

Comments
 (0)