Skip to content

Commit 2af9d45

Browse files
authored
Create configurable sagemaker_session fixture for all integ tests (#104)
* Create configurable sagemaker_session fixture for all integ tests * Update changelog
1 parent 6e0047b commit 2af9d45

13 files changed

+93
-103
lines changed

CHANGELOG.rst

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,18 @@
22
CHANGELOG
33
=========
44

5+
1.1.dev3
6+
========
7+
8+
* feature: Tests: create configurable ``sagemaker_session`` pytest fixture for all integration tests
9+
510
1.1.2
6-
=======
11+
=====
712

813
* bug-fix: AmazonEstimators: do not call create bucket if data location is provided
914

1015
1.1.1
11-
========
16+
=====
1217

1318
* feature: Estimators: add ``requirements.txt`` support for TensorFlow
1419

tests/conftest.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,50 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13+
import json
14+
15+
import boto3
1316
import pytest
1417

18+
from sagemaker import Session
19+
20+
DEFAULT_REGION = 'us-west-2'
21+
22+
23+
def pytest_addoption(parser):
24+
parser.addoption('--sagemaker-client-config', action='store', default=None)
25+
parser.addoption('--sagemaker-runtime-config', action='store', default=None)
26+
parser.addoption('--boto-config', action='store', default=None)
27+
28+
29+
@pytest.fixture(scope='session')
30+
def sagemaker_client_config(request):
31+
config = request.config.getoption('--sagemaker-client-config')
32+
return json.loads(config) if config else None
33+
34+
35+
@pytest.fixture(scope='session')
36+
def sagemaker_runtime_config(request):
37+
config = request.config.getoption('--sagemaker-runtime-config')
38+
return json.loads(config) if config else None
39+
40+
41+
@pytest.fixture(scope='session')
42+
def boto_config(request):
43+
config = request.config.getoption('--boto-config')
44+
return json.loads(config) if config else None
45+
46+
47+
@pytest.fixture(scope='session')
48+
def sagemaker_session(sagemaker_client_config, sagemaker_runtime_config, boto_config):
49+
sagemaker_client = boto3.client('sagemaker', **sagemaker_client_config) if sagemaker_client_config else None
50+
runtime_client = boto3.client('sagemaker-runtime', **sagemaker_runtime_config) if sagemaker_runtime_config else None
51+
boto_session = boto3.Session(**boto_config) if boto_config else boto3.Session(region_name=DEFAULT_REGION)
52+
53+
return Session(boto_session=boto_session,
54+
sagemaker_client=sagemaker_client,
55+
sagemaker_runtime_client=runtime_client)
56+
1557

1658
@pytest.fixture(scope='module', params=["1.4", "1.4.1", "1.5", "1.5.0"])
1759
def tf_version(request):

tests/integ/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
# language governing permissions and limitations under the License.
1313
import logging
1414
import os
15+
1516
DATA_DIR = os.path.join(os.path.dirname(__file__), '..', 'data')
16-
REGION = 'us-west-2'
1717

1818
logging.getLogger('boto3').setLevel(logging.INFO)
1919
logging.getLogger('botocore').setLevel(logging.INFO)

tests/integ/test_byo_estimator.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,30 +13,36 @@
1313
import gzip
1414
import io
1515
import json
16-
import numpy as np
1716
import os
1817
import pickle
1918
import sys
2019

2120
import boto3
21+
import numpy as np
22+
import pytest
2223

2324
import sagemaker
24-
from sagemaker.estimator import Estimator
2525
from sagemaker.amazon.amazon_estimator import registry
2626
from sagemaker.amazon.common import write_numpy_to_dense_tensor
27+
from sagemaker.estimator import Estimator
2728
from sagemaker.utils import name_from_base
28-
from tests.integ import DATA_DIR, REGION
29+
from tests.integ import DATA_DIR
2930
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
3031

3132

33+
@pytest.fixture(scope='module')
34+
def region(sagemaker_session):
35+
return sagemaker_session.boto_session.region_name
36+
37+
3238
def fm_serializer(data):
3339
js = {'instances': []}
3440
for row in data:
3541
js['instances'].append({'features': row.tolist()})
3642
return json.dumps(js)
3743

3844

39-
def test_byo_estimator():
45+
def test_byo_estimator(sagemaker_session, region):
4046
"""Use Factorization Machines algorithm as an example here.
4147
4248
First we need to prepare data for training. We take standard data set, convert it to the
@@ -47,10 +53,9 @@ def test_byo_estimator():
4753
Default predictor is updated with json serializer and deserializer.
4854
4955
"""
50-
image_name = registry(REGION) + "/factorization-machines:1"
56+
image_name = registry(region) + "/factorization-machines:1"
5157

5258
with timeout(minutes=15):
53-
sagemaker_session = sagemaker.Session(boto_session=boto3.Session(region_name=REGION))
5459
data_path = os.path.join(DATA_DIR, 'one_p_mnist', 'mnist.pkl.gz')
5560
pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'}
5661

@@ -100,13 +105,12 @@ def test_byo_estimator():
100105
assert prediction['score'] is not None
101106

102107

103-
def test_async_byo_estimator():
104-
image_name = registry(REGION) + "/factorization-machines:1"
108+
def test_async_byo_estimator(sagemaker_session, region):
109+
image_name = registry(region) + "/factorization-machines:1"
105110
endpoint_name = name_from_base('byo')
106111
training_job_name = ""
107112

108113
with timeout(minutes=5):
109-
sagemaker_session = sagemaker.Session(boto_session=boto3.Session(region_name=REGION))
110114
data_path = os.path.join(DATA_DIR, 'one_p_mnist', 'mnist.pkl.gz')
111115
pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'}
112116

tests/integ/test_factorization_machines.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,24 +11,19 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
import gzip
14+
import os
1415
import pickle
1516
import sys
1617
import time
1718

18-
import boto3
19-
import os
20-
21-
import sagemaker
2219
from sagemaker import FactorizationMachines, FactorizationMachinesModel
2320
from sagemaker.utils import name_from_base
24-
from tests.integ import DATA_DIR, REGION
21+
from tests.integ import DATA_DIR
2522
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
2623

2724

28-
def test_factorization_machines():
29-
25+
def test_factorization_machines(sagemaker_session):
3026
with timeout(minutes=15):
31-
sagemaker_session = sagemaker.Session(boto_session=boto3.Session(region_name=REGION))
3227
data_path = os.path.join(DATA_DIR, 'one_p_mnist', 'mnist.pkl.gz')
3328
pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'}
3429

@@ -56,14 +51,11 @@ def test_factorization_machines():
5651
assert record.label["score"] is not None
5752

5853

59-
def test_async_factorization_machines():
60-
54+
def test_async_factorization_machines(sagemaker_session):
6155
training_job_name = ""
6256
endpoint_name = name_from_base('factorizationMachines')
63-
sagemaker_session = sagemaker.Session(boto_session=boto3.Session(region_name=REGION))
6457

6558
with timeout(minutes=5):
66-
6759
data_path = os.path.join(DATA_DIR, 'one_p_mnist', 'mnist.pkl.gz')
6860
pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'}
6961

tests/integ/test_kmeans.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,24 +11,19 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
import gzip
14+
import os
1415
import pickle
1516
import sys
16-
17-
import boto3
18-
import os
1917
import time
2018

21-
import sagemaker
2219
from sagemaker import KMeans, KMeansModel
2320
from sagemaker.utils import name_from_base
24-
from tests.integ import DATA_DIR, REGION
21+
from tests.integ import DATA_DIR
2522
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
2623

2724

28-
def test_kmeans():
29-
25+
def test_kmeans(sagemaker_session):
3026
with timeout(minutes=15):
31-
sagemaker_session = sagemaker.Session(boto_session=boto3.Session(region_name=REGION))
3227
data_path = os.path.join(DATA_DIR, 'one_p_mnist', 'mnist.pkl.gz')
3328
pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'}
3429

@@ -63,13 +58,11 @@ def test_kmeans():
6358
assert record.label["distance_to_cluster"] is not None
6459

6560

66-
def test_async_kmeans():
67-
61+
def test_async_kmeans(sagemaker_session):
6862
training_job_name = ""
6963
endpoint_name = name_from_base('kmeans')
7064

7165
with timeout(minutes=5):
72-
sagemaker_session = sagemaker.Session(boto_session=boto3.Session(region_name=REGION))
7366
data_path = os.path.join(DATA_DIR, 'one_p_mnist', 'mnist.pkl.gz')
7467
pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'}
7568

tests/integ/test_lda.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,20 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
import boto3
14-
import numpy as np
1513
import os
1614

17-
import sagemaker
15+
import numpy as np
16+
1817
from sagemaker import LDA, LDAModel
1918
from sagemaker.amazon.common import read_records
2019
from sagemaker.utils import name_from_base
21-
22-
from tests.integ import DATA_DIR, REGION
20+
from tests.integ import DATA_DIR
2321
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
2422
from tests.integ.record_set import prepare_record_set_from_local_files
2523

2624

27-
def test_lda():
28-
25+
def test_lda(sagemaker_session):
2926
with timeout(minutes=15):
30-
sagemaker_session = sagemaker.Session(boto_session=boto3.Session(region_name=REGION))
3127
data_path = os.path.join(DATA_DIR, 'lda')
3228
data_filename = 'nips-train_1.pbr'
3329

tests/integ/test_linear_learner.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,17 @@
1515
import pickle
1616
import sys
1717
import time
18-
import pytest # noqa
19-
import boto3
18+
2019
import numpy as np
2120

22-
import sagemaker
2321
from sagemaker.amazon.linear_learner import LinearLearner, LinearLearnerModel
2422
from sagemaker.utils import name_from_base, sagemaker_timestamp
25-
26-
from tests.integ import DATA_DIR, REGION
23+
from tests.integ import DATA_DIR
2724
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
2825

2926

30-
def test_linear_learner():
27+
def test_linear_learner(sagemaker_session):
3128
with timeout(minutes=15):
32-
sagemaker_session = sagemaker.Session(boto_session=boto3.Session(region_name=REGION))
3329
data_path = os.path.join(DATA_DIR, 'one_p_mnist', 'mnist.pkl.gz')
3430
pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'}
3531

@@ -87,14 +83,11 @@ def test_linear_learner():
8783
assert record.label["score"] is not None
8884

8985

90-
def test_async_linear_learner():
91-
86+
def test_async_linear_learner(sagemaker_session):
9287
training_job_name = ""
9388
endpoint_name = 'test-linear-learner-async-{}'.format(sagemaker_timestamp())
94-
sagemaker_session = sagemaker.Session(boto_session=boto3.Session(region_name=REGION))
9589

9690
with timeout(minutes=5):
97-
9891
data_path = os.path.join(DATA_DIR, 'one_p_mnist', 'mnist.pkl.gz')
9992
pickle_args = {} if sys.version_info.major == 2 else {'encoding': 'latin1'}
10093

tests/integ/test_mxnet_train.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,16 @@
1313
import os
1414
import time
1515

16-
import boto3
1716
import numpy
1817
import pytest
19-
from sagemaker import Session
18+
2019
from sagemaker.mxnet.estimator import MXNet
2120
from sagemaker.mxnet.model import MXNetModel
2221
from sagemaker.utils import sagemaker_timestamp
23-
24-
from tests.integ import DATA_DIR, REGION
22+
from tests.integ import DATA_DIR
2523
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
2624

2725

28-
@pytest.fixture(scope='module')
29-
def sagemaker_session():
30-
return Session(boto_session=boto3.Session(region_name=REGION))
31-
32-
3326
@pytest.fixture(scope='module')
3427
def mxnet_training_job(sagemaker_session, mxnet_full_version):
3528
with timeout(minutes=15):

tests/integ/test_ntm.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,20 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
import boto3
14-
import numpy as np
1513
import os
1614

17-
import sagemaker
15+
import numpy as np
16+
1817
from sagemaker import NTM, NTMModel
1918
from sagemaker.amazon.common import read_records
2019
from sagemaker.utils import name_from_base
21-
22-
from tests.integ import DATA_DIR, REGION
20+
from tests.integ import DATA_DIR
2321
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
2422
from tests.integ.record_set import prepare_record_set_from_local_files
2523

2624

27-
def test_ntm():
28-
25+
def test_ntm(sagemaker_session):
2926
with timeout(minutes=15):
30-
sagemaker_session = sagemaker.Session(boto_session=boto3.Session(region_name=REGION))
3127
data_path = os.path.join(DATA_DIR, 'ntm')
3228
data_filename = 'nips-train_1.pbr'
3329

0 commit comments

Comments
 (0)