50
50
from sagemaker .predictor import PredictorBase
51
51
from sagemaker .serverless import ServerlessInferenceConfig
52
52
from sagemaker .transformer import Transformer
53
- from sagemaker .jumpstart .utils import add_jumpstart_tags , get_jumpstart_base_name_if_jumpstart_model
53
+ from sagemaker .jumpstart .utils import (
54
+ add_jumpstart_tags ,
55
+ get_jumpstart_base_name_if_jumpstart_model ,
56
+ )
54
57
from sagemaker .utils import (
55
58
unique_name_from_base ,
56
59
update_container_with_inference_params ,
63
66
from sagemaker .workflow import is_pipeline_variable
64
67
from sagemaker .workflow .entities import PipelineVariable
65
68
from sagemaker .workflow .pipeline_context import runnable_by_pipeline , PipelineSession
66
- from sagemaker .inference_recommender .inference_recommender_mixin import InferenceRecommenderMixin
69
+ from sagemaker .inference_recommender .inference_recommender_mixin import (
70
+ InferenceRecommenderMixin ,
71
+ )
67
72
68
73
LOGGER = logging .getLogger ("sagemaker" )
69
74
70
75
NEO_ALLOWED_FRAMEWORKS = set (
71
76
["mxnet" , "tensorflow" , "keras" , "pytorch" , "onnx" , "xgboost" , "tflite" ]
72
77
)
73
78
74
- NEO_IOC_TARGET_DEVICES = ["ml_c4" , "ml_c5" , "ml_m4" , "ml_m5" , "ml_p2" , "ml_p3" , "ml_g4dn" ]
79
+ NEO_IOC_TARGET_DEVICES = [
80
+ "ml_c4" ,
81
+ "ml_c5" ,
82
+ "ml_m4" ,
83
+ "ml_m5" ,
84
+ "ml_p2" ,
85
+ "ml_p3" ,
86
+ "ml_g4dn" ,
87
+ ]
75
88
76
89
NEO_MULTIVERSION_UNSUPPORTED = [
77
90
"imx8mplus" ,
@@ -300,7 +313,9 @@ def __init__(
300
313
self ._base_name = None
301
314
self .sagemaker_session = sagemaker_session
302
315
self .role = resolve_value_from_config (
303
- role , MODEL_EXECUTION_ROLE_ARN_PATH , sagemaker_session = self .sagemaker_session
316
+ role ,
317
+ MODEL_EXECUTION_ROLE_ARN_PATH ,
318
+ sagemaker_session = self .sagemaker_session ,
304
319
)
305
320
self .vpc_config = resolve_value_from_config (
306
321
vpc_config , MODEL_VPC_CONFIG_PATH , sagemaker_session = self .sagemaker_session
@@ -585,7 +600,9 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
585
600
local_code = utils .get_config_value ("local.local_code" , self .sagemaker_session .config )
586
601
587
602
bucket , key_prefix = s3 .determine_bucket_and_prefix (
588
- bucket = self .bucket , key_prefix = key_prefix , sagemaker_session = self .sagemaker_session
603
+ bucket = self .bucket ,
604
+ key_prefix = key_prefix ,
605
+ sagemaker_session = self .sagemaker_session ,
589
606
)
590
607
591
608
if (self .sagemaker_session .local_mode and local_code ) or self .entry_point is None :
@@ -633,7 +650,8 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
633
650
else :
634
651
repacked_model_data = "s3://" + "/" .join ([bucket , key_prefix , "model.tar.gz" ])
635
652
self .uploaded_code = fw_utils .UploadedCode (
636
- s3_prefix = repacked_model_data , script_name = os .path .basename (self .entry_point )
653
+ s3_prefix = repacked_model_data ,
654
+ script_name = os .path .basename (self .entry_point ),
637
655
)
638
656
639
657
LOGGER .info (
@@ -693,7 +711,11 @@ def enable_network_isolation(self):
693
711
return False if not self ._enable_network_isolation else self ._enable_network_isolation
694
712
695
713
def _create_sagemaker_model (
696
- self , instance_type = None , accelerator_type = None , tags = None , serverless_inference_config = None
714
+ self ,
715
+ instance_type = None ,
716
+ accelerator_type = None ,
717
+ tags = None ,
718
+ serverless_inference_config = None ,
697
719
):
698
720
"""Create a SageMaker Model Entity
699
721
@@ -734,10 +756,14 @@ def _create_sagemaker_model(
734
756
self ._init_sagemaker_session_if_does_not_exist (instance_type )
735
757
# Depending on the instance type, a local session (or) a session is initialized.
736
758
self .role = resolve_value_from_config (
737
- self .role , MODEL_EXECUTION_ROLE_ARN_PATH , sagemaker_session = self .sagemaker_session
759
+ self .role ,
760
+ MODEL_EXECUTION_ROLE_ARN_PATH ,
761
+ sagemaker_session = self .sagemaker_session ,
738
762
)
739
763
self .vpc_config = resolve_value_from_config (
740
- self .vpc_config , MODEL_VPC_CONFIG_PATH , sagemaker_session = self .sagemaker_session
764
+ self .vpc_config ,
765
+ MODEL_VPC_CONFIG_PATH ,
766
+ sagemaker_session = self .sagemaker_session ,
741
767
)
742
768
self ._enable_network_isolation = resolve_value_from_config (
743
769
self ._enable_network_isolation ,
@@ -955,12 +981,16 @@ def package_for_edge(
955
981
job_name = f"packaging{ self ._compilation_job_name [11 :]} "
956
982
self ._init_sagemaker_session_if_does_not_exist (None )
957
983
s3_kms_key = resolve_value_from_config (
958
- s3_kms_key , EDGE_PACKAGING_KMS_KEY_ID_PATH , sagemaker_session = self .sagemaker_session
984
+ s3_kms_key ,
985
+ EDGE_PACKAGING_KMS_KEY_ID_PATH ,
986
+ sagemaker_session = self .sagemaker_session ,
959
987
)
960
988
role = resolve_value_from_config (
961
989
role , EDGE_PACKAGING_ROLE_ARN_PATH , sagemaker_session = self .sagemaker_session
962
990
)
963
- resource_key = resolve_value_from_config (resource_key , EDGE_PACKAGING_RESOURCE_KEY_PATH , sagemaker_session = self )
991
+ resource_key = resolve_value_from_config (
992
+ resource_key , EDGE_PACKAGING_RESOURCE_KEY_PATH , sagemaker_session = self
993
+ )
964
994
if role is not None :
965
995
role = self .sagemaker_session .expand_role (role )
966
996
config = self ._edge_packaging_job_config (
@@ -1066,7 +1096,9 @@ def compile(
1066
1096
1067
1097
self ._init_sagemaker_session_if_does_not_exist (target_instance_family )
1068
1098
role = resolve_value_from_config (
1069
- role , COMPILATION_JOB_ROLE_ARN_PATH , sagemaker_session = self .sagemaker_session
1099
+ role ,
1100
+ COMPILATION_JOB_ROLE_ARN_PATH ,
1101
+ sagemaker_session = self .sagemaker_session ,
1070
1102
)
1071
1103
if not role :
1072
1104
# Originally IAM role was a required parameter.
@@ -1231,10 +1263,14 @@ def deploy(
1231
1263
self ._init_sagemaker_session_if_does_not_exist (instance_type )
1232
1264
# Depending on the instance type, a local session (or) a session is initialized.
1233
1265
self .role = resolve_value_from_config (
1234
- self .role , MODEL_EXECUTION_ROLE_ARN_PATH , sagemaker_session = self .sagemaker_session
1266
+ self .role ,
1267
+ MODEL_EXECUTION_ROLE_ARN_PATH ,
1268
+ sagemaker_session = self .sagemaker_session ,
1235
1269
)
1236
1270
self .vpc_config = resolve_value_from_config (
1237
- self .vpc_config , MODEL_VPC_CONFIG_PATH , sagemaker_session = self .sagemaker_session
1271
+ self .vpc_config ,
1272
+ MODEL_VPC_CONFIG_PATH ,
1273
+ sagemaker_session = self .sagemaker_session ,
1238
1274
)
1239
1275
self ._enable_network_isolation = resolve_value_from_config (
1240
1276
self ._enable_network_isolation ,
@@ -1243,7 +1279,9 @@ def deploy(
1243
1279
)
1244
1280
1245
1281
tags = add_jumpstart_tags (
1246
- tags = tags , inference_model_uri = self .model_data , inference_script_uri = self .source_dir
1282
+ tags = tags ,
1283
+ inference_model_uri = self .model_data ,
1284
+ inference_script_uri = self .source_dir ,
1247
1285
)
1248
1286
1249
1287
if self .role is None :
@@ -1291,7 +1329,9 @@ def deploy(
1291
1329
compiled_model_suffix = None if is_serverless else "-" .join (instance_type .split ("." )[:- 1 ])
1292
1330
if self ._is_compiled_model and not is_serverless :
1293
1331
self ._ensure_base_name_if_needed (
1294
- image_uri = self .image_uri , script_uri = self .source_dir , model_uri = self .model_data
1332
+ image_uri = self .image_uri ,
1333
+ script_uri = self .source_dir ,
1334
+ model_uri = self .model_data ,
1295
1335
)
1296
1336
if self ._base_name is not None :
1297
1337
self ._base_name = "-" .join ((self ._base_name , compiled_model_suffix ))
@@ -1668,7 +1708,12 @@ class ModelPackage(Model):
1668
1708
"""A SageMaker ``Model`` that can be deployed to an ``Endpoint``."""
1669
1709
1670
1710
def __init__ (
1671
- self , role = None , model_data = None , algorithm_arn = None , model_package_arn = None , ** kwargs
1711
+ self ,
1712
+ role = None ,
1713
+ model_data = None ,
1714
+ algorithm_arn = None ,
1715
+ model_package_arn = None ,
1716
+ ** kwargs ,
1672
1717
):
1673
1718
"""Initialize a SageMaker ModelPackage.
1674
1719
0 commit comments