Skip to content

Commit c01a07d

Browse files
authored
Merge branch 'zwei' into xgboost-uri
2 parents 01b005a + 211f4e5 commit c01a07d

27 files changed

+290
-294
lines changed

src/sagemaker/chainer/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def create_model(
214214
return ChainerModel(
215215
self.model_data,
216216
role or self.role,
217-
entry_point or self.entry_point,
217+
entry_point or self._model_entry_point(),
218218
source_dir=(source_dir or self._model_source_dir()),
219219
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
220220
container_log_level=self.container_log_level,

src/sagemaker/estimator.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1734,17 +1734,28 @@ def _stage_user_code_in_s3(self):
17341734
)
17351735

17361736
def _model_source_dir(self):
1737-
"""Get the appropriate value to pass as source_dir to model constructor
1738-
on deploying
1737+
"""Get the appropriate value to pass as ``source_dir`` to a model constructor.
17391738
17401739
Returns:
1741-
str: Either a local or an S3 path pointing to the source_dir to be
1742-
used for code by the model to be deployed
1740+
str: Either a local or an S3 path pointing to the ``source_dir`` to be
1741+
used for code by the model to be deployed
17431742
"""
17441743
return (
17451744
self.source_dir if self.sagemaker_session.local_mode else self.uploaded_code.s3_prefix
17461745
)
17471746

1747+
def _model_entry_point(self):
1748+
"""Get the appropriate value to pass as ``entry_point`` to a model constructor.
1749+
1750+
Returns:
1751+
str: The path to the entry point script. This can be either an absolute path or
1752+
a path relative to ``self._model_source_dir()``.
1753+
"""
1754+
if self.sagemaker_session.local_mode or (self._model_source_dir() is None):
1755+
return self.entry_point
1756+
1757+
return self.uploaded_code.script_name
1758+
17481759
def hyperparameters(self):
17491760
"""Return the hyperparameters as a dictionary to use for training.
17501761

src/sagemaker/fw_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ def tar_and_upload_dir(
447447
script name.
448448
"""
449449
if directory and directory.lower().startswith("s3://"):
450-
return UploadedCode(s3_prefix=directory, script_name=os.path.basename(script))
450+
return UploadedCode(s3_prefix=directory, script_name=script)
451451

452452
script_name = script if directory else os.path.basename(script)
453453
dependencies = dependencies or []

src/sagemaker/mxnet/estimator.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,10 +218,10 @@ def create_model(
218218

219219
kwargs["name"] = self._get_or_create_name(kwargs.get("name"))
220220

221-
return MXNetModel(
221+
model = MXNetModel(
222222
self.model_data,
223223
role or self.role,
224-
entry_point or self.entry_point,
224+
entry_point,
225225
framework_version=self.framework_version,
226226
py_version=self.py_version,
227227
source_dir=(source_dir or self._model_source_dir()),
@@ -235,6 +235,13 @@ def create_model(
235235
**kwargs
236236
)
237237

238+
if entry_point is None:
239+
model.entry_point = (
240+
self.entry_point if model._is_mms_version() else self._model_entry_point()
241+
)
242+
243+
return model
244+
238245
@classmethod
239246
def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None):
240247
"""Convert the job description to init params that can be handled by the

src/sagemaker/pytorch/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def create_model(
175175
return PyTorchModel(
176176
self.model_data,
177177
role or self.role,
178-
entry_point or self.entry_point,
178+
entry_point or self._model_entry_point(),
179179
framework_version=self.framework_version,
180180
py_version=self.py_version,
181181
source_dir=(source_dir or self._model_source_dir()),

src/sagemaker/rl/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def create_model(
232232
if not entry_point and (source_dir or dependencies):
233233
raise AttributeError("Please provide an `entry_point`.")
234234

235-
entry_point = entry_point or self.entry_point
235+
entry_point = entry_point or self._model_entry_point()
236236
source_dir = source_dir or self._model_source_dir()
237237
dependencies = dependencies or self.dependencies
238238

src/sagemaker/sklearn/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def create_model(
196196
return SKLearnModel(
197197
self.model_data,
198198
role,
199-
entry_point or self.entry_point,
199+
entry_point or self._model_entry_point(),
200200
source_dir=(source_dir or self._model_source_dir()),
201201
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
202202
container_log_level=self.container_log_level,

src/sagemaker/xgboost/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def create_model(
172172
return XGBoostModel(
173173
self.model_data,
174174
role,
175-
entry_point or self.entry_point,
175+
entry_point or self._model_entry_point(),
176176
framework_version=self.framework_version,
177177
source_dir=(source_dir or self._model_source_dir()),
178178
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,

tests/conftest.py

Lines changed: 16 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -166,17 +166,18 @@ def xgboost_framework_version(xgboost_version):
166166
return xgboost_version
167167

168168

169-
@pytest.fixture(scope="module")
170-
def tf_version(tensorflow_training_version):
171-
# TODO: remove this fixture and update tests
172-
if tensorflow_training_version in ("1.13.1", "2.2", "2.2.0"):
173-
pytest.skip("version isn't compatible with both training and inference.")
174-
return tensorflow_training_version
169+
@pytest.fixture(scope="module", params=["py2", "py3"])
170+
def tensorflow_training_py_version(tensorflow_training_version, request):
171+
return _tf_py_version(tensorflow_training_version, request)
175172

176173

177174
@pytest.fixture(scope="module", params=["py2", "py3"])
178-
def tf_py_version(tensorflow_training_version, request):
179-
version = Version(tensorflow_training_version)
175+
def tensorflow_inference_py_version(tensorflow_inference_version, request):
176+
return _tf_py_version(tensorflow_inference_version, request)
177+
178+
179+
def _tf_py_version(tf_version, request):
180+
version = Version(tf_version)
180181
if version < Version("1.11"):
181182
return "py2"
182183
if version < Version("2.2"):
@@ -255,28 +256,18 @@ def sklearn_full_py_version():
255256

256257

257258
@pytest.fixture(scope="module")
258-
def tf_training_latest_version():
259-
return "2.2.0"
260-
261-
262-
@pytest.fixture(scope="module")
263-
def tf_training_latest_py_version():
264-
return "py37"
265-
266-
267-
@pytest.fixture(scope="module")
268-
def tf_serving_latest_version():
269-
return "2.1.0"
270-
271-
272-
@pytest.fixture(scope="module")
273-
def tf_full_version(tf_training_latest_version, tf_serving_latest_version):
259+
def tf_full_version(tensorflow_training_latest_version, tensorflow_inference_latest_version):
274260
"""Fixture for TF tests that test both training and inference.
275261
276262
Fixture exists as such, since TF training and TFS have different latest versions.
277263
Otherwise, this would simply be a single latest version.
278264
"""
279-
return str(min(Version(tf_training_latest_version), Version(tf_serving_latest_version)))
265+
return str(
266+
min(
267+
Version(tensorflow_training_latest_version),
268+
Version(tensorflow_inference_latest_version),
269+
)
270+
)
280271

281272

282273
@pytest.fixture(scope="module")
@@ -294,11 +285,6 @@ def tf_full_py_version(tf_full_version):
294285
return "py37"
295286

296287

297-
@pytest.fixture(scope="module")
298-
def ei_tf_full_version():
299-
return "2.0.0"
300-
301-
302288
@pytest.fixture(scope="module")
303289
def xgboost_full_version():
304290
return "1.0-1"
2.15 KB
Binary file not shown.

tests/integ/test_airflow_config.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,10 @@ def test_sklearn_airflow_config_uploads_data_source_to_s3(
512512

513513
@pytest.mark.canary_quick
514514
def test_tf_airflow_config_uploads_data_source_to_s3(
515-
sagemaker_session, cpu_instance_type, tf_training_latest_version, tf_training_latest_py_version
515+
sagemaker_session,
516+
cpu_instance_type,
517+
tensorflow_training_latest_version,
518+
tensorflow_training_latest_py_version,
516519
):
517520
with timeout(seconds=AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS):
518521
tf = TensorFlow(
@@ -524,8 +527,8 @@ def test_tf_airflow_config_uploads_data_source_to_s3(
524527
instance_count=SINGLE_INSTANCE_COUNT,
525528
instance_type=cpu_instance_type,
526529
sagemaker_session=sagemaker_session,
527-
framework_version=tf_training_latest_version,
528-
py_version=tf_training_latest_py_version,
530+
framework_version=tensorflow_training_latest_version,
531+
py_version=tensorflow_training_latest_py_version,
529532
metric_definitions=[
530533
{"Name": "train:global_steps", "Regex": r"global_step\/sec:\s(.*)"}
531534
],

tests/integ/test_data_capture_config.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141

4242

4343
def test_enabling_data_capture_on_endpoint_shows_correct_data_capture_status(
44-
sagemaker_session, tf_serving_latest_version
44+
sagemaker_session, tensorflow_inference_latest_version
4545
):
4646
endpoint_name = unique_name_from_base("sagemaker-tensorflow-serving")
4747
model_data = sagemaker_session.upload_data(
@@ -52,7 +52,7 @@ def test_enabling_data_capture_on_endpoint_shows_correct_data_capture_status(
5252
model = TensorFlowModel(
5353
model_data=model_data,
5454
role=ROLE,
55-
framework_version=tf_serving_latest_version,
55+
framework_version=tensorflow_inference_latest_version,
5656
sagemaker_session=sagemaker_session,
5757
)
5858
predictor = model.deploy(
@@ -98,7 +98,7 @@ def test_enabling_data_capture_on_endpoint_shows_correct_data_capture_status(
9898

9999

100100
def test_disabling_data_capture_on_endpoint_shows_correct_data_capture_status(
101-
sagemaker_session, tf_serving_latest_version
101+
sagemaker_session, tensorflow_inference_latest_version
102102
):
103103
endpoint_name = unique_name_from_base("sagemaker-tensorflow-serving")
104104
model_data = sagemaker_session.upload_data(
@@ -109,7 +109,7 @@ def test_disabling_data_capture_on_endpoint_shows_correct_data_capture_status(
109109
model = TensorFlowModel(
110110
model_data=model_data,
111111
role=ROLE,
112-
framework_version=tf_serving_latest_version,
112+
framework_version=tensorflow_inference_latest_version,
113113
sagemaker_session=sagemaker_session,
114114
)
115115
destination_s3_uri = os.path.join(
@@ -184,7 +184,7 @@ def test_disabling_data_capture_on_endpoint_shows_correct_data_capture_status(
184184

185185

186186
def test_updating_data_capture_on_endpoint_shows_correct_data_capture_status(
187-
sagemaker_session, tf_serving_latest_version
187+
sagemaker_session, tensorflow_inference_latest_version
188188
):
189189
endpoint_name = sagemaker.utils.unique_name_from_base("sagemaker-tensorflow-serving")
190190
model_data = sagemaker_session.upload_data(
@@ -195,7 +195,7 @@ def test_updating_data_capture_on_endpoint_shows_correct_data_capture_status(
195195
model = TensorFlowModel(
196196
model_data=model_data,
197197
role=ROLE,
198-
framework_version=tf_serving_latest_version,
198+
framework_version=tensorflow_inference_latest_version,
199199
sagemaker_session=sagemaker_session,
200200
)
201201
destination_s3_uri = os.path.join(

tests/integ/test_horovod.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,15 @@
3131
@pytest.mark.canary_quick
3232
def test_hvd_cpu(
3333
sagemaker_session,
34-
tf_training_latest_version,
35-
tf_training_latest_py_version,
34+
tensorflow_training_latest_version,
35+
tensorflow_training_latest_py_version,
3636
cpu_instance_type,
3737
tmpdir,
3838
):
3939
_create_and_fit_estimator(
4040
sagemaker_session,
41-
tf_training_latest_version,
42-
tf_training_latest_py_version,
41+
tensorflow_training_latest_version,
42+
tensorflow_training_latest_py_version,
4343
cpu_instance_type,
4444
tmpdir,
4545
)
@@ -50,12 +50,15 @@ def test_hvd_cpu(
5050
integ.test_region() in integ.TRAINING_NO_P2_REGIONS, reason="no ml.p2 instances in this region"
5151
)
5252
def test_hvd_gpu(
53-
sagemaker_session, tf_training_latest_version, tf_training_latest_py_version, tmpdir
53+
sagemaker_session,
54+
tensorflow_training_latest_version,
55+
tensorflow_training_latest_py_version,
56+
tmpdir,
5457
):
5558
_create_and_fit_estimator(
5659
sagemaker_session,
57-
tf_training_latest_version,
58-
tf_training_latest_py_version,
60+
tensorflow_training_latest_version,
61+
tensorflow_training_latest_py_version,
5962
"ml.p2.xlarge",
6063
tmpdir,
6164
)
@@ -65,8 +68,8 @@ def test_hvd_gpu(
6568
@pytest.mark.parametrize("instances, processes", [[1, 2], (2, 1), (2, 2)])
6669
def test_horovod_local_mode(
6770
sagemaker_local_session,
68-
tf_training_latest_version,
69-
tf_training_latest_py_version,
71+
tensorflow_training_latest_version,
72+
tensorflow_training_latest_py_version,
7073
instances,
7174
processes,
7275
tmpdir,
@@ -80,8 +83,8 @@ def test_horovod_local_mode(
8083
instance_type="local",
8184
sagemaker_session=sagemaker_local_session,
8285
output_path=output_path,
83-
framework_version=tf_training_latest_version,
84-
py_version=tf_training_latest_py_version,
86+
framework_version=tensorflow_training_latest_version,
87+
py_version=tensorflow_training_latest_py_version,
8588
distribution={"mpi": {"enabled": True, "processes_per_host": processes}},
8689
)
8790

tests/integ/test_model_monitor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@
8888

8989

9090
@pytest.fixture(scope="module")
91-
def predictor(sagemaker_session, tf_serving_latest_version):
91+
def predictor(sagemaker_session, tensorflow_inference_latest_version):
9292
endpoint_name = unique_name_from_base("sagemaker-tensorflow-serving")
9393
model_data = sagemaker_session.upload_data(
9494
path=os.path.join(tests.integ.DATA_DIR, "tensorflow-serving-test-model.tar.gz"),
@@ -100,7 +100,7 @@ def predictor(sagemaker_session, tf_serving_latest_version):
100100
model = TensorFlowModel(
101101
model_data=model_data,
102102
role=ROLE,
103-
framework_version=tf_serving_latest_version,
103+
framework_version=tensorflow_inference_latest_version,
104104
sagemaker_session=sagemaker_session,
105105
)
106106
predictor = model.deploy(

tests/integ/test_mxnet.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,16 @@ def mxnet_training_job(
3232
sagemaker_session, mxnet_full_version, mxnet_full_py_version, cpu_instance_type
3333
):
3434
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
35-
script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist.py")
35+
s3_prefix = "integ-test-data/mxnet_mnist"
3636
data_path = os.path.join(DATA_DIR, "mxnet_mnist")
3737

38+
s3_source = sagemaker_session.upload_data(
39+
path=os.path.join(data_path, "sourcedir.tar.gz"), key_prefix="{}/src".format(s3_prefix)
40+
)
41+
3842
mx = MXNet(
39-
entry_point=script_path,
43+
entry_point="mxnet_mnist/mnist.py",
44+
source_dir=s3_source,
4045
role="SageMakerRole",
4146
framework_version=mxnet_full_version,
4247
py_version=mxnet_full_py_version,
@@ -46,10 +51,10 @@ def mxnet_training_job(
4651
)
4752

4853
train_input = mx.sagemaker_session.upload_data(
49-
path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train"
54+
path=os.path.join(data_path, "train"), key_prefix="{}/train".format(s3_prefix)
5055
)
5156
test_input = mx.sagemaker_session.upload_data(
52-
path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test"
57+
path=os.path.join(data_path, "test"), key_prefix="{}/test".format(s3_prefix)
5358
)
5459

5560
mx.fit({"train": train_input, "test": test_input})
@@ -62,7 +67,13 @@ def test_attach_deploy(mxnet_training_job, sagemaker_session, cpu_instance_type)
6267

6368
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
6469
estimator = MXNet.attach(mxnet_training_job, sagemaker_session=sagemaker_session)
65-
predictor = estimator.deploy(1, cpu_instance_type, endpoint_name=endpoint_name)
70+
predictor = estimator.deploy(
71+
1,
72+
cpu_instance_type,
73+
entry_point="mnist.py",
74+
source_dir=os.path.join(DATA_DIR, "mxnet_mnist"),
75+
endpoint_name=endpoint_name,
76+
)
6677
data = numpy.zeros(shape=(1, 1, 28, 28))
6778
result = predictor.predict(data)
6879
assert result is not None

0 commit comments

Comments
 (0)