Skip to content

Commit c91ab64

Browse files
authored
Merge pull request aws#36 from athewsey/feat/fw-processor
Fix XGBoostProcessor and add TF integration test
2 parents 69d791b + 072b84f commit c91ab64

File tree

9 files changed

+98
-9
lines changed

9 files changed

+98
-9
lines changed

src/sagemaker/mxnet/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
"""Placeholder docstring"""
13+
"""Classes for using MXNet with Amazon SageMaker."""
1414
from __future__ import absolute_import # noqa: F401
1515

1616
from sagemaker.mxnet.estimator import MXNet # noqa: F401
1717
from sagemaker.mxnet.model import MXNetModel, MXNetPredictor # noqa: F401
18+
from sagemaker.mxnet.processing import MXNetProcessor # noqa: F401

src/sagemaker/pytorch/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
"""Placeholder docstring"""
13+
"""Classes for using PyTorch with Amazon SageMaker."""
1414
from __future__ import absolute_import
1515

1616
from sagemaker.pytorch.estimator import PyTorch # noqa: F401
1717
from sagemaker.pytorch.model import PyTorchModel, PyTorchPredictor # noqa: F401
18+
from sagemaker.pytorch.processing import PyTorchProcessor # noqa: F401

src/sagemaker/sklearn/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
"""Placeholder docstring"""
13+
"""Classes for using Scikit-Learn with Amazon SageMaker."""
1414
from __future__ import absolute_import
1515

1616
from sagemaker.sklearn.estimator import SKLearn # noqa: F401

src/sagemaker/tensorflow/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@
1515

1616
from sagemaker.tensorflow.estimator import TensorFlow # noqa: F401 (imported but unused)
1717
from sagemaker.tensorflow.model import TensorFlowModel, TensorFlowPredictor # noqa: F401
18+
from sagemaker.tensorflow.processing import TensorFlowProcessor # noqa: F401

src/sagemaker/xgboost/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
"""Placeholder docstring"""
13+
"""Classes for using XGBoost with Amazon SageMaker."""
14+
from __future__ import absolute_import
15+
1416
from sagemaker.xgboost.defaults import XGBOOST_NAME # noqa: F401
1517
from sagemaker.xgboost.estimator import XGBoost # noqa: F401
1618
from sagemaker.xgboost.model import XGBoostModel, XGBoostPredictor # noqa: F401
19+
from sagemaker.xgboost.processing import XGBoostProcessor # noqa: F401

src/sagemaker/xgboost/processing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from sagemaker.xgboost.estimator import XGBoost
2222

2323

24-
class XGBoostEstimator(FrameworkProcessor):
24+
class XGBoostProcessor(FrameworkProcessor):
2525
"""Handles Amazon SageMaker processing tasks for jobs using XGBoost containers."""
2626

2727
estimator_cls = XGBoost

tests/conftest.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,4 +411,8 @@ def _parametrize_framework_version_fixtures(metafunc, fixture_prefix, config):
411411
if fixture_name in metafunc.fixturenames:
412412
config = config["versions"]
413413
py_versions = config[latest_version].get("py_versions", config[latest_version].keys())
414-
metafunc.parametrize(fixture_name, (sorted(py_versions)[-1],), scope="session")
414+
if "repository" in py_versions or "registries" in py_versions:
415+
# Config did not specify `py_versions` and is not arranged by py_version. Assume py3
416+
metafunc.parametrize(fixture_name, ("py3",), scope="session")
417+
else:
418+
metafunc.parametrize(fixture_name, (sorted(py_versions)[-1],), scope="session")

tests/integ/test_tf.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818

1919
import pytest
2020

21-
from sagemaker.tensorflow import TensorFlow
21+
from sagemaker.tensorflow import TensorFlow, TensorFlowProcessor
2222
from sagemaker.utils import unique_name_from_base, sagemaker_timestamp
2323

2424
import tests.integ
25-
from tests.integ import kms_utils, timeout
25+
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES, kms_utils, timeout
2626
from tests.integ.retry import retries
2727
from tests.integ.s3_utils import assert_s3_files_exist
2828

@@ -39,6 +39,35 @@
3939
ENV_INPUT = {"env_key1": "env_val1", "env_key2": "env_val2", "env_key3": "env_val3"}
4040

4141

42+
@pytest.mark.release
43+
def test_framework_processing_job_with_deps(
44+
sagemaker_session,
45+
instance_type,
46+
tensorflow_training_latest_version,
47+
tensorflow_training_latest_py_version,
48+
):
49+
with timeout.timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
50+
code_path = os.path.join(DATA_DIR, "dummy_code_bundle_with_reqs")
51+
entry_point = "main_script.py"
52+
53+
processor = TensorFlowProcessor(
54+
framework_version=tensorflow_training_latest_version,
55+
py_version=tensorflow_training_latest_py_version,
56+
role=ROLE,
57+
instance_count=1,
58+
instance_type=instance_type,
59+
sagemaker_session=sagemaker_session,
60+
base_job_name="test-tensorflow",
61+
)
62+
63+
processor.run(
64+
code=entry_point,
65+
source_dir=code_path,
66+
inputs=[],
67+
wait=True,
68+
)
69+
70+
4271
def test_mnist_with_checkpoint_config(
4372
sagemaker_session,
4473
instance_type,
@@ -51,7 +80,7 @@ def test_mnist_with_checkpoint_config(
5180
checkpoint_local_path = "/test/checkpoint/path"
5281
estimator = TensorFlow(
5382
entry_point=SCRIPT,
54-
role="SageMakerRole",
83+
role=ROLE,
5584
instance_count=1,
5685
instance_type=instance_type,
5786
sagemaker_session=sagemaker_session,

tests/integ/test_xgboost.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright 2018-2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import os
16+
import pytest
17+
from sagemaker.xgboost.processing import XGBoostProcessor
18+
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES
19+
from tests.integ.timeout import timeout
20+
21+
ROLE = "SageMakerRole"
22+
23+
24+
@pytest.mark.release
25+
def test_framework_processing_job_with_deps(
26+
sagemaker_session,
27+
xgboost_latest_version,
28+
xgboost_latest_py_version,
29+
cpu_instance_type,
30+
):
31+
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
32+
code_path = os.path.join(DATA_DIR, "dummy_code_bundle_with_reqs")
33+
entry_point = "main_script.py"
34+
35+
processor = XGBoostProcessor(
36+
framework_version=xgboost_latest_version,
37+
py_version=xgboost_latest_py_version,
38+
role=ROLE,
39+
instance_count=1,
40+
instance_type=cpu_instance_type,
41+
sagemaker_session=sagemaker_session,
42+
base_job_name="test-xgboost",
43+
)
44+
45+
processor.run(
46+
code=entry_point,
47+
source_dir=code_path,
48+
inputs=[],
49+
wait=True,
50+
)

0 commit comments

Comments
 (0)