Skip to content

Commit 5fc1125

Browse files
committed
breaking: require framework_version, py_version for chainer
1 parent 9df3f5a commit 5fc1125

File tree

5 files changed

+73
-80
lines changed

5 files changed

+73
-80
lines changed

doc/frameworks/chainer/using_chainer.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ directories ('train' and 'test').
141141
train_instance_type='ml.p3.2xlarge',
142142
train_instance_count=1,
143143
framework_version='5.0.0',
144+
py_version='py3',
144145
hyperparameters = {'epochs': 20, 'batch-size': 64, 'learning-rate': 0.1})
145146
chainer_estimator.fit({'train': 's3://my-data-bucket/path/to/my/training/data',
146147
'test': 's3://my-data-bucket/path/to/my/test/data'})
@@ -222,7 +223,8 @@ operation.
222223
chainer_estimator = Chainer(entry_point='train_and_deploy.py',
223224
train_instance_type='ml.p3.2xlarge',
224225
train_instance_count=1,
225-
framework_version='5.0.0')
226+
framework_version='5.0.0',
227+
py_version='py3')
226228
chainer_estimator.fit('s3://my_bucket/my_training_data/')
227229
228230
# Deploy my estimator to a SageMaker Endpoint and get a Predictor

src/sagemaker/chainer/estimator.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
from sagemaker.fw_utils import (
2020
framework_name_from_image,
2121
framework_version_from_tag,
22-
empty_framework_version_warning,
2322
python_deprecation_warning,
23+
validate_version_or_image_args,
2424
)
2525
from sagemaker.chainer import defaults
2626
from sagemaker.chainer.model import ChainerModel
@@ -51,8 +51,8 @@ def __init__(
5151
additional_mpi_options=None,
5252
source_dir=None,
5353
hyperparameters=None,
54-
py_version="py3",
5554
framework_version=None,
55+
py_version=None,
5656
image_name=None,
5757
**kwargs
5858
):
@@ -103,11 +103,12 @@ def __init__(
103103
and values, but ``str()`` will be called to convert them before
104104
training.
105105
py_version (str): Python version you want to use for executing your
106-
model training code (default: 'py2'). One of 'py2' or 'py3'.
106+
model training code. Defaults to ``None``. Required unless ``image_name``
107+
is provided.
107108
framework_version (str): Chainer version you want to use for
108-
executing your model training code. List of supported versions
109+
executing your model training code. Defaults to ``None``. Required unless
110+
``image_name`` is provided. List of supported versions:
109111
https://github.com/aws/sagemaker-python-sdk#chainer-sagemaker-estimators.
110-
If not specified, this will default to 4.1.
111112
image_name (str): If specified, the estimator will use this image
112113
for training and hosting, instead of selecting the appropriate
113114
SageMaker official image based on framework_version and
@@ -117,6 +118,9 @@ def __init__(
117118
* ``123412341234.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0``
118119
* ``custom-image:latest``
119120
121+
If ``framework_version`` or ``py_version`` are ``None``, then
122+
``image_name`` is required. If also ``None``, then a ``ValueError``
123+
will be raised.
120124
**kwargs: Additional kwargs passed to the
121125
:class:`~sagemaker.estimator.Framework` constructor.
122126
@@ -126,22 +130,18 @@ def __init__(
126130
:class:`~sagemaker.estimator.Framework` and
127131
:class:`~sagemaker.estimator.EstimatorBase`.
128132
"""
129-
if framework_version is None:
133+
validate_version_or_image_args(framework_version, py_version, image_name)
134+
if py_version == "py2":
130135
logger.warning(
131-
empty_framework_version_warning(defaults.CHAINER_VERSION, self.LATEST_VERSION)
136+
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
132137
)
133-
self.framework_version = framework_version or defaults.CHAINER_VERSION
138+
self.framework_version = framework_version
139+
self.py_version = py_version
134140

135141
super(Chainer, self).__init__(
136142
entry_point, source_dir, hyperparameters, image_name=image_name, **kwargs
137143
)
138144

139-
if py_version == "py2":
140-
logger.warning(
141-
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
142-
)
143-
144-
self.py_version = py_version
145145
self.use_mpi = use_mpi
146146
self.num_processes = num_processes
147147
self.process_slots_per_host = process_slots_per_host
@@ -262,15 +262,19 @@ class constructor
262262
image_name = init_params.pop("image")
263263
framework, py_version, tag, _ = framework_name_from_image(image_name)
264264

265+
if tag is None:
266+
framework_version = None
267+
else:
268+
framework_version = framework_version_from_tag(tag)
269+
init_params["framework_version"] = framework_version
270+
init_params["py_version"] = py_version
271+
265272
if not framework:
266273
# If we were unable to parse the framework name from the image it is not one of our
267274
# officially supported images, in this case just add the image to the init params.
268275
init_params["image_name"] = image_name
269276
return init_params
270277

271-
init_params["py_version"] = py_version
272-
init_params["framework_version"] = framework_version_from_tag(tag)
273-
274278
training_job_name = init_params["base_job_name"]
275279

276280
if framework != cls.__framework_name__:

src/sagemaker/chainer/model.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
create_image_uri,
2323
model_code_key_prefix,
2424
python_deprecation_warning,
25-
empty_framework_version_warning,
25+
validate_version_or_image_args,
2626
)
2727
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
2828
from sagemaker.chainer import defaults
@@ -67,8 +67,8 @@ def __init__(
6767
role,
6868
entry_point,
6969
image=None,
70-
py_version="py3",
7170
framework_version=None,
71+
py_version=None,
7272
predictor_cls=ChainerPredictor,
7373
model_server_workers=None,
7474
**kwargs
@@ -88,11 +88,15 @@ def __init__(
8888
hosting. If ``source_dir`` is specified, then ``entry_point``
8989
must point to a file located at the root of ``source_dir``.
9090
image (str): A Docker image URI (default: None). If not specified, a
91-
default image for Chainer will be used.
92-
py_version (str): Python version you want to use for executing your
93-
model training code (default: 'py2').
91+
default image for Chainer will be used. If ``framework_version``
92+
or ``py_version`` are ``None``, then ``image`` is required. If
93+
also ``None``, then a ``ValueError`` will be raised.
9494
framework_version (str): Chainer version you want to use for
95-
executing your model training code.
95+
executing your model training code. Defaults to ``None``. Required
96+
unless ``image`` is provided.
97+
py_version (str): Python version you want to use for executing your
98+
model training code. Defaults to ``None``. Required unless
99+
``image`` is provided.
96100
predictor_cls (callable[str, sagemaker.session.Session]): A function
97101
to call to create a predictor with an endpoint name and
98102
SageMaker ``Session``. If specified, ``deploy()`` returns the
@@ -109,21 +113,18 @@ def __init__(
109113
:class:`~sagemaker.model.FrameworkModel` and
110114
:class:`~sagemaker.model.Model`.
111115
"""
112-
super(ChainerModel, self).__init__(
113-
model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs
114-
)
116+
validate_version_or_image_args(framework_version, py_version, image)
115117
if py_version == "py2":
116118
logger.warning(
117119
python_deprecation_warning(self.__framework_name__, defaults.LATEST_PY2_VERSION)
118120
)
121+
self.framework_version = framework_version
122+
self.py_version = py_version
119123

120-
if framework_version is None:
121-
logger.warning(
122-
empty_framework_version_warning(defaults.CHAINER_VERSION, defaults.LATEST_VERSION)
123-
)
124+
super(ChainerModel, self).__init__(
125+
model_data, image, role, entry_point, predictor_cls=predictor_cls, **kwargs
126+
)
124127

125-
self.py_version = py_version
126-
self.framework_version = framework_version or defaults.CHAINER_VERSION
127128
self.model_server_workers = model_server_workers
128129

129130
def prepare_container_def(self, instance_type, accelerator_type=None):

tests/unit/test_chainer.py

Lines changed: 32 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def _get_full_cpu_image_uri_with_ei(version, py_version=PYTHON_VERSION):
9191

9292
def _chainer_estimator(
9393
sagemaker_session,
94-
framework_version=defaults.CHAINER_VERSION,
94+
framework_version,
9595
train_instance_type=None,
9696
base_job_name=None,
9797
use_mpi=None,
@@ -202,13 +202,14 @@ def _create_train_job_with_additional_hyperparameters(version):
202202
}
203203

204204

205-
def test_additional_hyperparameters(sagemaker_session):
205+
def test_additional_hyperparameters(sagemaker_session, chainer_version):
206206
chainer = _chainer_estimator(
207207
sagemaker_session,
208208
use_mpi=True,
209209
num_processes=4,
210210
process_slots_per_host=10,
211211
additional_mpi_options="-x MY_ENVIRONMENT_VARIABLE",
212+
framework_version=chainer_version,
212213
)
213214
assert bool(strtobool(chainer.hyperparameters()["sagemaker_use_mpi"]))
214215
assert int(chainer.hyperparameters()["sagemaker_num_processes"]) == 4
@@ -300,7 +301,7 @@ def test_create_model(sagemaker_session, chainer_version):
300301
assert model.vpc_config is None
301302

302303

303-
def test_create_model_with_optional_params(sagemaker_session):
304+
def test_create_model_with_optional_params(sagemaker_session, chainer_version):
304305
container_log_level = '"logging.INFO"'
305306
source_dir = "s3://mybucket/source"
306307
enable_cloudwatch_metrics = "true"
@@ -311,6 +312,7 @@ def test_create_model_with_optional_params(sagemaker_session):
311312
train_instance_count=INSTANCE_COUNT,
312313
train_instance_type=INSTANCE_TYPE,
313314
container_log_level=container_log_level,
315+
framework_version=chainer_version,
314316
py_version=PYTHON_VERSION,
315317
base_job_name="job",
316318
source_dir=source_dir,
@@ -372,8 +374,8 @@ def test_chainer(strftime, sagemaker_session, chainer_version):
372374
sagemaker_session=sagemaker_session,
373375
train_instance_count=INSTANCE_COUNT,
374376
train_instance_type=INSTANCE_TYPE,
375-
py_version=PYTHON_VERSION,
376377
framework_version=chainer_version,
378+
py_version=PYTHON_VERSION,
377379
)
378380

379381
inputs = "s3://mybucket/train"
@@ -414,62 +416,72 @@ def test_chainer(strftime, sagemaker_session, chainer_version):
414416

415417

416418
@patch("sagemaker.utils.create_tar_file", MagicMock())
417-
def test_model(sagemaker_session):
419+
def test_model(sagemaker_session, chainer_version):
418420
model = ChainerModel(
419421
"s3://some/data.tar.gz",
420422
role=ROLE,
421423
entry_point=SCRIPT_PATH,
422424
sagemaker_session=sagemaker_session,
425+
framework_version=chainer_version,
426+
py_version=PYTHON_VERSION,
423427
)
424428
predictor = model.deploy(1, GPU)
425429
assert isinstance(predictor, ChainerPredictor)
426430

427431

428432
@patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock())
429-
def test_model_prepare_container_def_accelerator_error(sagemaker_session):
433+
def test_model_prepare_container_def_accelerator_error(sagemaker_session, chainer_version):
430434
model = ChainerModel(
431-
MODEL_DATA, role=ROLE, entry_point=SCRIPT_PATH, sagemaker_session=sagemaker_session
435+
MODEL_DATA,
436+
role=ROLE,
437+
entry_point=SCRIPT_PATH,
438+
sagemaker_session=sagemaker_session,
439+
framework_version=chainer_version,
440+
py_version=PYTHON_VERSION,
432441
)
433442
with pytest.raises(ValueError):
434443
model.prepare_container_def(INSTANCE_TYPE, accelerator_type=ACCELERATOR_TYPE)
435444

436445

437-
def test_train_image_default(sagemaker_session):
446+
def test_train_image_default(sagemaker_session, chainer_version):
438447
chainer = Chainer(
439448
entry_point=SCRIPT_PATH,
440449
role=ROLE,
441450
sagemaker_session=sagemaker_session,
442451
train_instance_count=INSTANCE_COUNT,
443452
train_instance_type=INSTANCE_TYPE,
453+
framework_version=chainer_version,
444454
py_version=PYTHON_VERSION,
445455
)
446456

447-
assert _get_full_cpu_image_uri(defaults.CHAINER_VERSION) in chainer.train_image()
457+
assert _get_full_cpu_image_uri(chainer_version) in chainer.train_image()
448458

449459

450460
def test_train_image_cpu_instances(sagemaker_session, chainer_version):
451461
chainer = _chainer_estimator(
452-
sagemaker_session, chainer_version, train_instance_type="ml.c2.2xlarge"
462+
sagemaker_session, framework_version=chainer_version, train_instance_type="ml.c2.2xlarge"
453463
)
454464
assert chainer.train_image() == _get_full_cpu_image_uri(chainer_version)
455465

456466
chainer = _chainer_estimator(
457-
sagemaker_session, chainer_version, train_instance_type="ml.c4.2xlarge"
467+
sagemaker_session, framework_version=chainer_version, train_instance_type="ml.c4.2xlarge"
458468
)
459469
assert chainer.train_image() == _get_full_cpu_image_uri(chainer_version)
460470

461-
chainer = _chainer_estimator(sagemaker_session, chainer_version, train_instance_type="ml.m16")
471+
chainer = _chainer_estimator(
472+
sagemaker_session, framework_version=chainer_version, train_instance_type="ml.m16"
473+
)
462474
assert chainer.train_image() == _get_full_cpu_image_uri(chainer_version)
463475

464476

465477
def test_train_image_gpu_instances(sagemaker_session, chainer_version):
466478
chainer = _chainer_estimator(
467-
sagemaker_session, chainer_version, train_instance_type="ml.g2.2xlarge"
479+
sagemaker_session, framework_version=chainer_version, train_instance_type="ml.g2.2xlarge"
468480
)
469481
assert chainer.train_image() == _get_full_gpu_image_uri(chainer_version)
470482

471483
chainer = _chainer_estimator(
472-
sagemaker_session, chainer_version, train_instance_type="ml.p2.2xlarge"
484+
sagemaker_session, framework_version=chainer_version, train_instance_type="ml.p2.2xlarge"
473485
)
474486
assert chainer.train_image() == _get_full_gpu_image_uri(chainer_version)
475487

@@ -597,13 +609,14 @@ def test_attach_custom_image(sagemaker_session):
597609

598610

599611
@patch("sagemaker.chainer.estimator.python_deprecation_warning")
600-
def test_estimator_py2_warning(warning, sagemaker_session):
612+
def test_estimator_py2_warning(warning, sagemaker_session, chainer_version):
601613
estimator = Chainer(
602614
entry_point=SCRIPT_PATH,
603615
role=ROLE,
604616
sagemaker_session=sagemaker_session,
605617
train_instance_count=INSTANCE_COUNT,
606618
train_instance_type=INSTANCE_TYPE,
619+
framework_version=chainer_version,
607620
py_version="py2",
608621
)
609622

@@ -612,49 +625,22 @@ def test_estimator_py2_warning(warning, sagemaker_session):
612625

613626

614627
@patch("sagemaker.chainer.model.python_deprecation_warning")
615-
def test_model_py2_warning(warning, sagemaker_session):
628+
def test_model_py2_warning(warning, sagemaker_session, chainer_version):
616629
model = ChainerModel(
617630
MODEL_DATA,
618631
role=ROLE,
619632
entry_point=SCRIPT_PATH,
620633
sagemaker_session=sagemaker_session,
634+
framework_version=chainer_version,
621635
py_version="py2",
622636
)
623637
assert model.py_version == "py2"
624638
warning.assert_called_with(model.__framework_name__, defaults.LATEST_PY2_VERSION)
625639

626640

627-
@patch("sagemaker.chainer.estimator.empty_framework_version_warning")
628-
def test_empty_framework_version(warning, sagemaker_session):
629-
estimator = Chainer(
630-
entry_point=SCRIPT_PATH,
631-
role=ROLE,
632-
sagemaker_session=sagemaker_session,
633-
train_instance_count=INSTANCE_COUNT,
634-
train_instance_type=INSTANCE_TYPE,
635-
framework_version=None,
636-
)
637-
638-
assert estimator.framework_version == defaults.CHAINER_VERSION
639-
warning.assert_called_with(defaults.CHAINER_VERSION, Chainer.LATEST_VERSION)
640-
641-
642-
@patch("sagemaker.chainer.model.empty_framework_version_warning")
643-
def test_model_empty_framework_version(warning, sagemaker_session):
644-
model = ChainerModel(
645-
MODEL_DATA,
646-
role=ROLE,
647-
entry_point=SCRIPT_PATH,
648-
sagemaker_session=sagemaker_session,
649-
framework_version=None,
650-
)
651-
assert model.framework_version == defaults.CHAINER_VERSION
652-
warning.assert_called_with(defaults.CHAINER_VERSION, defaults.LATEST_VERSION)
653-
654-
655-
def test_custom_image_estimator_deploy(sagemaker_session):
641+
def test_custom_image_estimator_deploy(sagemaker_session, chainer_version):
656642
custom_image = "mycustomimage:latest"
657-
chainer = _chainer_estimator(sagemaker_session)
643+
chainer = _chainer_estimator(sagemaker_session, framework_version=chainer_version)
658644
chainer.fit(inputs="s3://mybucket/train", job_name="new_name")
659645
model = chainer.create_model(image=custom_image)
660646
assert model.image == custom_image

tox.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ extras = test
6868

6969
[testenv:py27]
7070
setenv =
71-
IGNORE_COVERAGE = true
71+
IGNORE_COVERAGE = -
7272

7373
[testenv:flake8]
7474
basepython = python3

0 commit comments

Comments
 (0)