-
Notifications
You must be signed in to change notification settings - Fork 101
Add EI support to TFS container. #10
Changes from 5 commits
78b9479
4257a56
0a11fd3
66bfc3e
cdf51e5
9b7ff35
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
FROM ubuntu:16.04 | ||
LABEL com.amazonaws.sagemaker.capabilities.accept-bind-to-port=true | ||
|
||
ARG TFS_SHORT_VERSION | ||
|
||
COPY AmazonEI_TensorFlow_Serving_v${TFS_SHORT_VERSION}_v1 /usr/bin/tensorflow_model_server | ||
|
||
# downloaded 1.12 version is not executable | ||
RUN chmod +x /usr/bin/tensorflow_model_server | ||
|
||
# nginx + njs | ||
RUN \ | ||
apt-get update && \ | ||
apt-get -y install --no-install-recommends curl && \ | ||
curl -s http://nginx.org/keys/nginx_signing.key | apt-key add - && \ | ||
echo 'deb http://nginx.org/packages/ubuntu/ xenial nginx' >> /etc/apt/sources.list && \ | ||
apt-get update && \ | ||
apt-get -y install --no-install-recommends nginx nginx-module-njs python3 python3-pip && \ | ||
apt-get clean | ||
|
||
COPY ./ / | ||
RUN rm AmazonEI_TensorFlow_Serving_v${TFS_SHORT_VERSION}_v1 | ||
|
||
ENV SAGEMAKER_TFS_VERSION "${TFS_SHORT_VERSION}" | ||
ENV PATH "$PATH:/sagemaker" |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,7 @@ | |
|
||
function error() { | ||
>&2 echo $1 | ||
>&2 echo "usage: $0 [--version <major-version>] [--arch (cpu*|gpu)] [--region <aws-region>]" | ||
>&2 echo "usage: $0 [--version <major-version>] [--arch (cpu*|gpu|ei)] [--region <aws-region>]" | ||
exit 1 | ||
} | ||
|
||
|
@@ -28,6 +28,17 @@ function get_aws_account() { | |
aws sts get-caller-identity --query 'Account' --output text | ||
} | ||
|
||
function get_tfs_executable() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this function utilize the parsed args below for the TF version? It also looks like the naming scheme isn't consistent for the zip file for 1.11 and 1.12 already, which might require updating this file often. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It only utilizes $version here, coz the naming for v1.11 and v1.12 are: Even the v1.12's zip name and unzipped directory name are different. So I believe we will need to up date this file in the future. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems like in S3 it is always going to follow the pattern of: s3://amazonei-tensorflow/Tensorflow\ Serving/{version}/Ubuntu. We can do a discovery on the item for example with: Afterwards, I think we can probably also discover the unzipped name as well. For example: There are probably better commands than the ones I used above in my examples. |
||
zip_file=$(aws s3 ls 's3://amazonei-tensorflow/Tensorflow Serving/v'${version}'/Ubuntu/' | awk '{print $4}') | ||
aws s3 cp 's3://amazonei-tensorflow/Tensorflow Serving/v'${version}'/Ubuntu/'${zip_file} . | ||
|
||
mkdir exec_dir | ||
unzip ${zip_file} -d exec_dir | ||
|
||
find . -name AmazonEI_TensorFlow_Serving_v${version}_v1* -exec mv {} container/ \; | ||
rm ${zip_file} && rm -rf exec_dir | ||
} | ||
|
||
function parse_std_args() { | ||
# defaults | ||
arch='cpu' | ||
|
@@ -63,7 +74,7 @@ function parse_std_args() { | |
done | ||
|
||
[[ -z "${version// }" ]] && error 'missing version' | ||
[[ "$arch" =~ ^(cpu|gpu)$ ]] || error "invalid arch: $arch" | ||
[[ "$arch" =~ ^(cpu|gpu|ei)$ ]] || error "invalid arch: $arch" | ||
[[ -z "${aws_region// }" ]] && error 'missing aws region' | ||
|
||
full_version=$(get_full_version $version) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should this be under There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I went with the option that involves fewer changes. but I'm also okay to put everything under test/integration/... 😺 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if you do move it, do it in a different PR haha |
||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
|
||
import logging | ||
|
||
import boto3 | ||
import pytest | ||
from sagemaker import Session | ||
from sagemaker.tensorflow import TensorFlow | ||
|
||
logger = logging.getLogger(__name__) | ||
logging.getLogger('boto').setLevel(logging.INFO) | ||
logging.getLogger('botocore').setLevel(logging.INFO) | ||
logging.getLogger('factory.py').setLevel(logging.INFO) | ||
logging.getLogger('auth.py').setLevel(logging.INFO) | ||
logging.getLogger('connectionpool.py').setLevel(logging.INFO) | ||
|
||
|
||
def pytest_addoption(parser): | ||
parser.addoption('--aws-id') | ||
parser.addoption('--docker-base-name', default='functional-tensorflow-serving') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is functional meant to correspond to test/functional? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's meant to be sagemaker-tensorflow-serving ... I just accidentally changed it to 'functional' for some reason I dont remember. |
||
parser.addoption('--instance-type') | ||
parser.addoption('--accelerator-type', default=None) | ||
parser.addoption('--region', default='us-west-2') | ||
parser.addoption('--framework-version', default=TensorFlow.LATEST_VERSION) | ||
parser.addoption('--processor', default='cpu', choices=['gpu', 'cpu']) | ||
parser.addoption('--tag') | ||
|
||
|
||
@pytest.fixture(scope='session') | ||
def aws_id(request): | ||
return request.config.getoption('--aws-id') | ||
|
||
|
||
@pytest.fixture(scope='session') | ||
def docker_base_name(request): | ||
return request.config.getoption('--docker-base-name') | ||
|
||
|
||
@pytest.fixture(scope='session') | ||
def instance_type(request): | ||
return request.config.getoption('--instance-type') | ||
|
||
|
||
@pytest.fixture(scope='session') | ||
def accelerator_type(request): | ||
return request.config.getoption('--accelerator-type') | ||
|
||
|
||
@pytest.fixture(scope='session') | ||
def region(request): | ||
return request.config.getoption('--region') | ||
|
||
|
||
@pytest.fixture(scope='session') | ||
def framework_version(request): | ||
return request.config.getoption('--framework-version') | ||
|
||
|
||
@pytest.fixture(scope='session') | ||
def processor(request): | ||
return request.config.getoption('--processor') | ||
|
||
|
||
@pytest.fixture(scope='session') | ||
def tag(request, framework_version, processor): | ||
provided_tag = request.config.getoption('--tag') | ||
default_tag = '{}-{}-py2'.format(framework_version, processor) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we ever specify python within the tag for this specific container? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nope |
||
return provided_tag if provided_tag is not None else default_tag | ||
|
||
|
||
@pytest.fixture(scope='session') | ||
def docker_registry(aws_id, region): | ||
return '{}.dkr.ecr.{}.amazonaws.com'.format(aws_id, region) | ||
|
||
|
||
@pytest.fixture(scope='module') | ||
def docker_image(docker_base_name, tag): | ||
return '{}:{}'.format(docker_base_name, tag) | ||
|
||
|
||
@pytest.fixture(scope='module') | ||
def docker_image_uri(docker_registry, docker_image): | ||
uri = '{}/{}'.format(docker_registry, docker_image) | ||
return uri | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
import io | ||
import json | ||
import logging | ||
import time | ||
|
||
import boto3 | ||
import numpy as np | ||
|
||
import pytest | ||
|
||
EI_SUPPORTED_REGIONS = ['us-east-1', 'us-east-2', 'us-west-2', 'eu-west-1', 'ap-northeast-1', 'ap-northeast-2'] | ||
|
||
logger = logging.getLogger(__name__) | ||
logging.getLogger('boto3').setLevel(logging.INFO) | ||
logging.getLogger('botocore').setLevel(logging.INFO) | ||
logging.getLogger('factory.py').setLevel(logging.INFO) | ||
logging.getLogger('auth.py').setLevel(logging.INFO) | ||
logging.getLogger('connectionpool.py').setLevel(logging.INFO) | ||
logging.getLogger('session.py').setLevel(logging.DEBUG) | ||
logging.getLogger('functional').setLevel(logging.DEBUG) | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def skip_if_no_accelerator(accelerator_type): | ||
if accelerator_type is None: | ||
pytest.skip('Skipping because accelerator type was not provided') | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def skip_if_non_supported_ei_region(region): | ||
if region not in EI_SUPPORTED_REGIONS: | ||
pytest.skip('EI is not supported in {}'.format(region)) | ||
|
||
|
||
@pytest.fixture | ||
def pretrained_model_data(region): | ||
return 's3://sagemaker-sample-data-{}/tensorflow/model/resnet/resnet_50_v2_fp32_NCHW.tar.gz'.format(region) | ||
|
||
|
||
def _timestamp(): | ||
return time.strftime("%Y-%m-%d-%H-%M-%S") | ||
|
||
|
||
def _execution_role(session): | ||
return session.resource('iam').Role('SageMakerRole').arn | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we consider adding an argparser for the role? |
||
|
||
|
||
def _production_variants(model_name, instance_type, accelerator_type): | ||
production_variants = [{ | ||
'VariantName': 'AllTraffic', | ||
'ModelName': model_name, | ||
'InitialInstanceCount': 1, | ||
'InstanceType': instance_type, | ||
'AcceleratorType': accelerator_type | ||
}] | ||
return production_variants | ||
|
||
|
||
@pytest.mark.skip_if_non_supported_ei_region | ||
@pytest.mark.skip_if_no_accelerator | ||
def test_deploy_elastic_inference_with_pretrained_model(pretrained_model_data, | ||
docker_image_uri, | ||
instance_type, | ||
accelerator_type): | ||
endpoint_name = 'test-tfs-ei-deploy-model-{}'.format(_timestamp()) | ||
endpoint_config_name = 'test-tfs-endpoint-config-{}'.format(_timestamp()) | ||
model_name = 'test-tfs-ei-model-{}'.format(_timestamp()) | ||
|
||
session = boto3.Session() | ||
client = session.client('sagemaker') | ||
runtime_client = session.client('runtime.sagemaker') | ||
client.create_model(ModelName=model_name, | ||
ExecutionRoleArn=_execution_role(session), | ||
PrimaryContainer={ | ||
'Image': docker_image_uri, | ||
'ModelDataUrl': pretrained_model_data | ||
}) | ||
|
||
logger.info('deploying model to endpoint: {}'.format(endpoint_name)) | ||
|
||
client.create_endpoint_config(EndpointConfigName=endpoint_config_name, | ||
ProductionVariants=_production_variants(model_name, instance_type, accelerator_type)) | ||
|
||
client.create_endpoint(EndpointName=endpoint_name, | ||
EndpointConfigName=endpoint_config_name) | ||
|
||
try: | ||
client.get_waiter('endpoint_in_service').wait(EndpointName=endpoint_name) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You don't have to do this, however I can foresee us reusing this logic, maybe it would be better to refactor to another file? |
||
finally: | ||
status = client.describe_endpoint(EndpointName=endpoint_name)['EndpointStatus'] | ||
if status != 'InService': | ||
raise Exception('Failed to create endpoint.') | ||
|
||
input_data = {'instances': np.random.rand(1, 1, 3, 3).tolist()} | ||
|
||
response = runtime_client.invoke_endpoint(EndpointName=endpoint_name, | ||
ContentType='application/json', | ||
Body=json.dumps(input_data)) | ||
result = json.loads(response['Body'].read().decode()) | ||
assert result['predictions'] is not None | ||
|
||
client.delete_endpoint(EndpointName=endpoint_name) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. delete_endpoint should clear endpoint_config too |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you provide an example of the command needed to run the test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"To test against Elastic Inference, you will..."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll update these as well as the above comments.