Skip to content

Commit 575292a

Browse files
committed
Merge branch 'zwei' into require-framework-version-cleanup
2 parents c7bd228 + 66178b4 commit 575292a

File tree

6 files changed

+76
-39
lines changed

6 files changed

+76
-39
lines changed

doc/frameworks/chainer/using_chainer.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ directories ('train' and 'test').
141141
train_instance_type='ml.p3.2xlarge',
142142
train_instance_count=1,
143143
framework_version='5.0.0',
144+
py_version='py3',
144145
hyperparameters = {'epochs': 20, 'batch-size': 64, 'learning-rate': 0.1})
145146
chainer_estimator.fit({'train': 's3://my-data-bucket/path/to/my/training/data',
146147
'test': 's3://my-data-bucket/path/to/my/test/data'})
@@ -222,7 +223,8 @@ operation.
222223
chainer_estimator = Chainer(entry_point='train_and_deploy.py',
223224
train_instance_type='ml.p3.2xlarge',
224225
train_instance_count=1,
225-
framework_version='5.0.0')
226+
framework_version='5.0.0',
227+
py_version='py3')
226228
chainer_estimator.fit('s3://my_bucket/my_training_data/')
227229
228230
# Deploy my estimator to a SageMaker Endpoint and get a Predictor

src/sagemaker/chainer/estimator.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
framework_name_from_image,
2121
framework_version_from_tag,
2222
python_deprecation_warning,
23+
validate_version_or_image_args,
2324
)
2425
from sagemaker.chainer import defaults
2526
from sagemaker.chainer.model import ChainerModel
@@ -48,8 +49,8 @@ def __init__(
4849
additional_mpi_options=None,
4950
source_dir=None,
5051
hyperparameters=None,
51-
py_version="py3",
5252
framework_version=None,
53+
py_version=None,
5354
image_name=None,
5455
**kwargs
5556
):
@@ -100,11 +101,12 @@ def __init__(
100101
and values, but ``str()`` will be called to convert them before
101102
training.
102103
py_version (str): Python version you want to use for executing your
103-
model training code (default: 'py2'). One of 'py2' or 'py3'.
104+
model training code. Defaults to ``None``. Required unless ``image_name``
105+
is provided.
104106
framework_version (str): Chainer version you want to use for
105-
executing your model training code. List of supported versions
107+
executing your model training code. Defaults to ``None``. Required unless
108+
``image_name`` is provided. List of supported versions:
106109
https://github.com/aws/sagemaker-python-sdk#chainer-sagemaker-estimators.
107-
If not specified, this will default to 4.1.
108110
image_name (str): If specified, the estimator will use this image
109111
for training and hosting, instead of selecting the appropriate
110112
SageMaker official image based on framework_version and
@@ -114,6 +116,9 @@ def __init__(
114116
* ``123412341234.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0``
115117
* ``custom-image:latest``
116118
119+
If ``framework_version`` or ``py_version`` are ``None``, then
120+
``image_name`` is required. If also ``None``, then a ``ValueError``
121+
will be raised.
117122
**kwargs: Additional kwargs passed to the
118123
:class:`~sagemaker.estimator.Framework` constructor.
119124
@@ -123,18 +128,18 @@ def __init__(
123128
:class:`~sagemaker.estimator.Framework` and
124129
:class:`~sagemaker.estimator.EstimatorBase`.
125130
"""
131+
validate_version_or_image_args(framework_version, py_version, image_name)
132+
if py_version == "py2":
133+
logger.warning(
134+
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
135+
)
126136
self.framework_version = framework_version
137+
self.py_version = py_version
127138

128139
super(Chainer, self).__init__(
129140
entry_point, source_dir, hyperparameters, image_name=image_name, **kwargs
130141
)
131142

132-
if py_version == "py2":
133-
logger.warning(
134-
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
135-
)
136-
137-
self.py_version = py_version
138143
self.use_mpi = use_mpi
139144
self.num_processes = num_processes
140145
self.process_slots_per_host = process_slots_per_host
@@ -255,15 +260,19 @@ class constructor
255260
image_name = init_params.pop("image")
256261
framework, py_version, tag, _ = framework_name_from_image(image_name)
257262

263+
if tag is None:
264+
framework_version = None
265+
else:
266+
framework_version = framework_version_from_tag(tag)
267+
init_params["framework_version"] = framework_version
268+
init_params["py_version"] = py_version
269+
258270
if not framework:
259271
# If we were unable to parse the framework name from the image it is not one of our
260272
# officially supported images, in this case just add the image to the init params.
261273
init_params["image_name"] = image_name
262274
return init_params
263275

264-
init_params["py_version"] = py_version
265-
init_params["framework_version"] = framework_version_from_tag(tag)
266-
267276
training_job_name = init_params["base_job_name"]
268277

269278
if framework != cls.__framework_name__:

src/sagemaker/chainer/model.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,12 @@
1818
from sagemaker import fw_utils
1919

2020
import sagemaker
21-
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix, python_deprecation_warning
21+
from sagemaker.fw_utils import (
22+
create_image_uri,
23+
model_code_key_prefix,
24+
python_deprecation_warning,
25+
validate_version_or_image_args,
26+
)
2227
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
2328
from sagemaker.chainer import defaults
2429
from sagemaker.predictor import RealTimePredictor, npy_serializer, numpy_deserializer
@@ -62,8 +67,8 @@ def __init__(
6267
role,
6368
entry_point,
6469
image=None,
65-
py_version="py3",
6670
framework_version=None,
71+
py_version=None,
6772
predictor_cls=ChainerPredictor,
6873
model_server_workers=None,
6974
**kwargs
@@ -83,11 +88,15 @@ def __init__(
8388
hosting. If ``source_dir`` is specified, then ``entry_point``
8489
must point to a file located at the root of ``source_dir``.
8590
image (str): A Docker image URI (default: None). If not specified, a
86-
default image for Chainer will be used.
87-
py_version (str): Python version you want to use for executing your
88-
model training code (default: 'py2').
91+
default image for Chainer will be used. If ``framework_version``
92+
or ``py_version`` are ``None``, then ``image`` is required. If
93+
also ``None``, then a ``ValueError`` will be raised.
8994
framework_version (str): Chainer version you want to use for
90-
executing your model training code.
95+
executing your model training code. Defaults to ``None``. Required
96+
unless ``image`` is provided.
97+
py_version (str): Python version you want to use for executing your
98+
model training code. Defaults to ``None``. Required unless
99+
``image`` is provided.
91100
predictor_cls (callable[str, sagemaker.session.Session]): A function
92101
to call to create a predictor with an endpoint name and
93102
SageMaker ``Session``. If specified, ``deploy()`` returns the
@@ -104,16 +113,18 @@ def __init__(
104113
:class:`~sagemaker.model.FrameworkModel` and
105114
:class:`~sagemaker.model.Model`.
106115
"""
107-
super(ChainerModel, self).__init__(
108-
model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs
109-
)
116+
validate_version_or_image_args(framework_version, py_version, image)
110117
if py_version == "py2":
111118
logger.warning(
112119
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
113120
)
114-
115-
self.py_version = py_version
116121
self.framework_version = framework_version
122+
self.py_version = py_version
123+
124+
super(ChainerModel, self).__init__(
125+
model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs
126+
)
127+
117128
self.model_server_workers = model_server_workers
118129

119130
def prepare_container_def(self, instance_type, accelerator_type=None):

tests/integ/test_chainer_train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,14 +100,16 @@ def test_attach_deploy(sagemaker_session, chainer_full_version, cpu_instance_typ
100100

101101

102102
@pytest.mark.local_mode
103-
def test_deploy_model(chainer_local_training_job, sagemaker_local_session):
103+
def test_deploy_model(chainer_local_training_job, sagemaker_local_session, chainer_full_version):
104104
script_path = os.path.join(DATA_DIR, "chainer_mnist", "mnist.py")
105105

106106
model = ChainerModel(
107107
chainer_local_training_job.model_data,
108108
"SageMakerRole",
109109
entry_point=script_path,
110110
sagemaker_session=sagemaker_local_session,
111+
framework_version=chainer_full_version,
112+
py_version=PYTHON_VERSION,
111113
)
112114

113115
predictor = model.deploy(1, "local")

tests/integ/test_tuner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -689,14 +689,15 @@ def test_tuning_tf_vpc_multi(sagemaker_session, cpu_instance_type):
689689

690690

691691
@pytest.mark.canary_quick
692-
def test_tuning_chainer(sagemaker_session, cpu_instance_type):
692+
def test_tuning_chainer(sagemaker_session, chainer_full_version, cpu_instance_type):
693693
with timeout(minutes=TUNING_DEFAULT_TIMEOUT_MINUTES):
694694
script_path = os.path.join(DATA_DIR, "chainer_mnist", "mnist.py")
695695
data_path = os.path.join(DATA_DIR, "chainer_mnist")
696696

697697
estimator = Chainer(
698698
entry_point=script_path,
699699
role="SageMakerRole",
700+
framework_version=chainer_full_version,
700701
py_version=PYTHON_VERSION,
701702
train_instance_count=1,
702703
train_instance_type=cpu_instance_type,

tests/unit/test_chainer.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ def test_create_model(sagemaker_session, chainer_version):
301301
assert model.vpc_config is None
302302

303303

304-
def test_create_model_with_optional_params(sagemaker_session):
304+
def test_create_model_with_optional_params(sagemaker_session, chainer_version):
305305
container_log_level = '"logging.INFO"'
306306
source_dir = "s3://mybucket/source"
307307
enable_cloudwatch_metrics = "true"
@@ -312,6 +312,7 @@ def test_create_model_with_optional_params(sagemaker_session):
312312
train_instance_count=INSTANCE_COUNT,
313313
train_instance_type=INSTANCE_TYPE,
314314
container_log_level=container_log_level,
315+
framework_version=chainer_version,
315316
py_version=PYTHON_VERSION,
316317
base_job_name="job",
317318
source_dir=source_dir,
@@ -373,8 +374,8 @@ def test_chainer(strftime, sagemaker_session, chainer_version):
373374
sagemaker_session=sagemaker_session,
374375
train_instance_count=INSTANCE_COUNT,
375376
train_instance_type=INSTANCE_TYPE,
376-
py_version=PYTHON_VERSION,
377377
framework_version=chainer_version,
378+
py_version=PYTHON_VERSION,
378379
)
379380

380381
inputs = "s3://mybucket/train"
@@ -415,21 +416,28 @@ def test_chainer(strftime, sagemaker_session, chainer_version):
415416

416417

417418
@patch("sagemaker.utils.create_tar_file", MagicMock())
418-
def test_model(sagemaker_session):
419+
def test_model(sagemaker_session, chainer_version):
419420
model = ChainerModel(
420421
"s3://some/data.tar.gz",
421422
role=ROLE,
422423
entry_point=SCRIPT_PATH,
423424
sagemaker_session=sagemaker_session,
425+
framework_version=chainer_version,
426+
py_version=PYTHON_VERSION,
424427
)
425428
predictor = model.deploy(1, GPU)
426429
assert isinstance(predictor, ChainerPredictor)
427430

428431

429432
@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock())
430-
def test_model_prepare_container_def_accelerator_error(sagemaker_session):
433+
def test_model_prepare_container_def_accelerator_error(sagemaker_session, chainer_version):
431434
model = ChainerModel(
432-
MODEL_DATA, role=ROLE, entry_point=SCRIPT_PATH, sagemaker_session=sagemaker_session
435+
MODEL_DATA,
436+
role=ROLE,
437+
entry_point=SCRIPT_PATH,
438+
sagemaker_session=sagemaker_session,
439+
framework_version=chainer_version,
440+
py_version=PYTHON_VERSION,
433441
)
434442
with pytest.raises(ValueError):
435443
model.prepare_container_def(INSTANCE_TYPE, accelerator_type=ACCELERATOR_TYPE)
@@ -451,27 +459,29 @@ def test_train_image_default(sagemaker_session, chainer_version):
451459

452460
def test_train_image_cpu_instances(sagemaker_session, chainer_version):
453461
chainer = _chainer_estimator(
454-
sagemaker_session, chainer_version, train_instance_type="ml.c2.2xlarge"
462+
sagemaker_session, framework_version=chainer_version, train_instance_type="ml.c2.2xlarge"
455463
)
456464
assert chainer.train_image() == _get_full_cpu_image_uri(chainer_version)
457465

458466
chainer = _chainer_estimator(
459-
sagemaker_session, chainer_version, train_instance_type="ml.c4.2xlarge"
467+
sagemaker_session, framework_version=chainer_version, train_instance_type="ml.c4.2xlarge"
460468
)
461469
assert chainer.train_image() == _get_full_cpu_image_uri(chainer_version)
462470

463-
chainer = _chainer_estimator(sagemaker_session, chainer_version, train_instance_type="ml.m16")
471+
chainer = _chainer_estimator(
472+
sagemaker_session, framework_version=chainer_version, train_instance_type="ml.m16"
473+
)
464474
assert chainer.train_image() == _get_full_cpu_image_uri(chainer_version)
465475

466476

467477
def test_train_image_gpu_instances(sagemaker_session, chainer_version):
468478
chainer = _chainer_estimator(
469-
sagemaker_session, chainer_version, train_instance_type="ml.g2.2xlarge"
479+
sagemaker_session, framework_version=chainer_version, train_instance_type="ml.g2.2xlarge"
470480
)
471481
assert chainer.train_image() == _get_full_gpu_image_uri(chainer_version)
472482

473483
chainer = _chainer_estimator(
474-
sagemaker_session, chainer_version, train_instance_type="ml.p2.2xlarge"
484+
sagemaker_session, framework_version=chainer_version, train_instance_type="ml.p2.2xlarge"
475485
)
476486
assert chainer.train_image() == _get_full_gpu_image_uri(chainer_version)
477487

@@ -599,13 +609,14 @@ def test_attach_custom_image(sagemaker_session):
599609

600610

601611
@patch("sagemaker.chainer.estimator.python_deprecation_warning")
602-
def test_estimator_py2_warning(warning, sagemaker_session):
612+
def test_estimator_py2_warning(warning, sagemaker_session, chainer_version):
603613
estimator = Chainer(
604614
entry_point=SCRIPT_PATH,
605615
role=ROLE,
606616
sagemaker_session=sagemaker_session,
607617
train_instance_count=INSTANCE_COUNT,
608618
train_instance_type=INSTANCE_TYPE,
619+
framework_version=chainer_version,
609620
py_version="py2",
610621
)
611622

@@ -614,12 +625,13 @@ def test_estimator_py2_warning(warning, sagemaker_session):
614625

615626

616627
@patch("sagemaker.chainer.model.python_deprecation_warning")
617-
def test_model_py2_warning(warning, sagemaker_session):
628+
def test_model_py2_warning(warning, sagemaker_session, chainer_version):
618629
model = ChainerModel(
619630
MODEL_DATA,
620631
role=ROLE,
621632
entry_point=SCRIPT_PATH,
622633
sagemaker_session=sagemaker_session,
634+
framework_version=chainer_version,
623635
py_version="py2",
624636
)
625637
assert model.py_version == "py2"

0 commit comments

Comments
 (0)