Skip to content

fix: enable kms support for repack_model #1061

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Sep 25, 2019
1 change: 1 addition & 0 deletions src/sagemaker/chainer/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def create_model(
entry_point=None,
source_dir=None,
dependencies=None,
**kwargs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there something here that should actually get used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am only fixing TensorFlow and SKLearn for now. I think all other frameworks needs more fixes than this one(similar to the fixes I added to sklearn). That work should be properly scheduled and we should add integ tests as well.

):
"""Create a SageMaker ``ChainerModel`` object that can be deployed to an
``Endpoint``.
Expand Down
1 change: 1 addition & 0 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,7 @@ def deploy(
)
model = self._compiled_models[family]
else:
kwargs["output_kms_key"] = self.output_kms_key
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realize that I'm probably the source of this problem (with one of my recent code changes), but should we distinguish more between output_kms_key and kms_key? also, is there a case where someone might want to specify output_kms_key in deploy?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we repack the model the new model tar ball is uploaded to the same bucket. we need to use the same output_kms_key in the estimator, otherwise uploading the repacked tar ball will fail.

I can't really think of any use case where you would want a different kms key since we don't allow user to specify a new bucket when deploy is called.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe let's rename the Model's output_kms_key to something like model_data_kms_key? (model_kms_key?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i will do model_kms_key

model = self.create_model(**kwargs)
model.name = model_name
return model.deploy(
Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(
vpc_config=None,
sagemaker_session=None,
enable_network_isolation=False,
output_kms_key=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring? (same for the rest)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will add

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

**kwargs also need docstring entries - if possible, include a link to where the **kwargs are eventually sent to

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

k

):
"""Initialize an SageMaker ``Model``.

