Skip to content

Commit 429b8d2

Browse files
committed
fix: prevent race condition in vpc tests
1 parent 91e658c commit 429b8d2

File tree

4 files changed

+48
-34
lines changed

4 files changed

+48
-34
lines changed

tests/integ/local_mode_utils.py renamed to tests/integ/lock.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,21 @@
1515
import fcntl
1616
import os
1717
import time
18+
import tempfile
1819
from contextlib import contextmanager
1920

20-
import tests.integ
21-
22-
LOCK_PATH = os.path.join(tests.integ.DATA_DIR, 'local_mode_lock')
21+
DEFAULT_LOCK_PATH = os.path.join(tempfile.gettempdir(), 'sagemaker_test_lock')
2322

2423

2524
@contextmanager
26-
def lock():
27-
# Since Local Mode uses the same port for serving, we need a lock in order
28-
# to allow concurrent test execution.
29-
local_mode_lock_fd = open(LOCK_PATH, 'w')
30-
local_mode_lock = local_mode_lock_fd.fileno()
25+
def lock(path=DEFAULT_LOCK_PATH):
26+
f = open(path, 'w')
27+
fd = f.fileno()
3128

32-
fcntl.lockf(local_mode_lock, fcntl.LOCK_EX)
29+
fcntl.lockf(fd, fcntl.LOCK_EX)
3330

3431
try:
3532
yield
3633
finally:
3734
time.sleep(5)
38-
fcntl.lockf(local_mode_lock, fcntl.LOCK_UN)
35+
fcntl.lockf(fd, fcntl.LOCK_UN)

tests/integ/test_local_mode.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,17 @@
1818
import boto3
1919
import numpy
2020
import pytest
21-
import tests.integ.local_mode_utils as local_mode_utils
21+
import tempfile
22+
23+
import tests.integ.lock as lock
2224
from tests.integ import DATA_DIR, PYTHON_VERSION
2325
from tests.integ.timeout import timeout
2426

2527
from sagemaker.local import LocalSession, LocalSagemakerRuntimeClient, LocalSagemakerClient
2628
from sagemaker.mxnet import MXNet
2729
from sagemaker.tensorflow import TensorFlow
2830

31+
LOCK_PATH = os.path.join(tempfile.gettempdir(), 'sagemaker_test_local_mode_lock')
2932
DATA_PATH = os.path.join(DATA_DIR, 'iris', 'data')
3033
DEFAULT_REGION = 'us-west-2'
3134

@@ -101,7 +104,7 @@ def test_tf_local_mode(tf_full_version, sagemaker_local_session):
101104
print('job succeeded: {}'.format(estimator.latest_training_job.name))
102105

