-
Notifications
You must be signed in to change notification settings - Fork 101
Add EI support to TFS container. #10
Changes from 1 commit
78b9479
4257a56
0a11fd3
66bfc3e
cdf51e5
9b7ff35
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,18 +29,14 @@ function get_aws_account() { | |
} | ||
|
||
function get_tfs_executable() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this function utilize the parsed args below for the TF version? It also looks like the naming scheme isn't consistent for the zip file for 1.11 and 1.12 already, which might require updating this file often. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It only utilizes $version here, coz the naming for v1.11 and v1.12 are: Even the v1.12's zip name and unzipped directory name are different. So I believe we will need to up date this file in the future. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems like in S3 it is always going to follow the pattern of: s3://amazonei-tensorflow/Tensorflow\ Serving/{version}/Ubuntu. We can do a discovery on the item for example with: Afterwards, I think we can probably also discover the unzipped name as well. For example: There are probably better commands than the ones I used above in my examples. |
||
# default to v1.12 in accordance with defaults below | ||
s3_object='tfs_ei_v1_12_ubuntu' | ||
unzipped='v1_12_Ubuntu' | ||
zip_file=$(aws s3 ls 's3://amazonei-tensorflow/Tensorflow Serving/v'${version}'/Ubuntu/' | awk '{print $4}') | ||
aws s3 cp 's3://amazonei-tensorflow/Tensorflow Serving/v'${version}'/Ubuntu/'${zip_file} . | ||
|
||
if [ ${version} = '1.11' ]; then | ||
s3_object='Ubuntu' | ||
unzipped='Ubuntu' | ||
fi | ||
mkdir exec_dir | ||
unzip ${zip_file} -d exec_dir | ||
|
||
aws s3 cp 's3://amazonei-tensorflow/Tensorflow Serving/v'${version}'/Ubuntu/'${s3_object}'.zip' . | ||
unzip ${s3_object} && mv ${unzipped}/AmazonEI_Tensorflow_Serving_v${version}_v1 container/ | ||
rm ${s3_object}.zip && rm -rf ${unzipped} | ||
find . -name AmazonEI_TensorFlow_Serving_v${version}_v1* -exec mv {} container/ \; | ||
rm ${zip_file} && rm -rf exec_dir | ||
} | ||
|
||
function parse_std_args() { | ||
|
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,7 +28,7 @@ | |
|
||
def pytest_addoption(parser): | ||
parser.addoption('--aws-id') | ||
parser.addoption('--docker-base-name', default='sagemaker-tensorflow-serving') | ||
parser.addoption('--docker-base-name', default='functional-tensorflow-serving') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is functional meant to correspond to test/functional? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's meant to be sagemaker-tensorflow-serving ... I just accidentally changed it to 'functional' for some reason I dont remember. |
||
parser.addoption('--instance-type') | ||
parser.addoption('--accelerator-type', default=None) | ||
parser.addoption('--region', default='us-west-2') | ||
|
@@ -94,7 +94,3 @@ def docker_image_uri(docker_registry, docker_image): | |
uri = '{}/{}'.format(docker_registry, docker_image) | ||
return uri | ||
|
||
|
||
@pytest.fixture(scope='session') | ||
def sagemaker_session(region): | ||
return Session(boto_session=boto3.Session(region_name=region)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"). You | ||
# may not use this file except in compliance with the License. A copy of | ||
# the License is located at | ||
# | ||
# http://aws.amazon.com/apache2.0/ | ||
# | ||
# or in the "license" file accompanying this file. This file is | ||
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
import io | ||
import json | ||
import logging | ||
import time | ||
|
||
import boto3 | ||
import numpy as np | ||
|
||
import pytest | ||
|
||
EI_SUPPORTED_REGIONS = ['us-east-1', 'us-east-2', 'us-west-2', 'eu-west-1', 'ap-northeast-1', 'ap-northeast-2'] | ||
|
||
logger = logging.getLogger(__name__) | ||
logging.getLogger('boto3').setLevel(logging.INFO) | ||
logging.getLogger('botocore').setLevel(logging.INFO) | ||
logging.getLogger('factory.py').setLevel(logging.INFO) | ||
logging.getLogger('auth.py').setLevel(logging.INFO) | ||
logging.getLogger('connectionpool.py').setLevel(logging.INFO) | ||
logging.getLogger('session.py').setLevel(logging.DEBUG) | ||
logging.getLogger('functional').setLevel(logging.DEBUG) | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def skip_if_no_accelerator(accelerator_type): | ||
if accelerator_type is None: | ||
pytest.skip('Skipping because accelerator type was not provided') | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def skip_if_non_supported_ei_region(region): | ||
if region not in EI_SUPPORTED_REGIONS: | ||
pytest.skip('EI is not supported in {}'.format(region)) | ||
|
||
|
||
@pytest.fixture | ||
def pretrained_model_data(region): | ||
return 's3://sagemaker-sample-data-{}/tensorflow/model/resnet/resnet_50_v2_fp32_NCHW.tar.gz'.format(region) | ||
|
||
|
||
def _timestamp(): | ||
return time.strftime("%Y-%m-%d-%H-%M-%S") | ||
|
||
|
||
def _execution_role(session): | ||
return session.resource('iam').Role('SageMakerRole').arn | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we consider adding an argparser for the role? |
||
|
||
|
||
def _production_variants(model_name, instance_type, accelerator_type): | ||
production_variants = [{ | ||
'VariantName': 'AllTraffic', | ||
'ModelName': model_name, | ||
'InitialInstanceCount': 1, | ||
'InstanceType': instance_type, | ||
'AcceleratorType': accelerator_type | ||
}] | ||
return production_variants | ||
|
||
|
||
@pytest.mark.skip_if_non_supported_ei_region | ||
@pytest.mark.skip_if_no_accelerator | ||
def test_deploy_elastic_inference_with_pretrained_model(pretrained_model_data, | ||
docker_image_uri, | ||
instance_type, | ||
accelerator_type): | ||
endpoint_name = 'test-tfs-ei-deploy-model-{}'.format(_timestamp()) | ||
endpoint_config_name = 'test-tfs-endpoint-config-{}'.format(_timestamp()) | ||
model_name = 'test-tfs-ei-model-{}'.format(_timestamp()) | ||
|
||
session = boto3.Session() | ||
client = session.client('sagemaker') | ||
runtime_client = session.client('runtime.sagemaker') | ||
client.create_model(ModelName=model_name, | ||
ExecutionRoleArn=_execution_role(session), | ||
PrimaryContainer={ | ||
'Image': docker_image_uri, | ||
'ModelDataUrl': pretrained_model_data | ||
}) | ||
|
||
logger.info('deploying model to endpoint: {}'.format(endpoint_name)) | ||
|
||
client.create_endpoint_config(EndpointConfigName=endpoint_config_name, | ||
ProductionVariants=_production_variants(model_name, instance_type, accelerator_type)) | ||
|
||
client.create_endpoint(EndpointName=endpoint_name, | ||
EndpointConfigName=endpoint_config_name) | ||
|
||
try: | ||
client.get_waiter('endpoint_in_service').wait(EndpointName=endpoint_name) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You don't have to do this, however I can foresee us reusing this logic, maybe it would be better to refactor to another file? |
||
finally: | ||
status = client.describe_endpoint(EndpointName=endpoint_name)['EndpointStatus'] | ||
if status != 'InService': | ||
raise Exception('Failed to create endpoint.') | ||
|
||
input_data = {'instances': np.random.rand(1, 1, 3, 3).tolist()} | ||
|
||
response = runtime_client.invoke_endpoint(EndpointName=endpoint_name, | ||
ContentType='application/json', | ||
Body=json.dumps(input_data)) | ||
result = json.loads(response['Body'].read().decode()) | ||
assert result['predictions'] is not None | ||
|
||
client.delete_endpoint(EndpointName=endpoint_name) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. delete_endpoint should clear endpoint_config too |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you provide an example of the command needed to run the test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"To test against Elastic Inference, you will..."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll update these as well as the above comments.