Expand Down Expand Up @@ -127,6 +128,7 @@ def __init__(
self.endpoint_name = None
self._is_compiled_model = False
self._enable_network_isolation = enable_network_isolation
self.output_kms_key = output_kms_key

def prepare_container_def(
self, instance_type, accelerator_type=None
Expand Down Expand Up @@ -799,6 +801,7 @@ def _upload_code(self, key_prefix, repack=False):
model_uri=self.model_data,
repacked_model_uri=repacked_model_data,
sagemaker_session=self.sagemaker_session,
kms_key=self.output_kms_key,
)

self.repacked_model_data = repacked_model_data
Expand Down
1 change: 1 addition & 0 deletions src/sagemaker/mxnet/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def create_model(
source_dir=None,
dependencies=None,
image_name=None,
**kwargs
):
"""Create a SageMaker ``MXNetModel`` object that can be deployed to an
``Endpoint``.
Expand Down
1 change: 1 addition & 0 deletions src/sagemaker/pytorch/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def create_model(
entry_point=None,
source_dir=None,
dependencies=None,
**kwargs
):
"""Create a SageMaker ``PyTorchModel`` object that can be deployed to an
``Endpoint``.
Expand Down
1 change: 1 addition & 0 deletions src/sagemaker/rl/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def create_model(
entry_point=None,
source_dir=None,
dependencies=None,
**kwargs
):
"""Create a SageMaker ``RLEstimatorModel`` object that can be deployed
to an Endpoint.
Expand Down
7 changes: 7 additions & 0 deletions src/sagemaker/tensorflow/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,7 @@ def create_model(
entry_point=None,
source_dir=None,
dependencies=None,
**kwargs
):
"""Create a ``Model`` object that can be used for creating SageMaker model entities,
deploying to a SageMaker endpoint, or starting SageMaker Batch Transform jobs.
Expand Down Expand Up @@ -552,6 +553,7 @@ def create_model(
entry_point=entry_point,
source_dir=source_dir,
dependencies=dependencies,
**kwargs
)

return self._create_default_model(
Expand All @@ -561,6 +563,7 @@ def create_model(
entry_point=entry_point,
source_dir=source_dir,
dependencies=dependencies,
**kwargs
)

def _create_tfs_model(
Expand All @@ -570,6 +573,7 @@ def _create_tfs_model(
entry_point=None,
source_dir=None,
dependencies=None,
**kwargs
):
"""Placeholder docstring"""
return Model(
Expand All @@ -585,6 +589,7 @@ def _create_tfs_model(
source_dir=source_dir,
dependencies=dependencies,
enable_network_isolation=self.enable_network_isolation(),
**kwargs
)

def _create_default_model(
Expand All @@ -595,6 +600,7 @@ def _create_default_model(
entry_point=None,
source_dir=None,
dependencies=None,
**kwargs
):
"""Placeholder docstring"""
return TensorFlowModel(
Expand All @@ -615,6 +621,7 @@ def _create_default_model(
vpc_config=self.get_vpc_config(vpc_config_override),
dependencies=dependencies or self.dependencies,
enable_network_isolation=self.enable_network_isolation(),
**kwargs
)

def hyperparameters(self):
Expand Down
1 change: 1 addition & 0 deletions src/sagemaker/tensorflow/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ def prepare_container_def(self, instance_type, accelerator_type=None):
self.model_data,
model_data,
self.sagemaker_session,
kms_key=self.output_kms_key,
)
else:
model_data = self.model_data
Expand Down
12 changes: 9 additions & 3 deletions src/sagemaker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@ def repack_model(
model_uri,
repacked_model_uri,
sagemaker_session,
kms_key=None,
):
"""Unpack model tarball and creates a new model tarball with the provided
code script.
Expand Down Expand Up @@ -400,6 +401,7 @@ def repack_model(
model will be saved
sagemaker_session (sagemaker.session.Session): a sagemaker session to
interact with S3.
kms_key (str): KMS key ARN for encrypting the repacked model file

Returns:
str: path to the new packed model
Expand All @@ -417,10 +419,10 @@ def repack_model(
with tarfile.open(tmp_model_path, mode="w:gz") as t:
t.add(model_dir, arcname=os.path.sep)

_save_model(repacked_model_uri, tmp_model_path, sagemaker_session)
_save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key=kms_key)


def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session):
def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key):
"""
Args:
repacked_model_uri:
Expand All @@ -432,8 +434,12 @@ def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session):
bucket, key = url.netloc, url.path.lstrip("/")
new_key = key.replace(os.path.basename(key), os.path.basename(repacked_model_uri))

if kms_key:
extra_args = {"ServerSideEncryption": "aws:kms", "SSEKMSKeyId": kms_key}
else:
extra_args = None
sagemaker_session.boto_session.resource("s3").Object(bucket, new_key).upload_file(
tmp_model_path
tmp_model_path, ExtraArgs=extra_args
)
else:
shutil.move(tmp_model_path, repacked_model_uri.replace("file://", ""))
Expand Down
18 changes: 13 additions & 5 deletions tests/integ/test_tf_script_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import tests.integ
from tests.integ import timeout
from tests.integ import kms_utils
from tests.integ.retry import retries
from tests.integ.s3_utils import assert_s3_files_exist

Expand Down Expand Up @@ -67,16 +68,14 @@ def test_mnist(sagemaker_session, instance_type):

def test_server_side_encryption(sagemaker_session):
boto_session = sagemaker_session.boto_session
with tests.integ.kms_utils.bucket_with_encryption(boto_session, ROLE) as (
bucket_with_kms,
kms_key,
):
with kms_utils.bucket_with_encryption(boto_session, ROLE) as (bucket_with_kms, kms_key):
output_path = os.path.join(
bucket_with_kms, "test-server-side-encryption", time.strftime("%y%m%d-%H%M")
)

estimator = TensorFlow(
entry_point=SCRIPT,
entry_point="training.py",
source_dir=TFS_RESOURCE_PATH,
role=ROLE,
train_instance_count=1,
train_instance_type="ml.c5.xlarge",
Expand All @@ -99,6 +98,15 @@ def test_server_side_encryption(sagemaker_session):
inputs=inputs, job_name=unique_name_from_base("test-server-side-encryption")
)

endpoint_name = unique_name_from_base("test-server-side-encryption")
with timeout.timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
estimator.deploy(
initial_instance_count=1,
instance_type="ml.c5.xlarge",
endpoint_name=endpoint_name,
entry_point=os.path.join(TFS_RESOURCE_PATH, "inference.py"),
)


@pytest.mark.canary_quick
def test_mnist_distributed(sagemaker_session, instance_type):
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/test_mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,13 +416,15 @@ def test_model(sagemaker_session):

@patch("sagemaker.utils.repack_model")
def test_model_mms_version(repack_model, sagemaker_session):
output_kms_key = "kms-key"
model = MXNetModel(
MODEL_DATA,
role=ROLE,
entry_point=SCRIPT_PATH,
framework_version=MXNetModel._LOWEST_MMS_VERSION,
sagemaker_session=sagemaker_session,
name="test-mxnet-model",
output_kms_key=output_kms_key,
)
predictor = model.deploy(1, GPU)

Expand All @@ -433,6 +435,7 @@ def test_model_mms_version(repack_model, sagemaker_session):
model_uri=MODEL_DATA,
repacked_model_uri="s3://mybucket/test-mxnet-model/model.tar.gz",
sagemaker_session=sagemaker_session,
kms_key=output_kms_key,
)

assert model.model_data == MODEL_DATA
Expand Down
4 changes: 4 additions & 0 deletions tests/unit/test_tfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def test_tfs_model_with_entry_point(
framework_version=tf_version,
image="my-image",
sagemaker_session=sagemaker_session,
output_kms_key="kms-key",
)

model.prepare_container_def(INSTANCE_TYPE)
Expand All @@ -180,6 +181,7 @@ def test_tfs_model_with_entry_point(
"s3://some/data.tar.gz",
"s3://my_bucket/key-prefix/model.tar.gz",
sagemaker_session,
kms_key="kms-key",
)


Expand Down Expand Up @@ -207,6 +209,7 @@ def test_tfs_model_with_source(repack_model, model_code_key_prefix, sagemaker_se
"s3://some/data.tar.gz",
"s3://my_bucket/key-prefix/model.tar.gz",
sagemaker_session,
kms_key=None,
)


Expand Down Expand Up @@ -236,6 +239,7 @@ def test_tfs_model_with_dependencies(
"s3://some/data.tar.gz",
"s3://my_bucket/key-prefix/model.tar.gz",
sagemaker_session,
kms_key=None,
)


Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ def __init__(self, bucket, key):
self.bucket = bucket
self.key = key

def upload_file(self, target):
def upload_file(self, target, **kwargs):
if self.bucket in BUCKET_WITHOUT_WRITING_PERMISSION:
raise exceptions.S3UploadFailedError()
shutil.copy2(target, dst)
Expand Down