-
Notifications
You must be signed in to change notification settings - Fork 1.2k
feature: adding resourcekey and tags for api in config for SDK defaults #3915
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c9d25f6
ebd7872
72b00b6
dacac74
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,6 +35,7 @@ | |
EDGE_PACKAGING_KMS_KEY_ID_PATH, | ||
EDGE_PACKAGING_ROLE_ARN_PATH, | ||
MODEL_CONTAINERS_PATH, | ||
EDGE_PACKAGING_RESOURCE_KEY_PATH, | ||
MODEL_VPC_CONFIG_PATH, | ||
MODEL_ENABLE_NETWORK_ISOLATION_PATH, | ||
MODEL_EXECUTION_ROLE_ARN_PATH, | ||
|
@@ -50,7 +51,10 @@ | |
from sagemaker.predictor import PredictorBase | ||
from sagemaker.serverless import ServerlessInferenceConfig | ||
from sagemaker.transformer import Transformer | ||
from sagemaker.jumpstart.utils import add_jumpstart_tags, get_jumpstart_base_name_if_jumpstart_model | ||
from sagemaker.jumpstart.utils import ( | ||
add_jumpstart_tags, | ||
get_jumpstart_base_name_if_jumpstart_model, | ||
) | ||
from sagemaker.utils import ( | ||
unique_name_from_base, | ||
update_container_with_inference_params, | ||
|
@@ -63,15 +67,25 @@ | |
from sagemaker.workflow import is_pipeline_variable | ||
from sagemaker.workflow.entities import PipelineVariable | ||
from sagemaker.workflow.pipeline_context import runnable_by_pipeline, PipelineSession | ||
from sagemaker.inference_recommender.inference_recommender_mixin import InferenceRecommenderMixin | ||
from sagemaker.inference_recommender.inference_recommender_mixin import ( | ||
InferenceRecommenderMixin, | ||
) | ||
|
||
LOGGER = logging.getLogger("sagemaker") | ||
|
||
NEO_ALLOWED_FRAMEWORKS = set( | ||
["mxnet", "tensorflow", "keras", "pytorch", "onnx", "xgboost", "tflite"] | ||
) | ||
|
||
NEO_IOC_TARGET_DEVICES = ["ml_c4", "ml_c5", "ml_m4", "ml_m5", "ml_p2", "ml_p3", "ml_g4dn"] | ||
NEO_IOC_TARGET_DEVICES = [ | ||
"ml_c4", | ||
"ml_c5", | ||
"ml_m4", | ||
"ml_m5", | ||
"ml_p2", | ||
"ml_p3", | ||
"ml_g4dn", | ||
] | ||
|
||
NEO_MULTIVERSION_UNSUPPORTED = [ | ||
"imx8mplus", | ||
|
@@ -300,7 +314,9 @@ def __init__( | |
self._base_name = None | ||
self.sagemaker_session = sagemaker_session | ||
self.role = resolve_value_from_config( | ||
role, MODEL_EXECUTION_ROLE_ARN_PATH, sagemaker_session=self.sagemaker_session | ||
role, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. qq: did some configuration change recently for how we do spacing? Odd that non-touched things are changing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. True - I think this is a result of |
||
MODEL_EXECUTION_ROLE_ARN_PATH, | ||
sagemaker_session=self.sagemaker_session, | ||
) | ||
self.vpc_config = resolve_value_from_config( | ||
vpc_config, MODEL_VPC_CONFIG_PATH, sagemaker_session=self.sagemaker_session | ||
|
@@ -585,7 +601,9 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None: | |
local_code = utils.get_config_value("local.local_code", self.sagemaker_session.config) | ||
|
||
bucket, key_prefix = s3.determine_bucket_and_prefix( | ||
bucket=self.bucket, key_prefix=key_prefix, sagemaker_session=self.sagemaker_session | ||
bucket=self.bucket, | ||
key_prefix=key_prefix, | ||
sagemaker_session=self.sagemaker_session, | ||
) | ||
|
||
if (self.sagemaker_session.local_mode and local_code) or self.entry_point is None: | ||
|
@@ -633,7 +651,8 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None: | |
else: | ||
repacked_model_data = "s3://" + "/".join([bucket, key_prefix, "model.tar.gz"]) | ||
self.uploaded_code = fw_utils.UploadedCode( | ||
s3_prefix=repacked_model_data, script_name=os.path.basename(self.entry_point) | ||
s3_prefix=repacked_model_data, | ||
script_name=os.path.basename(self.entry_point), | ||
) | ||
|
||
LOGGER.info( | ||
|
@@ -693,7 +712,11 @@ def enable_network_isolation(self): | |
return False if not self._enable_network_isolation else self._enable_network_isolation | ||
|
||
def _create_sagemaker_model( | ||
self, instance_type=None, accelerator_type=None, tags=None, serverless_inference_config=None | ||
self, | ||
instance_type=None, | ||
accelerator_type=None, | ||
tags=None, | ||
serverless_inference_config=None, | ||
): | ||
"""Create a SageMaker Model Entity | ||
|
||
|
@@ -734,10 +757,14 @@ def _create_sagemaker_model( | |
self._init_sagemaker_session_if_does_not_exist(instance_type) | ||
# Depending on the instance type, a local session (or) a session is initialized. | ||
self.role = resolve_value_from_config( | ||
self.role, MODEL_EXECUTION_ROLE_ARN_PATH, sagemaker_session=self.sagemaker_session | ||
self.role, | ||
MODEL_EXECUTION_ROLE_ARN_PATH, | ||
sagemaker_session=self.sagemaker_session, | ||
) | ||
self.vpc_config = resolve_value_from_config( | ||
self.vpc_config, MODEL_VPC_CONFIG_PATH, sagemaker_session=self.sagemaker_session | ||
self.vpc_config, | ||
MODEL_VPC_CONFIG_PATH, | ||
sagemaker_session=self.sagemaker_session, | ||
) | ||
self._enable_network_isolation = resolve_value_from_config( | ||
self._enable_network_isolation, | ||
|
@@ -955,11 +982,16 @@ def package_for_edge( | |
job_name = f"packaging{self._compilation_job_name[11:]}" | ||
self._init_sagemaker_session_if_does_not_exist(None) | ||
s3_kms_key = resolve_value_from_config( | ||
s3_kms_key, EDGE_PACKAGING_KMS_KEY_ID_PATH, sagemaker_session=self.sagemaker_session | ||
s3_kms_key, | ||
EDGE_PACKAGING_KMS_KEY_ID_PATH, | ||
sagemaker_session=self.sagemaker_session, | ||
) | ||
role = resolve_value_from_config( | ||
role, EDGE_PACKAGING_ROLE_ARN_PATH, sagemaker_session=self.sagemaker_session | ||
) | ||
resource_key = resolve_value_from_config( | ||
resource_key, EDGE_PACKAGING_RESOURCE_KEY_PATH, sagemaker_session=self.sagemaker_session | ||
) | ||
rubanh marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if role is not None: | ||
role = self.sagemaker_session.expand_role(role) | ||
config = self._edge_packaging_job_config( | ||
|
@@ -1065,7 +1097,9 @@ def compile( | |
|
||
self._init_sagemaker_session_if_does_not_exist(target_instance_family) | ||
role = resolve_value_from_config( | ||
role, COMPILATION_JOB_ROLE_ARN_PATH, sagemaker_session=self.sagemaker_session | ||
role, | ||
COMPILATION_JOB_ROLE_ARN_PATH, | ||
sagemaker_session=self.sagemaker_session, | ||
) | ||
if not role: | ||
# Originally IAM role was a required parameter. | ||
|
@@ -1232,10 +1266,14 @@ def deploy( | |
self._init_sagemaker_session_if_does_not_exist(instance_type) | ||
# Depending on the instance type, a local session (or) a session is initialized. | ||
self.role = resolve_value_from_config( | ||
self.role, MODEL_EXECUTION_ROLE_ARN_PATH, sagemaker_session=self.sagemaker_session | ||
self.role, | ||
MODEL_EXECUTION_ROLE_ARN_PATH, | ||
sagemaker_session=self.sagemaker_session, | ||
) | ||
self.vpc_config = resolve_value_from_config( | ||
self.vpc_config, MODEL_VPC_CONFIG_PATH, sagemaker_session=self.sagemaker_session | ||
self.vpc_config, | ||
MODEL_VPC_CONFIG_PATH, | ||
sagemaker_session=self.sagemaker_session, | ||
) | ||
self._enable_network_isolation = resolve_value_from_config( | ||
self._enable_network_isolation, | ||
|
@@ -1244,7 +1282,9 @@ def deploy( | |
) | ||
|
||
tags = add_jumpstart_tags( | ||
tags=tags, inference_model_uri=self.model_data, inference_script_uri=self.source_dir | ||
tags=tags, | ||
inference_model_uri=self.model_data, | ||
inference_script_uri=self.source_dir, | ||
) | ||
|
||
if self.role is None: | ||
|
@@ -1292,7 +1332,9 @@ def deploy( | |
compiled_model_suffix = None if is_serverless else "-".join(instance_type.split(".")[:-1]) | ||
if self._is_compiled_model and not is_serverless: | ||
self._ensure_base_name_if_needed( | ||
image_uri=self.image_uri, script_uri=self.source_dir, model_uri=self.model_data | ||
image_uri=self.image_uri, | ||
script_uri=self.source_dir, | ||
model_uri=self.model_data, | ||
) | ||
if self._base_name is not None: | ||
self._base_name = "-".join((self._base_name, compiled_model_suffix)) | ||
|
@@ -1673,7 +1715,12 @@ class ModelPackage(Model): | |
"""A SageMaker ``Model`` that can be deployed to an ``Endpoint``.""" | ||
|
||
def __init__( | ||
self, role=None, model_data=None, algorithm_arn=None, model_package_arn=None, **kwargs | ||
self, | ||
role=None, | ||
model_data=None, | ||
algorithm_arn=None, | ||
model_package_arn=None, | ||
**kwargs, | ||
): | ||
"""Initialize a SageMaker ModelPackage. | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/bot run pr