Skip to content

Commit 7534416

Browse files
authored
Merge branch 'master' into framework-versioning
2 parents 5c12d5a + 56352f3 commit 7534416

File tree

9 files changed

+47
-3
lines changed

9 files changed

+47
-3
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def read_version():
6666
"pytest<6.1.0",
6767
"pytest-cov",
6868
"pytest-rerunfailures",
69+
"pytest-timeout",
6970
"pytest-xdist",
7071
"mock",
7172
"contextlib2",

src/sagemaker/image_uris.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ def retrieve(
9898
"mxnet-1.8.0-gpu-py37": "cu110-ubuntu16.04",
9999
"pytorch-1.6-gpu-py36": "cu110-ubuntu18.04-v3",
100100
"pytorch-1.6.0-gpu-py36": "cu110-ubuntu18.04",
101+
"pytorch-1.6-gpu-py3": "cu110-ubuntu18.04-v3",
102+
"pytorch-1.6.0-gpu-py3": "cu110-ubuntu18.04",
101103
}
102104
key = "-".join([framework, tag])
103105
if key in container_versions:

tests/integ/sagemaker/lineage/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,8 @@ def artifact_obj_with_association(sagemaker_session, artifact_obj):
155155
@pytest.fixture
156156
def trial_component_obj(sagemaker_session):
157157
trial_component_obj = trial_component.TrialComponent.create(
158-
trial_component_name=name(), sagemaker_boto_client=sagemaker_session.sagemaker_client
158+
trial_component_name=name(),
159+
sagemaker_boto_client=sagemaker_session.sagemaker_client,
159160
)
160161
yield trial_component_obj
161162
time.sleep(0.5)

tests/integ/sagemaker/lineage/test_action.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515

1616
import datetime
1717
import logging
18+
import time
19+
20+
import pytest
1821

1922
from sagemaker.lineage import action
2023

@@ -80,6 +83,7 @@ def test_list(action_objs, sagemaker_session):
8083
assert action_names
8184

8285

86+
@pytest.mark.timeout(30)
8387
def test_tag(action_obj, sagemaker_session):
8488
tag = {"Key": "foo", "Value": "bar"}
8589
action_obj.set_tag(tag)
@@ -90,12 +94,14 @@ def test_tag(action_obj, sagemaker_session):
9094
)["Tags"]
9195
if actual_tags:
9296
break
97+
time.sleep(5)
9398
# When sagemaker-client-config endpoint-url is passed as argument to hit some endpoints,
9499
# length of actual tags will be greater than 1
95100
assert len(actual_tags) > 0
96101
assert actual_tags[0] == tag
97102

98103

104+
@pytest.mark.timeout(30)
99105
def test_tags(action_obj, sagemaker_session):
100106
tags = [{"Key": "foo1", "Value": "bar1"}]
101107
action_obj.set_tags(tags)
@@ -106,6 +112,7 @@ def test_tags(action_obj, sagemaker_session):
106112
)["Tags"]
107113
if actual_tags:
108114
break
115+
time.sleep(5)
109116
# When sagemaker-client-config endpoint-url is passed as argument to hit some endpoints,
110117
# length of actual tags will be greater than 1
111118
assert len(actual_tags) > 0

tests/integ/sagemaker/lineage/test_artifact.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import logging
1818
import time
1919

20+
import pytest
21+
2022
from sagemaker.lineage import artifact
2123

2224

@@ -111,6 +113,7 @@ def test_downstream_trials(trial_associated_artifact, trial_obj, sagemaker_sessi
111113
assert trial_obj.trial_name in trials
112114

113115

116+
@pytest.mark.timeout(30)
114117
def test_tag(artifact_obj, sagemaker_session):
115118
tag = {"Key": "foo", "Value": "bar"}
116119
artifact_obj.set_tag(tag)
@@ -121,12 +124,14 @@ def test_tag(artifact_obj, sagemaker_session):
121124
)["Tags"]
122125
if actual_tags:
123126
break
127+
time.sleep(5)
124128
# When sagemaker-client-config endpoint-url is passed as argument to hit some endpoints,
125129
# length of actual tags will be greater than 1
126130
assert len(actual_tags) > 0
127131
assert actual_tags[0] == tag
128132

129133