103106
endpoint_name = estimator.latest_training_job.name
104-
with local_mode_utils.lock():
107+
with lock.lock(LOCK_PATH):
105108
try:
106109
json_predictor = estimator.deploy(initial_instance_count=1,
107110
instance_type='local',
@@ -140,7 +143,7 @@ def test_tf_distributed_local_mode(sagemaker_local_session):
140143

141144
endpoint_name = estimator.latest_training_job.name
142145

143-
with local_mode_utils.lock():
146+
with lock.lock(LOCK_PATH):
144147
try:
145148
json_predictor = estimator.deploy(initial_instance_count=1,
146149
instance_type='local',
@@ -178,7 +181,7 @@ def test_tf_local_data(sagemaker_local_session):
178181
print('job succeeded: {}'.format(estimator.latest_training_job.name))
179182

180183
endpoint_name = estimator.latest_training_job.name
181-
with local_mode_utils.lock():
184+
with lock.lock(LOCK_PATH):
182185
try:
183186
json_predictor = estimator.deploy(initial_instance_count=1,
184187
instance_type='local',
@@ -217,7 +220,7 @@ def test_tf_local_data_local_script():
217220
print('job succeeded: {}'.format(estimator.latest_training_job.name))
218221

219222
endpoint_name = estimator.latest_training_job.name
220-
with local_mode_utils.lock():
223+
with lock.lock(LOCK_PATH):
221224
try:
222225
json_predictor = estimator.deploy(initial_instance_count=1,
223226
instance_type='local',
@@ -241,7 +244,7 @@ def test_local_mode_serving_from_s3_model(sagemaker_local_session, mxnet_model,
241244
s3_model.sagemaker_session = sagemaker_local_session
242245

243246
predictor = None
244-
with local_mode_utils.lock():
247+
with lock.lock(LOCK_PATH):
245248
try:
246249
predictor = s3_model.deploy(initial_instance_count=1, instance_type='local')
247250
data = numpy.zeros(shape=(1, 1, 28, 28))
@@ -255,7 +258,7 @@ def test_local_mode_serving_from_s3_model(sagemaker_local_session, mxnet_model,
255258
def test_local_mode_serving_from_local_model(tmpdir, sagemaker_local_session, mxnet_model):
256259
predictor = None
257260

258-
with local_mode_utils.lock():
261+
with lock.lock(LOCK_PATH):
259262
try:
260263
path = 'file://%s' % (str(tmpdir))
261264
model = mxnet_model(path)
@@ -285,7 +288,7 @@ def test_mxnet_local_mode(sagemaker_local_session, mxnet_full_version):
285288
mx.fit({'train': train_input, 'test': test_input})
286289
endpoint_name = mx.latest_training_job.name
287290

288-
with local_mode_utils.lock():
291+
with lock.lock(LOCK_PATH):
289292
try:
290293
predictor = mx.deploy(1, 'local', endpoint_name=endpoint_name)
291294
data = numpy.zeros(shape=(1, 1, 28, 28))
@@ -310,7 +313,7 @@ def test_mxnet_local_data_local_script(mxnet_full_version):
310313
mx.fit({'train': train_input, 'test': test_input})
311314
endpoint_name = mx.latest_training_job.name
312315

313-
with local_mode_utils.lock():
316+
with lock.lock(LOCK_PATH):
314317
try:
315318
predictor = mx.deploy(1, 'local', endpoint_name=endpoint_name)
316319
data = numpy.zeros(shape=(1, 1, 28, 28))
@@ -365,7 +368,7 @@ def test_local_transform_mxnet(sagemaker_local_session, tmpdir, mxnet_full_versi
365368
transformer = mx.transformer(1, 'local', assemble_with='Line', max_payload=1,
366369
strategy='SingleRecord', output_path=output_path)
367370

368-
with local_mode_utils.lock():
371+
with lock.lock(LOCK_PATH):
369372
transformer.transform(transform_input, content_type='text/csv', split_type='Line')
370373
transformer.wait()
371374

tests/integ/test_source_dirs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import pytest
1818

19-
import tests.integ.local_mode_utils as local_mode_utils
19+
import tests.integ.lock as lock
2020
from tests.integ import DATA_DIR, PYTHON_VERSION
2121

2222
from sagemaker.pytorch.estimator import PyTorch
@@ -37,7 +37,7 @@ def test_source_dirs(tmpdir, sagemaker_local_session):
3737
sagemaker_session=sagemaker_local_session)
3838
estimator.fit()
3939

40-
with local_mode_utils.lock():
40+
with lock.lock():
4141
try:
4242
predictor = estimator.deploy(initial_instance_count=1, instance_type='local')
4343
predict_response = predictor.predict([7])

tests/integ/vpc_test_utils.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,13 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import os
16+
import tempfile
17+
18+
import tests.integ.lock as lock
19+
1520
VPC_NAME = 'sagemaker-python-sdk-test-vpc'
21+
LOCK_PATH = os.path.join(tempfile.gettempdir(), 'sagemaker_test_vpc_lock')
1622

1723

1824
def _get_subnet_ids_by_name(ec2_client, name):
@@ -61,20 +67,24 @@ def _create_vpc_with_name(ec2_client, region, name):
6167
AvailabilityZone=(region + 'b'))['Subnet']['SubnetId']
6268
print('created subnet: {}'.format(subnet_id_b))
6369

64-
s3_service = [s for s in ec2_client.describe_vpc_endpoint_services()['ServiceNames'] if s.endswith('s3')][0]
70+
s3_service = \
71+
[s for s in ec2_client.describe_vpc_endpoint_services()['ServiceNames'] if
72+
s.endswith('s3')][0]
6573
ec2_client.create_vpc_endpoint(VpcId=vpc_id, ServiceName=s3_service,
6674
RouteTableIds=[_get_route_table_id(ec2_client, vpc_id)])
6775
print('created s3 vpc endpoint')
6876

69-
security_group_id = ec2_client.create_security_group(VpcId=vpc_id, GroupName=name, Description=name)['GroupId']
77+
security_group_id = \
78+
ec2_client.create_security_group(VpcId=vpc_id, GroupName=name, Description=name)['GroupId']
7079
print('created security group: {}'.format(security_group_id))
7180

7281
# multi-host vpc jobs require communication among hosts
7382
ec2_client.authorize_security_group_ingress(GroupId=security_group_id,
7483
IpPermissions=[{'IpProtocol': 'tcp',
7584
'FromPort': 0,
7685
'ToPort': 65535,
77-
'UserIdGroupPairs': [{'GroupId': security_group_id}]}])
86+
'UserIdGroupPairs': [{
87+
'GroupId': security_group_id}]}])
7888

7989
ec2_client.create_tags(Resources=[vpc_id, subnet_id_a, subnet_id_b, security_group_id],
8090
Tags=[{'Key': 'Name', 'Value': name}])
@@ -83,23 +93,27 @@ def _create_vpc_with_name(ec2_client, region, name):
8393

8494

8595
def get_or_create_vpc_resources(ec2_client, region, name=VPC_NAME):
86-
if _vpc_exists(ec2_client, name):
87-
print('using existing vpc: {}'.format(name))
88-
return _get_subnet_ids_by_name(ec2_client, name), _get_security_id_by_name(ec2_client, name)
89-
else:
90-
print('creating new vpc: {}'.format(name))
91-
return _create_vpc_with_name(ec2_client, region, name)
96+
with lock.lock(LOCK_PATH):
97+
if _vpc_exists(ec2_client, name):
98+
print('using existing vpc: {}'.format(name))
99+
return _get_subnet_ids_by_name(ec2_client, name), _get_security_id_by_name(ec2_client,
100+
name)
101+
else:
102+
print('creating new vpc: {}'.format(name))
103+
return _create_vpc_with_name(ec2_client, region, name)
92104

93105

94106
def setup_security_group_for_encryption(ec2_client, security_group_id):
95107
sg_desc = ec2_client.describe_security_groups(GroupIds=[security_group_id])
96108
ingress_perms = sg_desc['SecurityGroups'][0]['IpPermissions']
97109
if len(ingress_perms) == 1:
98-
ec2_client.\
110+
ec2_client. \
99111
authorize_security_group_ingress(GroupId=security_group_id,
100112
IpPermissions=[{'IpProtocol': '50',
101-
'UserIdGroupPairs': [{'GroupId': security_group_id}]},
113+
'UserIdGroupPairs': [
114+
{'GroupId': security_group_id}]},
102115
{'IpProtocol': 'udp',
103116
'FromPort': 500,
104117
'ToPort': 500,
105-
'UserIdGroupPairs': [{'GroupId': security_group_id}]}])
118+
'UserIdGroupPairs': [
119+
{'GroupId': security_group_id}]}])

0 commit comments

Comments
 (0)