Skip to content

fix: prevent race condition in vpc tests #863

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 19, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ venv/
*~
.pytest_cache/
*.swp
tests/data/local_mode_lock
.docker/
21 changes: 11 additions & 10 deletions tests/integ/local_mode_utils.py → tests/integ/lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,25 @@
import fcntl
import os
import time
import tempfile
from contextlib import contextmanager

import tests.integ

LOCK_PATH = os.path.join(tests.integ.DATA_DIR, 'local_mode_lock')
DEFAULT_LOCK_PATH = os.path.join(tempfile.gettempdir(), 'sagemaker_test_lock')


@contextmanager
def lock():
# Since Local Mode uses the same port for serving, we need a lock in order
# to allow concurrent test execution.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it'd be good to have an explanation of why these locks are needed somewhere

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably at (multiple) points of use. or would a generic "lock for tests that need locks" comment work?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe a generic comment with specific callouts for the most common reasons (e.g. local mode)?

local_mode_lock_fd = open(LOCK_PATH, 'w')
local_mode_lock = local_mode_lock_fd.fileno()
def lock(path=DEFAULT_LOCK_PATH):
"""Create a file lock to control concurrent test execution. Certain tests or
test operations need to limit concurrency to work reliably. Examples include
local mode endpoint tests and vpc creation tests.
"""
f = open(path, 'w')
fd = f.fileno()

fcntl.lockf(local_mode_lock, fcntl.LOCK_EX)
fcntl.lockf(fd, fcntl.LOCK_EX)

try:
yield
finally:
time.sleep(5)
fcntl.lockf(local_mode_lock, fcntl.LOCK_UN)
fcntl.lockf(fd, fcntl.LOCK_UN)
24 changes: 14 additions & 10 deletions tests/integ/test_local_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,18 @@
import boto3
import numpy
import pytest
import tests.integ.local_mode_utils as local_mode_utils
import tempfile

import tests.integ.lock as lock
from tests.integ import DATA_DIR, PYTHON_VERSION
from tests.integ.timeout import timeout

from sagemaker.local import LocalSession, LocalSagemakerRuntimeClient, LocalSagemakerClient
from sagemaker.mxnet import MXNet
from sagemaker.tensorflow import TensorFlow

# endpoint tests all use the same port, so we use this lock to prevent concurrent execution
LOCK_PATH = os.path.join(tempfile.gettempdir(), 'sagemaker_test_local_mode_lock')
DATA_PATH = os.path.join(DATA_DIR, 'iris', 'data')
DEFAULT_REGION = 'us-west-2'

Expand Down Expand Up @@ -101,7 +105,7 @@ def test_tf_local_mode(tf_full_version, sagemaker_local_session):
print('job succeeded: {}'.format(estimator.latest_training_job.name))

