Skip to content

fix: enable kms support for repack_model #1061

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Sep 25, 2019
4 changes: 3 additions & 1 deletion src/sagemaker/amazon/kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def __init__(
self.center_factor = center_factor
self.eval_metrics = eval_metrics

def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs):
"""Return a :class:`~sagemaker.amazon.kmeans.KMeansModel` referencing
the latest s3 model data produced by this Estimator.

Expand All @@ -158,12 +158,14 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
Default: use subnets and security groups from this Estimator.
* 'Subnets' (list[str]): List of subnet ids.
* 'SecurityGroupIds' (list[str]): List of security group ids.
**kwargs: Additional kwargs passed to the KMeansModel constructor.
"""
return KMeansModel(
self.model_data,
self.role,
self.sagemaker_session,
vpc_config=self.get_vpc_config(vpc_config_override),
**kwargs
)

def _prepare_for_training(self, records, mini_batch_size=5000, job_name=None):
Expand Down
4 changes: 3 additions & 1 deletion src/sagemaker/amazon/lda.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def __init__(
self.max_iterations = max_iterations
self.tol = tol

def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs):
"""Return a :class:`~sagemaker.amazon.LDAModel` referencing the latest
s3 model data produced by this Estimator.

Expand All @@ -132,12 +132,14 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
Default: use subnets and security groups from this Estimator.
* 'Subnets' (list[str]): List of subnet ids.
* 'SecurityGroupIds' (list[str]): List of security group ids.
**kwargs: Additional kwargs passed to the LDAModel constructor.
"""
return LDAModel(
self.model_data,
self.role,
sagemaker_session=self.sagemaker_session,
vpc_config=self.get_vpc_config(vpc_config_override),
**kwargs
)

def _prepare_for_training( # pylint: disable=signature-differs
Expand Down
4 changes: 3 additions & 1 deletion src/sagemaker/amazon/linear_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def __init__(
"value greater than 2."
)

def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs):
"""Return a :class:`~sagemaker.amazon.LinearLearnerModel` referencing
the latest s3 model data produced by this Estimator.

Expand All @@ -382,12 +382,14 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
the model. Default: use subnets and security groups from this Estimator.
* 'Subnets' (list[str]): List of subnet ids.
* 'SecurityGroupIds' (list[str]): List of security group ids.
**kwargs: Additional kwargs passed to the LinearLearnerModel constructor.
"""
return LinearLearnerModel(
self.model_data,
self.role,
self.sagemaker_session,
vpc_config=self.get_vpc_config(vpc_config_override),
**kwargs
)

def _prepare_for_training(self, records, mini_batch_size=None, job_name=None):
Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/chainer/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def create_model(
entry_point=None,
source_dir=None,
dependencies=None,
**kwargs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there something here that should actually get used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am only fixing TensorFlow and SKLearn for now. I think all other frameworks needs more fixes than this one(similar to the fixes I added to sklearn). That work should be properly scheduled and we should add integ tests as well.

):
"""Create a SageMaker ``ChainerModel`` object that can be deployed to an
``Endpoint``.
Expand All @@ -186,6 +187,7 @@ def create_model(
dependencies (list[str]): A list of paths to directories (absolute or relative) with
any additional libraries that will be exported to the container.
If not specified, the dependencies from training are used.
**kwargs: Additional kwargs passed to the ChainerModel constructor.

Returns:
sagemaker.chainer.model.ChainerModel: A SageMaker ``ChainerModel``
Expand Down
6 changes: 5 additions & 1 deletion src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,7 @@ def deploy(
)
model = self._compiled_models[family]
else:
kwargs["model_kms_key"] = self.output_kms_key
model = self.create_model(**kwargs)
model.name = model_name
return model.deploy(
Expand Down Expand Up @@ -734,7 +735,9 @@ def transformer(
model_name = self._current_job_name
else:
model_name = self.latest_training_job.name
model = self.create_model(vpc_config_override=vpc_config_override)
model = self.create_model(
vpc_config_override=vpc_config_override, model_kms_key=self.output_kms_key
)

# not all create_model() implementations have the same kwargs
model.name = model_name
Expand Down Expand Up @@ -1716,6 +1719,7 @@ def transformer(
model_server_workers=model_server_workers,
entry_point=entry_point,
vpc_config_override=vpc_config_override,
model_kms_key=self.output_kms_key,
)
model._create_sagemaker_model(instance_type, tags=tags)

Expand Down
5 changes: 5 additions & 0 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(
vpc_config=None,
sagemaker_session=None,
enable_network_isolation=False,
model_kms_key=None,
):
"""Initialize an SageMaker ``Model``.

