Skip to content

Commit c6d44da

Browse files
committed
fix: add conditionals to include container variables
1 parent e6fe39c commit c6d44da

File tree

5 files changed

+67
-75
lines changed

5 files changed

+67
-75
lines changed

src/sagemaker/model.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -374,16 +374,22 @@ def register(
374374
"Image": self.image_uri,
375375
"ModelDataUrl": self.model_data,
376376
}
377-
container_def.update(
378-
{
379-
"Framework": framework,
380-
"FrameworkVersion": framework_version,
381-
"NearestModelName": nearest_model_name,
382-
"ModelInput": {
383-
"DataInputConfig": data_input_configuration,
384-
},
385-
}
386-
)
377+
if (
378+
framework is not None
379+
and framework_version is not None
380+
and nearest_model_name is not None
381+
and data_input_configuration is not None
382+
):
383+
container_def.update(
384+
{
385+
"Framework": framework,
386+
"FrameworkVersion": framework_version,
387+
"NearestModelName": nearest_model_name,
388+
"ModelInput": {
389+
"DataInputConfig": data_input_configuration,
390+
},
391+
}
392+
)
387393
model_pkg_args = sagemaker.get_model_package_args(
388394
content_types,
389395
response_types,

src/sagemaker/pipeline.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -337,31 +337,31 @@ def register(
337337
container_def = self.pipeline_container_def(
338338
inference_instances[0] if inference_instances else None
339339
)
340-
container_def[0].update(
341-
{
342-
"Framework": framework,
343-
"FrameworkVersion": framework_version,
344-
"NearestModelName": nearest_model_name,
345-
"ModelInput": {
346-
"DataInputConfig": data_input_configuration,
347-
},
348-
}
349-
)
350340
else:
351341
container_def = [
352342
{
353343
"Image": image_uri or model.image_uri,
354344
"ModelDataUrl": model.model_data,
355-
"Framework": framework or model.framework,
356-
"FrameworkVersion": framework_version or model.framework_version,
357-
"NearestModelName": nearest_model_name or model.nearest_model_name,
358-
"ModelInput": {
359-
"DataInputConfig": data_input_configuration
360-
or model.data_input_configuration
361-
},
362345
}
363346
for model in self.models
364347
]
348+
if (
349+
framework is not None
350+
and framework_version is not None
351+
and nearest_model_name is not None
352+
and data_input_configuration is not None
353+
):
354+
for container_obj in container_def:
355+
container_obj.update(
356+
{
357+
"Framework": framework,
358+
"FrameworkVersion": framework_version,
359+
"NearestModelName": nearest_model_name,
360+
"ModelInput": {
361+
"DataInputConfig": data_input_configuration,
362+
},
363+
}
364+
)
365365

366366
model_pkg_args = sagemaker.get_model_package_args(
367367
content_types,

src/sagemaker/session.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4299,27 +4299,28 @@ def get_model_package_args(
42994299
dict: A dictionary of method argument names and values.
43004300
"""
43014301
if container_def_list is not None:
4302-
container_def_list[0].update(
4303-
{
4304-
"Framework": container_def_list[0]["Framework"],
4305-
"FrameworkVersion": container_def_list[0]["FrameworkVersion"],
4306-
"NearestModelName": container_def_list[0]["NearestModelName"],
4307-
"ModelInput": {
4308-
"DataInputConfig": container_def_list[0]["ModelInput"]["DataInputConfig"],
4309-
},
4310-
}
4311-
)
4302+
container_fields = container_def_list[0]
4303+
if (
4304+
container_fields.get("Framework") is not None
4305+
and container_fields.get("FrameworkVersion") is not None
4306+
and container_fields.get("NearestModelName") is not None
4307+
and container_fields.get("ModelInput").get("DataInputConfig") is not None
4308+
):
4309+
container_def_list[0].update(
4310+
{
4311+
"Framework": container_fields["Framework"],
4312+
"FrameworkVersion": container_fields["FrameworkVersion"],
4313+
"NearestModelName": container_fields["NearestModelName"],
4314+
"ModelInput": {
4315+
"DataInputConfig": container_fields["ModelInput"]["DataInputConfig"],
4316+
},
4317+
}
4318+
)
43124319
containers = container_def_list
43134320
else:
43144321
container = {
43154322
"Image": image_uri,
43164323
"ModelDataUrl": model_data,
4317-
"Framework": None,
4318-
"FrameworkVersion": None,
4319-
"NearestModelName": None,
4320-
"ModelInput": {
4321-
"DataInputConfig": None,
4322-
},
43234324
}
43244325
containers = [container]
43254326

src/sagemaker/workflow/step_collections.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -246,17 +246,23 @@ def __init__(
246246
inference_instances[0] if inference_instances else None
247247
)
248248
]
249-
for container_obj in self.container_def_list:
250-
container_obj.update(
251-
{
252-
"Framework": framework,
253-
"FrameworkVersion": framework_version,
254-
"NearestModelName": nearest_model_name,
255-
"ModelInput": {
256-
"DataInputConfig": data_input_configuration,
257-
},
258-
}
259-
)
249+
if (
250+
framework is not None
251+
and framework_version is not None
252+
and nearest_model_name is not None
253+
and data_input_configuration is not None
254+
):
255+
for container_obj in self.container_def_list:
256+
container_obj.update(
257+
{
258+
"Framework": framework,
259+
"FrameworkVersion": framework_version,
260+
"NearestModelName": nearest_model_name,
261+
"ModelInput": {
262+
"DataInputConfig": data_input_configuration,
263+
},
264+
}
265+
)
260266
register_model_step = _RegisterModelStep(
261267
name=name,
262268
estimator=estimator,

tests/unit/sagemaker/workflow/test_step_collections.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -389,12 +389,6 @@ def test_register_model(estimator, model_metrics, drift_check_baselines):
389389
{
390390
"Image": "012345678901.dkr.ecr.us-west-2.amazonaws.com/my-custom-image-uri",
391391
"ModelDataUrl": f"s3://{BUCKET}/model.tar.gz",
392-
"Framework": None,
393-
"FrameworkVersion": None,
394-
"NearestModelName": None,
395-
"ModelInput": {
396-
"DataInputConfig": None,
397-
},
398392
}
399393
],
400394
"SupportedContentTypes": ["content_type"],
@@ -466,12 +460,6 @@ def test_register_model_tf(estimator_tf, model_metrics, drift_check_baselines):
466460
{
467461
"Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/tensorflow-inference:1.15.2-cpu",
468462
"ModelDataUrl": f"s3://{BUCKET}/model.tar.gz",
469-
"Framework": None,
470-
"FrameworkVersion": None,
471-
"NearestModelName": None,
472-
"ModelInput": {
473-
"DataInputConfig": None,
474-
},
475463
}
476464
],
477465
"SupportedContentTypes": ["content_type"],
@@ -703,15 +691,6 @@ def test_register_model_with_model_repack_with_estimator(
703691
assert isinstance(
704692
arguments["InferenceSpecification"]["Containers"][0]["ModelDataUrl"], Properties
705693
)
706-
assert arguments["InferenceSpecification"]["Containers"][0]["Framework"] is None
707-
assert arguments["InferenceSpecification"]["Containers"][0]["FrameworkVersion"] is None
708-
assert arguments["InferenceSpecification"]["Containers"][0]["NearestModelName"] is None
709-
assert (
710-
arguments["InferenceSpecification"]["Containers"][0]["ModelInput"][
711-
"DataInputConfig"
712-
]
713-
is None
714-
)
715694
del arguments["InferenceSpecification"]["Containers"]
716695
assert ordered(arguments) == ordered(
717696
{

0 commit comments

Comments
 (0)