endpoint_name = estimator.latest_training_job.name
with local_mode_utils.lock():
with lock.lock(LOCK_PATH):
try:
json_predictor = estimator.deploy(initial_instance_count=1,
instance_type='local',
Expand Down Expand Up @@ -140,7 +144,7 @@ def test_tf_distributed_local_mode(sagemaker_local_session):

endpoint_name = estimator.latest_training_job.name

with local_mode_utils.lock():
with lock.lock(LOCK_PATH):
try:
json_predictor = estimator.deploy(initial_instance_count=1,
instance_type='local',
Expand Down Expand Up @@ -178,7 +182,7 @@ def test_tf_local_data(sagemaker_local_session):
print('job succeeded: {}'.format(estimator.latest_training_job.name))

endpoint_name = estimator.latest_training_job.name
with local_mode_utils.lock():
with lock.lock(LOCK_PATH):
try:
json_predictor = estimator.deploy(initial_instance_count=1,
instance_type='local',
Expand Down Expand Up @@ -217,7 +221,7 @@ def test_tf_local_data_local_script():
print('job succeeded: {}'.format(estimator.latest_training_job.name))

endpoint_name = estimator.latest_training_job.name
with local_mode_utils.lock():
with lock.lock(LOCK_PATH):
try:
json_predictor = estimator.deploy(initial_instance_count=1,
instance_type='local',
Expand All @@ -241,7 +245,7 @@ def test_local_mode_serving_from_s3_model(sagemaker_local_session, mxnet_model,
s3_model.sagemaker_session = sagemaker_local_session

predictor = None
with local_mode_utils.lock():
with lock.lock(LOCK_PATH):
try:
predictor = s3_model.deploy(initial_instance_count=1, instance_type='local')
data = numpy.zeros(shape=(1, 1, 28, 28))
Expand All @@ -255,7 +259,7 @@ def test_local_mode_serving_from_s3_model(sagemaker_local_session, mxnet_model,
def test_local_mode_serving_from_local_model(tmpdir, sagemaker_local_session, mxnet_model):
predictor = None

with local_mode_utils.lock():
with lock.lock(LOCK_PATH):
try:
path = 'file://%s' % (str(tmpdir))
model = mxnet_model(path)
Expand Down Expand Up @@ -285,7 +289,7 @@ def test_mxnet_local_mode(sagemaker_local_session, mxnet_full_version):
mx.fit({'train': train_input, 'test': test_input})
endpoint_name = mx.latest_training_job.name

with local_mode_utils.lock():
with lock.lock(LOCK_PATH):
try:
predictor = mx.deploy(1, 'local', endpoint_name=endpoint_name)
data = numpy.zeros(shape=(1, 1, 28, 28))
Expand All @@ -310,7 +314,7 @@ def test_mxnet_local_data_local_script(mxnet_full_version):
mx.fit({'train': train_input, 'test': test_input})
endpoint_name = mx.latest_training_job.name

with local_mode_utils.lock():
with lock.lock(LOCK_PATH):
try:
predictor = mx.deploy(1, 'local', endpoint_name=endpoint_name)
data = numpy.zeros(shape=(1, 1, 28, 28))
Expand Down Expand Up @@ -365,7 +369,7 @@ def test_local_transform_mxnet(sagemaker_local_session, tmpdir, mxnet_full_versi
transformer = mx.transformer(1, 'local', assemble_with='Line', max_payload=1,
strategy='SingleRecord', output_path=output_path)

with local_mode_utils.lock():
with lock.lock(LOCK_PATH):
transformer.transform(transform_input, content_type='text/csv', split_type='Line')
transformer.wait()

Expand Down
5 changes: 3 additions & 2 deletions tests/integ/test_source_dirs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import pytest

import tests.integ.local_mode_utils as local_mode_utils
import tests.integ.lock as lock
from tests.integ import DATA_DIR, PYTHON_VERSION

from sagemaker.pytorch.estimator import PyTorch
Expand All @@ -37,7 +37,8 @@ def test_source_dirs(tmpdir, sagemaker_local_session):
sagemaker_session=sagemaker_local_session)
estimator.fit()

with local_mode_utils.lock():
# endpoint tests all use the same port, so we use this lock to prevent concurrent execution
with lock.lock():
try:
predictor = estimator.deploy(initial_instance_count=1, instance_type='local')
predict_response = predictor.predict([7])
Expand Down
39 changes: 27 additions & 12 deletions tests/integ/vpc_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import os
import tempfile

import tests.integ.lock as lock

VPC_NAME = 'sagemaker-python-sdk-test-vpc'
LOCK_PATH = os.path.join(tempfile.gettempdir(), 'sagemaker_test_vpc_lock')


def _get_subnet_ids_by_name(ec2_client, name):
Expand Down Expand Up @@ -61,20 +67,24 @@ def _create_vpc_with_name(ec2_client, region, name):
AvailabilityZone=(region + 'b'))['Subnet']['SubnetId']
print('created subnet: {}'.format(subnet_id_b))

s3_service = [s for s in ec2_client.describe_vpc_endpoint_services()['ServiceNames'] if s.endswith('s3')][0]
s3_service = \
[s for s in ec2_client.describe_vpc_endpoint_services()['ServiceNames'] if
s.endswith('s3')][0]
ec2_client.create_vpc_endpoint(VpcId=vpc_id, ServiceName=s3_service,
RouteTableIds=[_get_route_table_id(ec2_client, vpc_id)])
print('created s3 vpc endpoint')

security_group_id = ec2_client.create_security_group(VpcId=vpc_id, GroupName=name, Description=name)['GroupId']
security_group_id = \
ec2_client.create_security_group(VpcId=vpc_id, GroupName=name, Description=name)['GroupId']
print('created security group: {}'.format(security_group_id))

# multi-host vpc jobs require communication among hosts
ec2_client.authorize_security_group_ingress(GroupId=security_group_id,
IpPermissions=[{'IpProtocol': 'tcp',
'FromPort': 0,
'ToPort': 65535,
'UserIdGroupPairs': [{'GroupId': security_group_id}]}])
'UserIdGroupPairs': [{
'GroupId': security_group_id}]}])

ec2_client.create_tags(Resources=[vpc_id, subnet_id_a, subnet_id_b, security_group_id],
Tags=[{'Key': 'Name', 'Value': name}])
Expand All @@ -83,23 +93,28 @@ def _create_vpc_with_name(ec2_client, region, name):


def get_or_create_vpc_resources(ec2_client, region, name=VPC_NAME):
if _vpc_exists(ec2_client, name):
print('using existing vpc: {}'.format(name))
return _get_subnet_ids_by_name(ec2_client, name), _get_security_id_by_name(ec2_client, name)
else:
print('creating new vpc: {}'.format(name))
return _create_vpc_with_name(ec2_client, region, name)
# use lock to prevent race condition when tests are running concurrently
with lock.lock(LOCK_PATH):
if _vpc_exists(ec2_client, name):
print('using existing vpc: {}'.format(name))
return _get_subnet_ids_by_name(ec2_client, name), _get_security_id_by_name(ec2_client,
name)
else:
print('creating new vpc: {}'.format(name))
return _create_vpc_with_name(ec2_client, region, name)


def setup_security_group_for_encryption(ec2_client, security_group_id):
sg_desc = ec2_client.describe_security_groups(GroupIds=[security_group_id])
ingress_perms = sg_desc['SecurityGroups'][0]['IpPermissions']
if len(ingress_perms) == 1:
ec2_client.\
ec2_client. \
authorize_security_group_ingress(GroupId=security_group_id,
IpPermissions=[{'IpProtocol': '50',
'UserIdGroupPairs': [{'GroupId': security_group_id}]},
'UserIdGroupPairs': [
{'GroupId': security_group_id}]},
{'IpProtocol': 'udp',
'FromPort': 500,
'ToPort': 500,
'UserIdGroupPairs': [{'GroupId': security_group_id}]}])
'UserIdGroupPairs': [
{'GroupId': security_group_id}]}])