35
35
EDGE_PACKAGING_KMS_KEY_ID_PATH ,
36
36
EDGE_PACKAGING_ROLE_ARN_PATH ,
37
37
MODEL_CONTAINERS_PATH ,
38
+ EDGE_PACKAGING_RESOURCE_KEY_PATH ,
38
39
MODEL_VPC_CONFIG_PATH ,
39
40
MODEL_ENABLE_NETWORK_ISOLATION_PATH ,
40
41
MODEL_EXECUTION_ROLE_ARN_PATH ,
50
51
from sagemaker .predictor import PredictorBase
51
52
from sagemaker .serverless import ServerlessInferenceConfig
52
53
from sagemaker .transformer import Transformer
53
- from sagemaker .jumpstart .utils import add_jumpstart_tags , get_jumpstart_base_name_if_jumpstart_model
54
+ from sagemaker .jumpstart .utils import (
55
+ add_jumpstart_tags ,
56
+ get_jumpstart_base_name_if_jumpstart_model ,
57
+ )
54
58
from sagemaker .utils import (
55
59
unique_name_from_base ,
56
60
update_container_with_inference_params ,
63
67
from sagemaker .workflow import is_pipeline_variable
64
68
from sagemaker .workflow .entities import PipelineVariable
65
69
from sagemaker .workflow .pipeline_context import runnable_by_pipeline , PipelineSession
66
- from sagemaker .inference_recommender .inference_recommender_mixin import InferenceRecommenderMixin
70
+ from sagemaker .inference_recommender .inference_recommender_mixin import (
71
+ InferenceRecommenderMixin ,
72
+ )
67
73
68
74
LOGGER = logging .getLogger ("sagemaker" )
69
75
70
76
NEO_ALLOWED_FRAMEWORKS = set (
71
77
["mxnet" , "tensorflow" , "keras" , "pytorch" , "onnx" , "xgboost" , "tflite" ]
72
78
)
73
79
74
- NEO_IOC_TARGET_DEVICES = ["ml_c4" , "ml_c5" , "ml_m4" , "ml_m5" , "ml_p2" , "ml_p3" , "ml_g4dn" ]
80
+ NEO_IOC_TARGET_DEVICES = [
81
+ "ml_c4" ,
82
+ "ml_c5" ,
83
+ "ml_m4" ,
84
+ "ml_m5" ,
85
+ "ml_p2" ,
86
+ "ml_p3" ,
87
+ "ml_g4dn" ,
88
+ ]
75
89
76
90
NEO_MULTIVERSION_UNSUPPORTED = [
77
91
"imx8mplus" ,
@@ -300,7 +314,9 @@ def __init__(
300
314
self ._base_name = None
301
315
self .sagemaker_session = sagemaker_session
302
316
self .role = resolve_value_from_config (
303
- role , MODEL_EXECUTION_ROLE_ARN_PATH , sagemaker_session = self .sagemaker_session
317
+ role ,
318
+ MODEL_EXECUTION_ROLE_ARN_PATH ,
319
+ sagemaker_session = self .sagemaker_session ,
304
320
)
305
321
self .vpc_config = resolve_value_from_config (
306
322
vpc_config , MODEL_VPC_CONFIG_PATH , sagemaker_session = self .sagemaker_session
@@ -585,7 +601,9 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
585
601
local_code = utils .get_config_value ("local.local_code" , self .sagemaker_session .config )
586
602
587
603
bucket , key_prefix = s3 .determine_bucket_and_prefix (
588
- bucket = self .bucket , key_prefix = key_prefix , sagemaker_session = self .sagemaker_session
604
+ bucket = self .bucket ,
605
+ key_prefix = key_prefix ,
606
+ sagemaker_session = self .sagemaker_session ,
589
607
)
590
608
591
609
if (self .sagemaker_session .local_mode and local_code ) or self .entry_point is None :
@@ -633,7 +651,8 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
633
651
else :
634
652
repacked_model_data = "s3://" + "/" .join ([bucket , key_prefix , "model.tar.gz" ])
635
653
self .uploaded_code = fw_utils .UploadedCode (
636
- s3_prefix = repacked_model_data , script_name = os .path .basename (self .entry_point )
654
+ s3_prefix = repacked_model_data ,
655
+ script_name = os .path .basename (self .entry_point ),
637
656
)
638
657
639
658
LOGGER .info (
@@ -693,7 +712,11 @@ def enable_network_isolation(self):
693
712
return False if not self ._enable_network_isolation else self ._enable_network_isolation
694
713
695
714
def _create_sagemaker_model (
696
- self , instance_type = None , accelerator_type = None , tags = None , serverless_inference_config = None
715
+ self ,
716
+ instance_type = None ,
717
+ accelerator_type = None ,
718
+ tags = None ,
719
+ serverless_inference_config = None ,
697
720
):
698
721
"""Create a SageMaker Model Entity
699
722
@@ -734,10 +757,14 @@ def _create_sagemaker_model(
734
757
self ._init_sagemaker_session_if_does_not_exist (instance_type )
735
758
# Depending on the instance type, a local session (or) a session is initialized.
736
759
self .role = resolve_value_from_config (
737
- self .role , MODEL_EXECUTION_ROLE_ARN_PATH , sagemaker_session = self .sagemaker_session
760
+ self .role ,
761
+ MODEL_EXECUTION_ROLE_ARN_PATH ,
762
+ sagemaker_session = self .sagemaker_session ,
738
763
)
739
764
self .vpc_config = resolve_value_from_config (
740
- self .vpc_config , MODEL_VPC_CONFIG_PATH , sagemaker_session = self .sagemaker_session
765
+ self .vpc_config ,
766
+ MODEL_VPC_CONFIG_PATH ,
767
+ sagemaker_session = self .sagemaker_session ,
741
768
)
742
769
self ._enable_network_isolation = resolve_value_from_config (
743
770
self ._enable_network_isolation ,
@@ -955,11 +982,16 @@ def package_for_edge(
955
982
job_name = f"packaging{ self ._compilation_job_name [11 :]} "
956
983
self ._init_sagemaker_session_if_does_not_exist (None )
957
984
s3_kms_key = resolve_value_from_config (
958
- s3_kms_key , EDGE_PACKAGING_KMS_KEY_ID_PATH , sagemaker_session = self .sagemaker_session
985
+ s3_kms_key ,
986
+ EDGE_PACKAGING_KMS_KEY_ID_PATH ,
987
+ sagemaker_session = self .sagemaker_session ,
959
988
)
960
989
role = resolve_value_from_config (
961
990
role , EDGE_PACKAGING_ROLE_ARN_PATH , sagemaker_session = self .sagemaker_session
962
991
)
992
+ resource_key = resolve_value_from_config (
993
+ resource_key , EDGE_PACKAGING_RESOURCE_KEY_PATH , sagemaker_session = self .sagemaker_session
994
+ )
963
995
if role is not None :
964
996
role = self .sagemaker_session .expand_role (role )
965
997
config = self ._edge_packaging_job_config (
@@ -1065,7 +1097,9 @@ def compile(
1065
1097
1066
1098
self ._init_sagemaker_session_if_does_not_exist (target_instance_family )
1067
1099
role = resolve_value_from_config (
1068
- role , COMPILATION_JOB_ROLE_ARN_PATH , sagemaker_session = self .sagemaker_session
1100
+ role ,
1101
+ COMPILATION_JOB_ROLE_ARN_PATH ,
1102
+ sagemaker_session = self .sagemaker_session ,
1069
1103
)
1070
1104
if not role :
1071
1105
# Originally IAM role was a required parameter.
@@ -1232,10 +1266,14 @@ def deploy(
1232
1266
self ._init_sagemaker_session_if_does_not_exist (instance_type )
1233
1267
# Depending on the instance type, a local session (or) a session is initialized.
1234
1268
self .role = resolve_value_from_config (
1235
- self .role , MODEL_EXECUTION_ROLE_ARN_PATH , sagemaker_session = self .sagemaker_session
1269
+ self .role ,
1270
+ MODEL_EXECUTION_ROLE_ARN_PATH ,
1271
+ sagemaker_session = self .sagemaker_session ,
1236
1272
)
1237
1273
self .vpc_config = resolve_value_from_config (
1238
- self .vpc_config , MODEL_VPC_CONFIG_PATH , sagemaker_session = self .sagemaker_session
1274
+ self .vpc_config ,
1275
+ MODEL_VPC_CONFIG_PATH ,
1276
+ sagemaker_session = self .sagemaker_session ,
1239
1277
)
1240
1278
self ._enable_network_isolation = resolve_value_from_config (
1241
1279
self ._enable_network_isolation ,
@@ -1244,7 +1282,9 @@ def deploy(
1244
1282
)
1245
1283
1246
1284
tags = add_jumpstart_tags (
1247
- tags = tags , inference_model_uri = self .model_data , inference_script_uri = self .source_dir
1285
+ tags = tags ,
1286
+ inference_model_uri = self .model_data ,
1287
+ inference_script_uri = self .source_dir ,
1248
1288
)
1249
1289
1250
1290
if self .role is None :
@@ -1292,7 +1332,9 @@ def deploy(
1292
1332
compiled_model_suffix = None if is_serverless else "-" .join (instance_type .split ("." )[:- 1 ])
1293
1333
if self ._is_compiled_model and not is_serverless :
1294
1334
self ._ensure_base_name_if_needed (
1295
- image_uri = self .image_uri , script_uri = self .source_dir , model_uri = self .model_data
1335
+ image_uri = self .image_uri ,
1336
+ script_uri = self .source_dir ,
1337
+ model_uri = self .model_data ,
1296
1338
)
1297
1339
if self ._base_name is not None :
1298
1340
self ._base_name = "-" .join ((self ._base_name , compiled_model_suffix ))
@@ -1673,7 +1715,12 @@ class ModelPackage(Model):
1673
1715
"""A SageMaker ``Model`` that can be deployed to an ``Endpoint``."""
1674
1716
1675
1717
def __init__ (
1676
- self , role = None , model_data = None , algorithm_arn = None , model_package_arn = None , ** kwargs
1718
+ self ,
1719
+ role = None ,
1720
+ model_data = None ,
1721
+ algorithm_arn = None ,
1722
+ model_package_arn = None ,
1723
+ ** kwargs ,
1677
1724
):
1678
1725
"""Initialize a SageMaker ModelPackage.
1679
1726
0 commit comments