134+
@pytest.mark.timeout(30)
130135
def test_tags(artifact_obj, sagemaker_session):
131136
tags = [{"Key": "foo1", "Value": "bar1"}]
132137
artifact_obj.set_tags(tags)
@@ -137,6 +142,7 @@ def test_tags(artifact_obj, sagemaker_session):
137142
)["Tags"]
138143
if actual_tags:
139144
break
145+
time.sleep(5)
140146
# When sagemaker-client-config endpoint-url is passed as argument to hit some endpoints,
141147
# length of actual tags will be greater than 1
142148
assert len(actual_tags) > 0

tests/integ/sagemaker/lineage/test_association.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import datetime
1717
import time
1818

19+
import pytest
20+
1921
from sagemaker.lineage import association
2022

2123

@@ -55,6 +57,7 @@ def test_list(association_objs, sagemaker_session):
5557
assert association_keys_listed
5658

5759

60+
@pytest.mark.timeout(30)
5861
def test_set_tag(association_obj, sagemaker_session):
5962
tag = {"Key": "foo", "Value": "bar"}
6063
association_obj.set_tag(tag)
@@ -65,13 +68,14 @@ def test_set_tag(association_obj, sagemaker_session):
6568
)["Tags"]
6669
if actual_tags:
6770
break
68-
time.sleep(1)
71+
time.sleep(5)
6972
# When sagemaker-client-config endpoint-url is passed as argument to hit some endpoints,
7073
# length of actual tags will be greater than 1
7174
assert len(actual_tags) > 0
7275
assert actual_tags[0] == tag
7376

7477

78+
@pytest.mark.timeout(30)
7579
def test_tags(association_obj, sagemaker_session):
7680
tags = [{"Key": "foo1", "Value": "bar1"}]
7781
association_obj.set_tags(tags)
@@ -82,7 +86,7 @@ def test_tags(association_obj, sagemaker_session):
8286
)["Tags"]
8387
if actual_tags:
8488
break
85-
time.sleep(1)
89+
time.sleep(5)
8690
# When sagemaker-client-config endpoint-url is passed as argument to hit some endpoints,
8791
# length of actual tags will be greater than 1
8892
assert len(actual_tags) > 0

tests/integ/sagemaker/lineage/test_context.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515

1616
import datetime
1717
import logging
18+
import time
19+
20+
import pytest
1821

1922
from sagemaker.lineage import context
2023

@@ -78,6 +81,7 @@ def test_list(context_objs, sagemaker_session):
7881
assert context_names
7982

8083

84+
@pytest.mark.timeout(30)
8185
def test_tag(context_obj, sagemaker_session):
8286
tag = {"Key": "foo", "Value": "bar"}
8387
context_obj.set_tag(tag)
@@ -88,12 +92,14 @@ def test_tag(context_obj, sagemaker_session):
8892
)["Tags"]
8993
if actual_tags:
9094
break
95+
time.sleep(5)
9196
# When sagemaker-client-config endpoint-url is passed as argument to hit some endpoints,
9297
# length of actual tags will be greater than 1
9398
assert len(actual_tags) > 0
9499
assert actual_tags[0] == tag
95100

96101

102+
@pytest.mark.timeout(30)
97103
def test_tags(context_obj, sagemaker_session):
98104
tags = [{"Key": "foo1", "Value": "bar1"}]
99105
context_obj.set_tags(tags)
@@ -104,6 +110,7 @@ def test_tags(context_obj, sagemaker_session):
104110
)["Tags"]
105111
if actual_tags:
106112
break
113+
time.sleep(5)
107114
# When sagemaker-client-config endpoint-url is passed as argument to hit some endpoints,
108115
# length of actual tags will be greater than 1
109116
assert len(actual_tags) > 0

tests/unit/sagemaker/image_uris/test_retrieve.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,21 @@ def test_retrieve_auto_selected_container_version():
553553
)
554554

555555

556+
def test_retrieve_pytorch_container_version():
557+
uri = image_uris.retrieve(
558+
framework="pytorch",
559+
region="us-west-2",
560+
version="1.6",
561+
py_version="py3",
562+
instance_type="ml.p4d.24xlarge",
563+
image_scope="training",
564+
)
565+
assert (
566+
"763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.6-gpu-py3-cu110-ubuntu18.04-v3"
567+
== uri
568+
)
569+
570+
556571
@patch("sagemaker.image_uris.config_for_framework", return_value=BASE_CONFIG)
557572
def test_retrieve_unsupported_processor_type(config_for_framework):
558573
with pytest.raises(ValueError) as e:

tox.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ markers =
5555
canary_quick
5656
cron
5757
local_mode
58+
timeout: mark a test as a timeout.
5859

5960
[testenv]
6061
passenv =

0 commit comments

Comments
 (0)