-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Login to ECR if needed for Local Mode #121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,13 +10,16 @@ | |
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
# ANY KIND, either express or implied. See the License for the specific | ||
# language governing permissions and limitations under the License. | ||
import base64 | ||
import errno | ||
import json | ||
import logging | ||
import os | ||
import platform | ||
import random | ||
import shlex | ||
import shutil | ||
import string | ||
import subprocess | ||
import sys | ||
import tempfile | ||
|
@@ -59,7 +62,10 @@ def __init__(self, instance_type, instance_count, image, sagemaker_session=None) | |
self.instance_type = instance_type | ||
self.instance_count = instance_count | ||
self.image = image | ||
self.hosts = ['{}-{}'.format(CONTAINER_PREFIX, i) for i in range(1, self.instance_count + 1)] | ||
# Since we are using a single docker network, Generate a random suffix to attach to the container names. | ||
# This way multiple jobs can run in parallel. | ||
suffix = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(5)) | ||
self.hosts = ['{}-{}-{}'.format(CONTAINER_PREFIX, i, suffix) for i in range(1, self.instance_count + 1)] | ||
self.container_root = None | ||
self.container = None | ||
# set the local config. This is optional and will use reasonable defaults | ||
|
@@ -110,6 +116,8 @@ def train(self, input_data_config, hyperparameters): | |
|
||
compose_data = self._generate_compose_file('train', additional_volumes=volumes) | ||
compose_command = self._compose() | ||
|
||
_ecr_login_if_needed(self.sagemaker_session.boto_session, self.image) | ||
_execute_and_stream_output(compose_command) | ||
|
||
s3_model_artifacts = self.retrieve_model_artifacts(compose_data) | ||
|
@@ -152,6 +160,8 @@ def serve(self, primary_container): | |
|
||
env_vars = ['{}={}'.format(k, v) for k, v in primary_container['Environment'].items()] | ||
|
||
_ecr_login_if_needed(self.sagemaker_session.boto_session, self.image) | ||
|
||
self._generate_compose_file('serve', additional_env_vars=env_vars) | ||
compose_command = self._compose() | ||
self.container = _HostingContainer(compose_command) | ||
|
@@ -296,7 +306,11 @@ def _generate_compose_file(self, command, additional_volumes=None, additional_en | |
content = { | ||
# Some legacy hosts only support the 2.1 format. | ||
'version': '2.1', | ||
'services': services | ||
'services': services, | ||
'networks': { | ||
'sagemaker-local': {'name': 'sagemaker-local'} | ||
} | ||
|
||
} | ||
|
||
docker_compose_path = os.path.join(self.container_root, DOCKER_COMPOSE_FILENAME) | ||
|
@@ -335,7 +349,12 @@ def _create_docker_host(self, host, environment, optml_subdirs, command, volumes | |
'tty': True, | ||
'volumes': [v.map for v in optml_volumes], | ||
'environment': environment, | ||
'command': command | ||
'command': command, | ||
'networks': { | ||
'sagemaker-local': { | ||
'aliases': [host] | ||
} | ||
} | ||
} | ||
|
||
serving_port = 8080 if self.local_config is None else self.local_config.get('serving_port', 8080) | ||
|
@@ -390,7 +409,8 @@ def _build_optml_volumes(self, host, subdirs): | |
return volumes | ||
|
||
def _cleanup(self): | ||
_check_output('docker network prune -f') | ||
# we don't need to cleanup anything at the moment | ||
pass | ||
|
||
|
||
class _HostingContainer(object): | ||
|
@@ -525,3 +545,24 @@ def _aws_credentials(session): | |
def _write_json_file(filename, content): | ||
with open(filename, 'w') as f: | ||
json.dump(content, f) | ||
|
||
|
||
def _ecr_login_if_needed(boto_session, image): | ||
# Only ECR images need login | ||
if not ('dkr.ecr' in image and 'amazonaws.com' in image): | ||
return | ||
|
||
# do we have the image? | ||
if _check_output('docker images -q %s' % image).strip(): | ||
return | ||
|
||
ecr = boto_session.client('ecr') | ||
auth = ecr.get_authorization_token(registryIds=[image.split('.')[0]]) | ||
authorization_data = auth['authorizationData'][0] | ||
|
||
raw_token = base64.b64decode(authorization_data['authorizationToken']) | ||
token = raw_token.decode('utf-8').strip('AWS:') | ||
ecr_url = auth['authorizationData'][0]['proxyEndpoint'] | ||
|
||
cmd = "docker login -u AWS -p %s %s" % (token, ecr_url) | ||
subprocess.check_output(cmd, shell=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sometimes we are using _check_output and sometimes we are using subprocess.checkoutput. Can we just use one of them? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a good reason for this. I will probably refactor this anyways. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a TODO (rignacio) here to remove this function later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will not remove this function, I will add to it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Them add a TODO (rignacio) here to write this function later :)