Skip to content

Commit 76d46d0

Browse files
authored
fix: enable kms support for repack_model (#1061)
* fix: enable kms support for repack_model Currently repack_model doesn't accept a kms key. This change added a kms_key argument to the fucntion. In addition repack_model will always use the output_kms_key inside the Estimator if it's set.
1 parent d368524 commit 76d46d0

File tree

17 files changed

+69
-13
lines changed

17 files changed

+69
-13
lines changed

src/sagemaker/amazon/kmeans.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def __init__(
148148
self.center_factor = center_factor
149149
self.eval_metrics = eval_metrics
150150

151-
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
151+
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs):
152152
"""Return a :class:`~sagemaker.amazon.kmeans.KMeansModel` referencing
153153
the latest s3 model data produced by this Estimator.
154154
@@ -158,12 +158,14 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
158158
Default: use subnets and security groups from this Estimator.
159159
* 'Subnets' (list[str]): List of subnet ids.
160160
* 'SecurityGroupIds' (list[str]): List of security group ids.
161+
**kwargs: Additional kwargs passed to the KMeansModel constructor.
161162
"""
162163
return KMeansModel(
163164
self.model_data,
164165
self.role,
165166
self.sagemaker_session,
166167
vpc_config=self.get_vpc_config(vpc_config_override),
168+
**kwargs
167169
)
168170

169171
def _prepare_for_training(self, records, mini_batch_size=5000, job_name=None):

src/sagemaker/amazon/lda.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def __init__(
122122
self.max_iterations = max_iterations
123123
self.tol = tol
124124

125-
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
125+
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs):
126126
"""Return a :class:`~sagemaker.amazon.LDAModel` referencing the latest
127127
s3 model data produced by this Estimator.
128128
@@ -132,12 +132,14 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
132132
Default: use subnets and security groups from this Estimator.
133133
* 'Subnets' (list[str]): List of subnet ids.
134134
* 'SecurityGroupIds' (list[str]): List of security group ids.
135+
**kwargs: Additional kwargs passed to the LDAModel constructor.
135136
"""
136137
return LDAModel(
137138
self.model_data,
138139
self.role,
139140
sagemaker_session=self.sagemaker_session,
140141
vpc_config=self.get_vpc_config(vpc_config_override),
142+
**kwargs
141143
)
142144

143145
def _prepare_for_training( # pylint: disable=signature-differs

src/sagemaker/amazon/linear_learner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@ def __init__(
373373
"value greater than 2."
374374
)
375375

376-
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
376+
def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT, **kwargs):
377377
"""Return a :class:`~sagemaker.amazon.LinearLearnerModel` referencing
378378
the latest s3 model data produced by this Estimator.
379379
@@ -382,12 +382,14 @@ def create_model(self, vpc_config_override=VPC_CONFIG_DEFAULT):
382382
the model. Default: use subnets and security groups from this Estimator.
383383
* 'Subnets' (list[str]): List of subnet ids.
384384
* 'SecurityGroupIds' (list[str]): List of security group ids.
385+
**kwargs: Additional kwargs passed to the LinearLearnerModel constructor.
385386
"""
386387
return LinearLearnerModel(
387388
self.model_data,
388389
self.role,
389390
self.sagemaker_session,
390391
vpc_config=self.get_vpc_config(vpc_config_override),
392+
**kwargs
391393
)
392394

393395
def _prepare_for_training(self, records, mini_batch_size=None, job_name=None):

src/sagemaker/chainer/estimator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ def create_model(
162162
entry_point=None,
163163
source_dir=None,
164164
dependencies=None,
165+
**kwargs
165166
):
166167
"""Create a SageMaker ``ChainerModel`` object that can be deployed to an
167168
``Endpoint``.
@@ -186,6 +187,7 @@ def create_model(
186187
dependencies (list[str]): A list of paths to directories (absolute or relative) with
187188
any additional libraries that will be exported to the container.
188189
If not specified, the dependencies from training are used.
190+
**kwargs: Additional kwargs passed to the ChainerModel constructor.
189191
190192
Returns:
191193
sagemaker.chainer.model.ChainerModel: A SageMaker ``ChainerModel``

src/sagemaker/estimator.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,7 @@ def deploy(
547547
)
548548
model = self._compiled_models[family]
549549
else:
550+
kwargs["model_kms_key"] = self.output_kms_key
550551
model = self.create_model(**kwargs)
551552
model.name = model_name
552553
return model.deploy(
@@ -734,7 +735,9 @@ def transformer(
734735
model_name = self._current_job_name
735736
else:
736737
model_name = self.latest_training_job.name
737-
model = self.create_model(vpc_config_override=vpc_config_override)
738+
model = self.create_model(
739+
vpc_config_override=vpc_config_override, model_kms_key=self.output_kms_key
740+
)
738741

739742
# not all create_model() implementations have the same kwargs
740743
model.name = model_name
@@ -1716,6 +1719,7 @@ def transformer(
17161719
model_server_workers=model_server_workers,
17171720
entry_point=entry_point,
17181721
vpc_config_override=vpc_config_override,
1722+
model_kms_key=self.output_kms_key,
17191723
)
17201724
model._create_sagemaker_model(instance_type, tags=tags)
17211725

src/sagemaker/model.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def __init__(
7979
vpc_config=None,
8080
sagemaker_session=None,
8181
enable_network_isolation=False,
82+
model_kms_key=None,
8283
):
8384
"""Initialize an SageMaker ``Model``.
8485
@@ -114,6 +115,8 @@ def __init__(
114115
network isolation in the endpoint, isolating the model
115116
container. No inbound or outbound network calls can be made to
116117
or from the model container.
118+
model_kms_key (str): KMS key ARN used to encrypt the repacked
119+
model archive file if the model is repacked
117120
"""
118121
self.model_data = model_data
119122
self.image = image
@@ -127,6 +130,7 @@ def __init__(
127130
self.endpoint_name = None
128131
self._is_compiled_model = False
129132
self._enable_network_isolation = enable_network_isolation
133+
self.model_kms_key = model_kms_key
130134

131135
def prepare_container_def(
132136
self, instance_type, accelerator_type=None
@@ -799,6 +803,7 @@ def _upload_code(self, key_prefix, repack=False):
799803
model_uri=self.model_data,
800804
repacked_model_uri=repacked_model_data,
801805
sagemaker_session=self.sagemaker_session,
806+
kms_key=self.model_kms_key,
802807
)
803808

804809
self.repacked_model_data = repacked_model_data

src/sagemaker/mxnet/estimator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ def create_model(
141141
source_dir=None,
142142
dependencies=None,
143143
image_name=None,
144+
**kwargs
144145
):
145146
"""Create a SageMaker ``MXNetModel`` object that can be deployed to an
146147
``Endpoint``.
@@ -171,6 +172,7 @@ def create_model(
171172
Examples:
172173
123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0
173174
custom-image:latest.
175+
**kwargs: Additional kwargs passed to the MXNetModel constructor.
174176
175177
Returns:
176178
sagemaker.mxnet.model.MXNetModel: A SageMaker ``MXNetModel`` object.

src/sagemaker/pytorch/estimator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def create_model(
115115
entry_point=None,
116116
source_dir=None,
117117
dependencies=None,
118+
**kwargs
118119
):
119120
"""Create a SageMaker ``PyTorchModel`` object that can be deployed to an
120121
``Endpoint``.
@@ -139,6 +140,7 @@ def create_model(
139140
dependencies (list[str]): A list of paths to directories (absolute or relative) with
140141
any additional libraries that will be exported to the container.
141142
If not specified, the dependencies from training are used.
143+
**kwargs: Additional kwargs passed to the PyTorchModel constructor.
142144
143145
Returns:
144146
sagemaker.pytorch.model.PyTorchModel: A SageMaker ``PyTorchModel``

src/sagemaker/rl/estimator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ def create_model(
163163
entry_point=None,
164164
source_dir=None,
165165
dependencies=None,
166+
**kwargs
166167
):
167168
"""Create a SageMaker ``RLEstimatorModel`` object that can be deployed
168169
to an Endpoint.
@@ -189,6 +190,7 @@ def create_model(
189190
folders will be copied to SageMaker in the same folder where the
190191
entry_point is copied. If the ```source_dir``` points to S3,
191192
code will be uploaded and the S3 location will be used instead.
193+
**kwargs: Additional kwargs passed to the FrameworkModel constructor.
192194
193195
Returns:
194196
sagemaker.model.FrameworkModel: Depending on input parameters returns

src/sagemaker/tensorflow/estimator.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,7 @@ def create_model(
504504
entry_point=None,
505505
source_dir=None,
506506
dependencies=None,
507+
**kwargs
507508
):
508509
"""Create a ``Model`` object that can be used for creating SageMaker model entities,
509510
deploying to a SageMaker endpoint, or starting SageMaker Batch Transform jobs.
@@ -537,6 +538,8 @@ def create_model(
537538
If not specified and ``endpoint_type`` is 'tensorflow-serving', ``dependencies`` is
538539
set to ``None``.
539540
If ``endpoint_type`` is also ``None``, then the dependencies from training are used.
541+
**kwargs: Additional kwargs passed to ``sagemaker.tensorflow.serving.Model`` constructor
542+
and ``sagemaker.tensorflow.model.TensorFlowModel`` constructor.
540543
541544
Returns:
542545
sagemaker.tensorflow.model.TensorFlowModel or sagemaker.tensorflow.serving.Model: A
@@ -552,6 +555,7 @@ def create_model(
552555
entry_point=entry_point,
553556
source_dir=source_dir,
554557
dependencies=dependencies,
558+
**kwargs
555559
)
556560

557561
return self._create_default_model(
@@ -561,6 +565,7 @@ def create_model(
561565
entry_point=entry_point,
562566
source_dir=source_dir,
563567
dependencies=dependencies,
568+
**kwargs
564569
)
565570

566571
def _create_tfs_model(
@@ -570,6 +575,7 @@ def _create_tfs_model(
570575
entry_point=None,
571576
source_dir=None,
572577
dependencies=None,
578+
**kwargs
573579
):
574580
"""Placeholder docstring"""
575581
return Model(
@@ -585,6 +591,7 @@ def _create_tfs_model(
585591
source_dir=source_dir,
586592
dependencies=dependencies,
587593
enable_network_isolation=self.enable_network_isolation(),
594+
**kwargs
588595
)
589596

590597
def _create_default_model(
@@ -595,6 +602,7 @@ def _create_default_model(
595602
entry_point=None,
596603
source_dir=None,
597604
dependencies=None,
605+
**kwargs
598606
):
599607
"""Placeholder docstring"""
600608
return TensorFlowModel(
@@ -615,6 +623,7 @@ def _create_default_model(
615623
vpc_config=self.get_vpc_config(vpc_config_override),
616624
dependencies=dependencies or self.dependencies,
617625
enable_network_isolation=self.enable_network_isolation(),
626+
**kwargs
618627
)
619628

620629
def hyperparameters(self):

src/sagemaker/tensorflow/serving.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ def prepare_container_def(self, instance_type, accelerator_type=None):
230230
self.model_data,
231231
model_data,
232232
self.sagemaker_session,
233+
kms_key=self.model_kms_key,
233234
)
234235
else:
235236
model_data = self.model_data

src/sagemaker/utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,7 @@ def repack_model(
365365
model_uri,
366366
repacked_model_uri,
367367
sagemaker_session,
368+
kms_key=None,
368369
):
369370
"""Unpack model tarball and creates a new model tarball with the provided
370371
code script.
@@ -400,6 +401,7 @@ def repack_model(
400401
model will be saved
401402
sagemaker_session (sagemaker.session.Session): a sagemaker session to
402403
interact with S3.
404+
kms_key (str): KMS key ARN for encrypting the repacked model file
403405
404406
Returns:
405407
str: path to the new packed model
@@ -417,10 +419,10 @@ def repack_model(
417419
with tarfile.open(tmp_model_path, mode="w:gz") as t:
418420
t.add(model_dir, arcname=os.path.sep)
419421

420-
_save_model(repacked_model_uri, tmp_model_path, sagemaker_session)
422+
_save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key=kms_key)
421423

422424

423-
def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session):
425+
def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session, kms_key):
424426
"""
425427
Args:
426428
repacked_model_uri:
@@ -432,8 +434,12 @@ def _save_model(repacked_model_uri, tmp_model_path, sagemaker_session):
432434
bucket, key = url.netloc, url.path.lstrip("/")
433435
new_key = key.replace(os.path.basename(key), os.path.basename(repacked_model_uri))
434436

437+
if kms_key:
438+
extra_args = {"ServerSideEncryption": "aws:kms", "SSEKMSKeyId": kms_key}
439+
else:
440+
extra_args = None
435441
sagemaker_session.boto_session.resource("s3").Object(bucket, new_key).upload_file(
436-
tmp_model_path
442+
tmp_model_path, ExtraArgs=extra_args
437443
)
438444
else:
439445
shutil.move(tmp_model_path, repacked_model_uri.replace("file://", ""))

tests/integ/test_tf_script_mode.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import tests.integ
2525
from tests.integ import timeout
26+
from tests.integ import kms_utils
2627
from tests.integ.retry import retries
2728
from tests.integ.s3_utils import assert_s3_files_exist
2829

@@ -67,16 +68,14 @@ def test_mnist(sagemaker_session, instance_type):
6768

6869
def test_server_side_encryption(sagemaker_session):
6970
boto_session = sagemaker_session.boto_session
70-
with tests.integ.kms_utils.bucket_with_encryption(boto_session, ROLE) as (
71-
bucket_with_kms,
72-
kms_key,
73-
):
71+
with kms_utils.bucket_with_encryption(boto_session, ROLE) as (bucket_with_kms, kms_key):
7472
output_path = os.path.join(
7573
bucket_with_kms, "test-server-side-encryption", time.strftime("%y%m%d-%H%M")
7674
)
7775

7876
estimator = TensorFlow(
79-
entry_point=SCRIPT,
77+
entry_point="training.py",
78+
source_dir=TFS_RESOURCE_PATH,
8079
role=ROLE,
8180
train_instance_count=1,
8281
train_instance_type="ml.c5.xlarge",
@@ -99,6 +98,15 @@ def test_server_side_encryption(sagemaker_session):
9998
inputs=inputs, job_name=unique_name_from_base("test-server-side-encryption")
10099
)
101100

101+
endpoint_name = unique_name_from_base("test-server-side-encryption")
102+
with timeout.timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
103+
estimator.deploy(
104+
initial_instance_count=1,
105+
instance_type="ml.c5.xlarge",
106+
endpoint_name=endpoint_name,
107+
entry_point=os.path.join(TFS_RESOURCE_PATH, "inference.py"),
108+
)
109+
102110

103111
@pytest.mark.canary_quick
104112
def test_mnist_distributed(sagemaker_session, instance_type):

tests/unit/test_estimator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,13 +119,15 @@ def create_model(
119119
model_server_workers=None,
120120
entry_point=None,
121121
vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT,
122+
**kwargs
122123
):
123124
return DummyFrameworkModel(
124125
self.sagemaker_session,
125126
vpc_config=self.get_vpc_config(vpc_config_override),
126127
entry_point=entry_point,
127128
enable_network_isolation=self.enable_network_isolation(),
128129
role=role,
130+
**kwargs
129131
)
130132

131133
@classmethod

tests/unit/test_mxnet.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,13 +416,15 @@ def test_model(sagemaker_session):
416416

417417
@patch("sagemaker.utils.repack_model")
418418
def test_model_mms_version(repack_model, sagemaker_session):
419+
model_kms_key = "kms-key"
419420
model = MXNetModel(
420421
MODEL_DATA,
421422
role=ROLE,
422423
entry_point=SCRIPT_PATH,
423424
framework_version=MXNetModel._LOWEST_MMS_VERSION,
424425
sagemaker_session=sagemaker_session,
425426
name="test-mxnet-model",
427+
model_kms_key=model_kms_key,
426428
)
427429
predictor = model.deploy(1, GPU)
428430

@@ -433,6 +435,7 @@ def test_model_mms_version(repack_model, sagemaker_session):
433435
model_uri=MODEL_DATA,
434436
repacked_model_uri="s3://mybucket/test-mxnet-model/model.tar.gz",
435437
sagemaker_session=sagemaker_session,
438+
kms_key=model_kms_key,
436439
)
437440

438441
assert model.model_data == MODEL_DATA

0 commit comments

Comments
 (0)