Expand Down Expand Up @@ -114,6 +115,8 @@ def __init__(
network isolation in the endpoint, isolating the model
container. No inbound or outbound network calls can be made to
or from the model container.
model_kms_key (str): KMS key ARN used to encrypt the repacked
model archive file if the model is repacked
"""
self.model_data = model_data
self.image = image
Expand All @@ -127,6 +130,7 @@ def __init__(
self.endpoint_name = None
self._is_compiled_model = False
self._enable_network_isolation = enable_network_isolation
self.model_kms_key = model_kms_key

def prepare_container_def(
self, instance_type, accelerator_type=None
Expand Down Expand Up @@ -799,6 +803,7 @@ def _upload_code(self, key_prefix, repack=False):
model_uri=self.model_data,
repacked_model_uri=repacked_model_data,
sagemaker_session=self.sagemaker_session,
kms_key=self.model_kms_key,
)

self.repacked_model_data = repacked_model_data
Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/mxnet/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def create_model(
source_dir=None,
dependencies=None,
image_name=None,
**kwargs
):
"""Create a SageMaker ``MXNetModel`` object that can be deployed to an
``Endpoint``.
Expand Down Expand Up @@ -171,6 +172,7 @@ def create_model(
Examples:
123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0
custom-image:latest.
**kwargs: Additional kwargs passed to the MXNetModel constructor.

Returns:
sagemaker.mxnet.model.MXNetModel: A SageMaker ``MXNetModel`` object.
Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/pytorch/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def create_model(
entry_point=None,
source_dir=None,
dependencies=None,
**kwargs
):
"""Create a SageMaker ``PyTorchModel`` object that can be deployed to an
``Endpoint``.
Expand All @@ -139,6 +140,7 @@ def create_model(
dependencies (list[str]): A list of paths to directories (absolute or relative) with
any additional libraries that will be exported to the container.
If not specified, the dependencies from training are used.
**kwargs: Additional kwargs passed to the PyTorchModel constructor.

Returns:
sagemaker.pytorch.model.PyTorchModel: A SageMaker ``PyTorchModel``
Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/rl/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def create_model(
entry_point=None,
source_dir=None,
dependencies=None,
**kwargs
):
"""Create a SageMaker ``RLEstimatorModel`` object that can be deployed
to an Endpoint.
Expand All @@ -189,6 +190,7 @@ def create_model(
folders will be copied to SageMaker in the same folder where the
entry_point is copied. If the ```source_dir``` points to S3,
code will be uploaded and the S3 location will be used instead.
**kwargs: Additional kwargs passed to the FrameworkModel constructor.

Returns:
sagemaker.model.FrameworkModel: Depending on input parameters returns
Expand Down
9 changes: 9 additions & 0 deletions src/sagemaker/tensorflow/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,7 @@ def create_model(
entry_point=None,
source_dir=None,
dependencies=None,
**kwargs
):
"""Create a ``Model`` object that can be used for creating SageMaker model entities,
deploying to a SageMaker endpoint, or starting SageMaker Batch Transform jobs.
Expand Down Expand Up @@ -537,6 +538,8 @@ def create_model(
If not specified and ``endpoint_type`` is 'tensorflow-serving', ``dependencies`` is
set to ``None``.
If ``endpoint_type`` is also ``None``, then the dependencies from training are used.
**kwargs: Additional kwargs passed to ``sagemaker.tensorflow.serving.Model`` constructor
and ``sagemaker.tensorflow.model.TensorFlowModel`` constructor.

Returns:
sagemaker.tensorflow.model.TensorFlowModel or sagemaker.tensorflow.serving.Model: A
Expand All @@ -552,6 +555,7 @@ def create_model(
entry_point=entry_point,
source_dir=source_dir,
dependencies=dependencies,
**kwargs
)

return self._create_default_model(
Expand All @@ -561,6 +565,7 @@ def create_model(
entry_point=entry_point,
source_dir=source_dir,
dependencies=dependencies,
**kwargs
)

def _create_tfs_model(
Expand All @@ -570,6 +575,7 @@ def _create_tfs_model(
entry_point=None,
source_dir=None,
dependencies=None,
**kwargs
):
"""Placeholder docstring"""
return Model(
Expand All @@ -585,6 +591,7 @@ def _create_tfs_model(
source_dir=source_dir,
dependencies=dependencies,
enable_network_isolation=self.enable_network_isolation(),
**kwargs
)

def _create_default_model(
Expand All @@ -595,6 +602,7 @@ def _create_default_model(
entry_point=None,
source_dir=None,
dependencies=None,
**kwargs
):
"""Placeholder docstring"""
return TensorFlowModel(
Expand All @@ -615,6 +623,7 @@ def _create_default_model(
vpc_config=self.get_vpc_config(vpc_config_override),
dependencies=dependencies or self.dependencies,
enable_network_isolation=self.enable_network_isolation(),
**kwargs
)

def hyperparameters(self):
Expand Down
1 change: 1 addition & 0 deletions src/sagemaker/tensorflow/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ def prepare_container_def(self, instance_type, accelerator_type=None):
self.model_data,
model_data,
self.sagemaker_session,
kms_key=self.model_kms_key,
)
else:
model_data = self.model_data
Expand Down
12 changes: 9 additions & 3 deletions src/sagemaker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@ def repack_model(
model_uri,
repacked_model_uri,
sagemaker_session,
kms_key=None,
):
"""Unpack model tarball and creates a new model tarball with the provided
code script.
Expand Down Expand Up @@ -400,6 +401,7 @@ def repack_model(
model will be saved
sagemaker_session (sagemaker.session.Session): a sagemaker session to
interact with S3.
kms_key (str): KMS key ARN for encrypting the repacked model file

Returns:
str: path to the new packed model
Expand All @@ -417,10 +419,10 @@ def repack_model(
with tarfile.open(tmp_model_path, mode="w:gz") as t:
t.add(model_dir, arcname=os.path.sep)

_save_model(repacked_model_uri, tmp_model_path, sagemaker_session)
_save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key=kms_key)


def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session):
def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key):
"""
Args:
repacked_model_uri:
Expand All @@ -432,8 +434,12 @@ def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session):
bucket, key = url.netloc, url.path.lstrip("/")
new_key = key.replace(os.path.basename(key), os.path.basename(repacked_model_uri))

if kms_key:
extra_args = {"ServerSideEncryption": "aws:kms", "SSEKMSKeyId": kms_key}
else:
extra_args = None
sagemaker_session.boto_session.resource("s3").Object(bucket, new_key).upload_file(
tmp_model_path
tmp_model_path, ExtraArgs=extra_args
)
else:
shutil.move(tmp_model_path, repacked_model_uri.replace("file://", ""))
Expand Down
18 changes: 13 additions & 5 deletions tests/integ/test_tf_script_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import tests.integ
from tests.integ import timeout
from tests.integ import kms_utils
from tests.integ.retry import retries
from tests.integ.s3_utils import assert_s3_files_exist

Expand Down Expand Up @@ -67,16 +68,14 @@ def test_mnist(sagemaker_session, instance_type):

def test_server_side_encryption(sagemaker_session):
boto_session = sagemaker_session.boto_session
with tests.integ.kms_utils.bucket_with_encryption(boto_session, ROLE) as (
bucket_with_kms,
kms_key,
):
with kms_utils.bucket_with_encryption(boto_session, ROLE) as (bucket_with_kms, kms_key):
output_path = os.path.join(
bucket_with_kms, "test-server-side-encryption", time.strftime("%y%m%d-%H%M")
)

estimator = TensorFlow(
entry_point=SCRIPT,
entry_point="training.py",
source_dir=TFS_RESOURCE_PATH,
role=ROLE,
train_instance_count=1,
train_instance_type="ml.c5.xlarge",
Expand All @@ -99,6 +98,15 @@ def test_server_side_encryption(sagemaker_session):
inputs=inputs, job_name=unique_name_from_base("test-server-side-encryption")
)

endpoint_name = unique_name_from_base("test-server-side-encryption")
with timeout.timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
estimator.deploy(
initial_instance_count=1,
instance_type="ml.c5.xlarge",
endpoint_name=endpoint_name,
entry_point=os.path.join(TFS_RESOURCE_PATH, "inference.py"),
)


@pytest.mark.canary_quick
def test_mnist_distributed(sagemaker_session, instance_type):
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,15 @@ def create_model(
model_server_workers=None,
entry_point=None,
vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT,
**kwargs
):
return DummyFrameworkModel(
self.sagemaker_session,
vpc_config=self.get_vpc_config(vpc_config_override),
entry_point=entry_point,
enable_network_isolation=self.enable_network_isolation(),
role=role,
**kwargs
)

@classmethod
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/test_mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,13 +416,15 @@ def test_model(sagemaker_session):

@patch("sagemaker.utils.repack_model")
def test_model_mms_version(repack_model, sagemaker_session):
model_kms_key = "kms-key"
model = MXNetModel(
MODEL_DATA,
role=ROLE,
entry_point=SCRIPT_PATH,
framework_version=MXNetModel._LOWEST_MMS_VERSION,
sagemaker_session=sagemaker_session,
name="test-mxnet-model",
model_kms_key=model_kms_key,
)
predictor = model.deploy(1, GPU)

Expand All @@ -433,6 +435,7 @@ def test_model_mms_version(repack_model, sagemaker_session):
model_uri=MODEL_DATA,
repacked_model_uri="s3://mybucket/test-mxnet-model/model.tar.gz",
sagemaker_session=sagemaker_session,
kms_key=model_kms_key,
)

assert model.model_data == MODEL_DATA
Expand Down
Loading