Skip to content

Commit fe511b8

Browse files
authored
fix: prevent race condition in vpc tests (#863)
1 parent 91e658c commit fe511b8

File tree

5 files changed

+56
-35
lines changed

5 files changed

+56
-35
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,4 @@ venv/
2525
*~
2626
.pytest_cache/
2727
*.swp
28-
tests/data/local_mode_lock
28+
.docker/

tests/integ/local_mode_utils.py renamed to tests/integ/lock.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,25 @@
1515
import fcntl
1616
import os
1717
import time
18+
import tempfile
1819
from contextlib import contextmanager
1920

20-
import tests.integ
21-
22-
LOCK_PATH = os.path.join(tests.integ.DATA_DIR, 'local_mode_lock')
21+
DEFAULT_LOCK_PATH = os.path.join(tempfile.gettempdir(), 'sagemaker_test_lock')
2322

2423

2524
@contextmanager
26-
def lock():
27-
# Since Local Mode uses the same port for serving, we need a lock in order
28-
# to allow concurrent test execution.
29-
local_mode_lock_fd = open(LOCK_PATH, 'w')
30-
local_mode_lock = local_mode_lock_fd.fileno()
25+
def lock(path=DEFAULT_LOCK_PATH):
26+
"""Create a file lock to control concurrent test execution. Certain tests or
27+
test operations need to limit concurrency to work reliably. Examples include
28+
local mode endpoint tests and vpc creation tests.
29+
"""
30+
f = open(path, 'w')
31+
fd = f.fileno()
3132

32-
fcntl.lockf(local_mode_lock, fcntl.LOCK_EX)
33+
fcntl.lockf(fd, fcntl.LOCK_EX)
3334

3435
try:
3536
yield
3637
finally:
3738
time.sleep(5)
38-
fcntl.lockf(local_mode_lock, fcntl.LOCK_UN)
39+
fcntl.lockf(fd, fcntl.LOCK_UN)

tests/integ/test_local_mode.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,18 @@
1818
import boto3
1919
import numpy
2020
import pytest
21-
import tests.integ.local_mode_utils as local_mode_utils
21+
import tempfile
22+
23+
import tests.integ.lock as lock
2224
from tests.integ import DATA_DIR, PYTHON_VERSION
2325
from tests.integ.timeout import timeout
2426

2527
from sagemaker.local import LocalSession, LocalSagemakerRuntimeClient, LocalSagemakerClient
2628
from sagemaker.mxnet import MXNet
2729
from sagemaker.tensorflow import TensorFlow
2830

31+
# endpoint tests all use the same port, so we use this lock to prevent concurrent execution
32+
LOCK_PATH = os.path.join(tempfile.gettempdir(), 'sagemaker_test_local_mode_lock')
2933
DATA_PATH = os.path.join(DATA_DIR, 'iris', 'data')
3034
DEFAULT_REGION = 'us-west-2'
3135

@@ -101,7 +105,7 @@ def test_tf_local_mode(tf_full_version, sagemaker_local_session):
101105
print('job succeeded: {}'.format(estimator.latest_training_job.name))
102106

103107
endpoint_name = estimator.latest_training_job.name
104-
with local_mode_utils.lock():
108+
with lock.lock(LOCK_PATH):
105109
try:
106110
json_predictor = estimator.deploy(initial_instance_count=1,
107111
instance_type='local',
@@ -140,7 +144,7 @@ def test_tf_distributed_local_mode(sagemaker_local_session):
140144

141145
endpoint_name = estimator.latest_training_job.name
142146

143-
with local_mode_utils.lock():
147+
with lock.lock(LOCK_PATH):
144148
try:
145149
json_predictor = estimator.deploy(initial_instance_count=1,
146150
instance_type='local',
@@ -178,7 +182,7 @@ def test_tf_local_data(sagemaker_local_session):
178182
print('job succeeded: {}'.format(estimator.latest_training_job.name))
179183

180184
endpoint_name = estimator.latest_training_job.name
181-
with local_mode_utils.lock():
185+
with lock.lock(LOCK_PATH):
182186
try:
183187
json_predictor = estimator.deploy(initial_instance_count=1,
184188
instance_type='local',
@@ -217,7 +221,7 @@ def test_tf_local_data_local_script():
217221
print('job succeeded: {}'.format(estimator.latest_training_job.name))
218222

219223
endpoint_name = estimator.latest_training_job.name
220-
with local_mode_utils.lock():
224+
with lock.lock(LOCK_PATH):
221225
try:
222226
json_predictor = estimator.deploy(initial_instance_count=1,
223227
instance_type='local',
@@ -241,7 +245,7 @@ def test_local_mode_serving_from_s3_model(sagemaker_local_session, mxnet_model,
241245
s3_model.sagemaker_session = sagemaker_local_session
242246

243247
predictor = None
244-
with local_mode_utils.lock():
248+
with lock.lock(LOCK_PATH):
245249
try:
246250
predictor = s3_model.deploy(initial_instance_count=1, instance_type='local')
247251
data = numpy.zeros(shape=(1, 1, 28, 28))
@@ -255,7 +259,7 @@ def test_local_mode_serving_from_s3_model(sagemaker_local_session, mxnet_model,
255259
def test_local_mode_serving_from_local_model(tmpdir, sagemaker_local_session, mxnet_model):
256260
predictor = None
257261

258-
with local_mode_utils.lock():
262+
with lock.lock(LOCK_PATH):
259263
try:
260264
path = 'file://%s' % (str(tmpdir))
261265
model = mxnet_model(path)
@@ -285,7 +289,7 @@ def test_mxnet_local_mode(sagemaker_local_session, mxnet_full_version):
285289
mx.fit({'train': train_input, 'test': test_input})
286290
endpoint_name = mx.latest_training_job.name
287291

288-
with local_mode_utils.lock():
292+
with lock.lock(LOCK_PATH):
289293
try:
290294
predictor = mx.deploy(1, 'local', endpoint_name=endpoint_name)
291295
data = numpy.zeros(shape=(1, 1, 28, 28))
@@ -310,7 +314,7 @@ def test_mxnet_local_data_local_script(mxnet_full_version):
310314
mx.fit({'train': train_input, 'test': test_input})
311315
endpoint_name = mx.latest_training_job.name
312316

313-
with local_mode_utils.lock():
317+
with lock.lock(LOCK_PATH):
314318
try:
315319
predictor = mx.deploy(1, 'local', endpoint_name=endpoint_name)
316320
data = numpy.zeros(shape=(1, 1, 28, 28))
@@ -365,7 +369,7 @@ def test_local_transform_mxnet(sagemaker_local_session, tmpdir, mxnet_full_versi
365369
transformer = mx.transformer(1, 'local', assemble_with='Line', max_payload=1,
366370
strategy='SingleRecord', output_path=output_path)
367371

368-
with local_mode_utils.lock():
372+
with lock.lock(LOCK_PATH):
369373
transformer.transform(transform_input, content_type='text/csv', split_type='Line')
370374
transformer.wait()
371375

tests/integ/test_source_dirs.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import pytest
1818

19-
import tests.integ.local_mode_utils as local_mode_utils
19+
import tests.integ.lock as lock
2020
from tests.integ import DATA_DIR, PYTHON_VERSION
2121

2222
from sagemaker.pytorch.estimator import PyTorch
@@ -37,7 +37,8 @@ def test_source_dirs(tmpdir, sagemaker_local_session):
3737
sagemaker_session=sagemaker_local_session)
3838
estimator.fit()
3939

40-
with local_mode_utils.lock():
40+
# endpoint tests all use the same port, so we use this lock to prevent concurrent execution
41+
with lock.lock():
4142
try:
4243
predictor = estimator.deploy(initial_instance_count=1, instance_type='local')
4344
predict_response = predictor.predict([7])

tests/integ/vpc_test_utils.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,13 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import os
16+
import tempfile
17+
18+
import tests.integ.lock as lock
19+
1520
VPC_NAME = 'sagemaker-python-sdk-test-vpc'
21+
LOCK_PATH = os.path.join(tempfile.gettempdir(), 'sagemaker_test_vpc_lock')
1622

1723

1824
def _get_subnet_ids_by_name(ec2_client, name):
@@ -61,20 +67,24 @@ def _create_vpc_with_name(ec2_client, region, name):
6167
AvailabilityZone=(region + 'b'))['Subnet']['SubnetId']
6268
print('created subnet: {}'.format(subnet_id_b))
6369

64-
s3_service = [s for s in ec2_client.describe_vpc_endpoint_services()['ServiceNames'] if s.endswith('s3')][0]
70+
s3_service = \
71+
[s for s in ec2_client.describe_vpc_endpoint_services()['ServiceNames'] if
72+
s.endswith('s3')][0]
6573
ec2_client.create_vpc_endpoint(VpcId=vpc_id, ServiceName=s3_service,
6674
RouteTableIds=[_get_route_table_id(ec2_client, vpc_id)])
6775
print('created s3 vpc endpoint')
6876

69-
security_group_id = ec2_client.create_security_group(VpcId=vpc_id, GroupName=name, Description=name)['GroupId']
77+
security_group_id = \
78+
ec2_client.create_security_group(VpcId=vpc_id, GroupName=name, Description=name)['GroupId']
7079
print('created security group: {}'.format(security_group_id))
7180

7281
# multi-host vpc jobs require communication among hosts
7382
ec2_client.authorize_security_group_ingress(GroupId=security_group_id,
7483
IpPermissions=[{'IpProtocol': 'tcp',
7584
'FromPort': 0,
7685
'ToPort': 65535,
77-
'UserIdGroupPairs': [{'GroupId': security_group_id}]}])
86+
'UserIdGroupPairs': [{
87+
'GroupId': security_group_id}]}])
7888

7989
ec2_client.create_tags(Resources=[vpc_id, subnet_id_a, subnet_id_b, security_group_id],
8090
Tags=[{'Key': 'Name', 'Value': name}])
@@ -83,23 +93,28 @@ def _create_vpc_with_name(ec2_client, region, name):
8393

8494

8595
def get_or_create_vpc_resources(ec2_client, region, name=VPC_NAME):
86-
if _vpc_exists(ec2_client, name):
87-
print('using existing vpc: {}'.format(name))
88-
return _get_subnet_ids_by_name(ec2_client, name), _get_security_id_by_name(ec2_client, name)
89-
else:
90-
print('creating new vpc: {}'.format(name))
91-
return _create_vpc_with_name(ec2_client, region, name)
96+
# use lock to prevent race condition when tests are running concurrently
97+
with lock.lock(LOCK_PATH):
98+
if _vpc_exists(ec2_client, name):
99+
print('using existing vpc: {}'.format(name))
100+
return _get_subnet_ids_by_name(ec2_client, name), _get_security_id_by_name(ec2_client,
101+
name)
102+
else:
103+
print('creating new vpc: {}'.format(name))
104+
return _create_vpc_with_name(ec2_client, region, name)
92105

93106

94107
def setup_security_group_for_encryption(ec2_client, security_group_id):
95108
sg_desc = ec2_client.describe_security_groups(GroupIds=[security_group_id])
96109
ingress_perms = sg_desc['SecurityGroups'][0]['IpPermissions']
97110
if len(ingress_perms) == 1:
98-
ec2_client.\
111+
ec2_client. \
99112
authorize_security_group_ingress(GroupId=security_group_id,
100113
IpPermissions=[{'IpProtocol': '50',
101-
'UserIdGroupPairs': [{'GroupId': security_group_id}]},
114+
'UserIdGroupPairs': [
115+
{'GroupId': security_group_id}]},
102116
{'IpProtocol': 'udp',
103117
'FromPort': 500,
104118
'ToPort': 500,
105-
'UserIdGroupPairs': [{'GroupId': security_group_id}]}])
119+
'UserIdGroupPairs': [
120+
{'GroupId': security_group_id}]}])

0 commit comments

Comments
 (0)