Skip to content

feature: adding resourcekey and tags for api in config for SDK defaults #3915

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/sagemaker/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
COMPILATION_JOB_VPC_CONFIG_PATH,
COMPILATION_JOB,
EDGE_PACKAGING_ROLE_ARN_PATH,
EDGE_PACKAGING_RESOURCE_KEY_PATH,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/bot run pr

EDGE_PACKAGING_OUTPUT_CONFIG_PATH,
EDGE_PACKAGING_JOB,
TRANSFORM_JOB,
Expand All @@ -69,10 +70,13 @@
MODEL_PRIMARY_CONTAINER_ENVIRONMENT_PATH,
ENDPOINT_CONFIG_PRODUCTION_VARIANTS_PATH,
KMS_KEY_ID,
RESOURCE_KEY,
ENDPOINT_CONFIG_KMS_KEY_ID_PATH,
ENDPOINT_CONFIG,
ENDPOINT_CONFIG_DATA_CAPTURE_PATH,
ENDPOINT_CONFIG_ASYNC_INFERENCE_PATH,
ENDPOINT,
ENDPOINT_TAGS_PATH,
SAGEMAKER,
FEATURE_GROUP,
TAGS,
Expand Down
34 changes: 30 additions & 4 deletions src/sagemaker/config/config_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
ENABLE_NETWORK_ISOLATION = "EnableNetworkIsolation"
VOLUME_KMS_KEY_ID = "VolumeKmsKeyId"
KMS_KEY_ID = "KmsKeyId"
RESOURCE_KEY = "ResourceKey"
ROLE_ARN = "RoleArn"
TAGS = "Tags"
KEY = "Key"
Expand Down Expand Up @@ -78,6 +79,7 @@
MODEL = "Model"
MONITORING_SCHEDULE = "MonitoringSchedule"
ENDPOINT_CONFIG = "EndpointConfig"
ENDPOINT = "Endpoint"
AUTO_ML_JOB = "AutoMLJob"
COMPILATION_JOB = "CompilationJob"
CUSTOM_PARAMETERS = "CustomParameters"
Expand Down Expand Up @@ -131,6 +133,7 @@ def _simple_path(*args: str):
)
EDGE_PACKAGING_OUTPUT_CONFIG_PATH = _simple_path(SAGEMAKER, EDGE_PACKAGING_JOB, OUTPUT_CONFIG)
EDGE_PACKAGING_ROLE_ARN_PATH = _simple_path(SAGEMAKER, EDGE_PACKAGING_JOB, ROLE_ARN)
EDGE_PACKAGING_RESOURCE_KEY_PATH = _simple_path(SAGEMAKER, EDGE_PACKAGING_JOB, RESOURCE_KEY)
ENDPOINT_CONFIG_DATA_CAPTURE_KMS_KEY_ID_PATH = _simple_path(
SAGEMAKER, ENDPOINT_CONFIG, DATA_CAPTURE_CONFIG, KMS_KEY_ID
)
Expand All @@ -145,6 +148,7 @@ def _simple_path(*args: str):
SAGEMAKER, ENDPOINT_CONFIG, ASYNC_INFERENCE_CONFIG, OUTPUT_CONFIG, KMS_KEY_ID
)
ENDPOINT_CONFIG_KMS_KEY_ID_PATH = _simple_path(SAGEMAKER, ENDPOINT_CONFIG, KMS_KEY_ID)
ENDPOINT_TAGS_PATH = _simple_path(SAGEMAKER, ENDPOINT, TAGS)
FEATURE_GROUP_ONLINE_STORE_CONFIG_PATH = _simple_path(SAGEMAKER, FEATURE_GROUP, ONLINE_STORE_CONFIG)
FEATURE_GROUP_OFFLINE_STORE_CONFIG_PATH = _simple_path(
SAGEMAKER, FEATURE_GROUP, OFFLINE_STORE_CONFIG
Expand All @@ -167,14 +171,20 @@ def _simple_path(*args: str):
)
AUTO_ML_JOB_CONFIG_PATH = _simple_path(SAGEMAKER, AUTO_ML_JOB, AUTO_ML_JOB_CONFIG)
MONITORING_JOB_DEFINITION_PREFIX = _simple_path(
SAGEMAKER, MONITORING_SCHEDULE, MONITORING_SCHEDULE_CONFIG, MONITORING_JOB_DEFINITION
SAGEMAKER,
MONITORING_SCHEDULE,
MONITORING_SCHEDULE_CONFIG,
MONITORING_JOB_DEFINITION,
)
MONITORING_JOB_ENVIRONMENT_PATH = _simple_path(MONITORING_JOB_DEFINITION_PREFIX, ENVIRONMENT)
MONITORING_JOB_OUTPUT_KMS_KEY_ID_PATH = _simple_path(
MONITORING_JOB_DEFINITION_PREFIX, MONITORING_OUTPUT_CONFIG, KMS_KEY_ID
)
MONITORING_JOB_VOLUME_KMS_KEY_ID_PATH = _simple_path(
MONITORING_JOB_DEFINITION_PREFIX, MONITORING_RESOURCES, CLUSTER_CONFIG, VOLUME_KMS_KEY_ID
MONITORING_JOB_DEFINITION_PREFIX,
MONITORING_RESOURCES,
CLUSTER_CONFIG,
VOLUME_KMS_KEY_ID,
)
MONITORING_JOB_NETWORK_CONFIG_PATH = _simple_path(MONITORING_JOB_DEFINITION_PREFIX, NETWORK_CONFIG)
MONITORING_JOB_ENABLE_NETWORK_ISOLATION_PATH = _simple_path(
Expand Down Expand Up @@ -288,7 +298,11 @@ def _simple_path(*args: str):
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, VPC_CONFIG, SECURITY_GROUP_IDS
)
REMOTE_FUNCTION_ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION = _simple_path(
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION
SAGEMAKER,
PYTHON_SDK,
MODULES,
REMOTE_FUNCTION,
ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION,
)
MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION_PATH = _simple_path(
SAGEMAKER,
Expand Down Expand Up @@ -468,7 +482,11 @@ def _simple_path(*args: str):
"maxProperties": 48,
},
# Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_S3DataSource.html#sagemaker-Type-S3DataSource-S3Uri
"s3Uri": {TYPE: "string", "pattern": "^(https|s3)://([^/]+)/?(.*)$", "maxLength": 1024},
"s3Uri": {
TYPE: "string",
"pattern": "^(https|s3)://([^/]+)/?(.*)$",
"maxLength": 1024,
},
# Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_AlgorithmSpecification.html#sagemaker-Type-AlgorithmSpecification-ContainerEntrypoint
"preExecutionCommand": {TYPE: "string", "pattern": r".*"},
# Regex based on https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_PipelineDefinitionS3Location.html
Expand Down Expand Up @@ -746,6 +764,13 @@ def _simple_path(*args: str):
TAGS: {"$ref": "#/definitions/tags"},
},
},
# Endpoint
# https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateEndpoint.html
ENDPOINT: {
TYPE: OBJECT,
ADDITIONAL_PROPERTIES: False,
PROPERTIES: {TAGS: {"$ref": "#/definitions/tags"}},
},
# Endpoint Config
# https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateEndpointConfig.html
# Note: there is a separate API for creating Endpoints.
Expand Down Expand Up @@ -992,6 +1017,7 @@ def _simple_path(*args: str):
ADDITIONAL_PROPERTIES: False,
PROPERTIES: {KMS_KEY_ID: {"$ref": "#/definitions/kmsKeyId"}},
},
RESOURCE_KEY: {"$ref": "#/definitions/kmsKeyId"},
ROLE_ARN: {"$ref": "#/definitions/roleArn"},
TAGS: {"$ref": "#/definitions/tags"},
},
Expand Down
79 changes: 63 additions & 16 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
EDGE_PACKAGING_KMS_KEY_ID_PATH,
EDGE_PACKAGING_ROLE_ARN_PATH,
MODEL_CONTAINERS_PATH,
EDGE_PACKAGING_RESOURCE_KEY_PATH,
MODEL_VPC_CONFIG_PATH,
MODEL_ENABLE_NETWORK_ISOLATION_PATH,
MODEL_EXECUTION_ROLE_ARN_PATH,
Expand All @@ -50,7 +51,10 @@
from sagemaker.predictor import PredictorBase
from sagemaker.serverless import ServerlessInferenceConfig
from sagemaker.transformer import Transformer
from sagemaker.jumpstart.utils import add_jumpstart_tags, get_jumpstart_base_name_if_jumpstart_model
from sagemaker.jumpstart.utils import (
add_jumpstart_tags,
get_jumpstart_base_name_if_jumpstart_model,
)
from sagemaker.utils import (
unique_name_from_base,
update_container_with_inference_params,
Expand All @@ -63,15 +67,25 @@
from sagemaker.workflow import is_pipeline_variable
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.workflow.pipeline_context import runnable_by_pipeline, PipelineSession
from sagemaker.inference_recommender.inference_recommender_mixin import InferenceRecommenderMixin
from sagemaker.inference_recommender.inference_recommender_mixin import (
InferenceRecommenderMixin,
)

LOGGER = logging.getLogger("sagemaker")

NEO_ALLOWED_FRAMEWORKS = set(
["mxnet", "tensorflow", "keras", "pytorch", "onnx", "xgboost", "tflite"]
)

NEO_IOC_TARGET_DEVICES = ["ml_c4", "ml_c5", "ml_m4", "ml_m5", "ml_p2", "ml_p3", "ml_g4dn"]
NEO_IOC_TARGET_DEVICES = [
"ml_c4",
"ml_c5",
"ml_m4",
"ml_m5",
"ml_p2",
"ml_p3",
"ml_g4dn",
]

NEO_MULTIVERSION_UNSUPPORTED = [
"imx8mplus",
Expand Down Expand Up @@ -300,7 +314,9 @@ def __init__(
self._base_name = None
self.sagemaker_session = sagemaker_session
self.role = resolve_value_from_config(
role, MODEL_EXECUTION_ROLE_ARN_PATH, sagemaker_session=self.sagemaker_session
role,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qq: did some configuration change recently for how we do spacing? Odd that non-touched things are changing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True - I think this is a result of tox -e black-format

MODEL_EXECUTION_ROLE_ARN_PATH,
sagemaker_session=self.sagemaker_session,
)
self.vpc_config = resolve_value_from_config(
vpc_config, MODEL_VPC_CONFIG_PATH, sagemaker_session=self.sagemaker_session
Expand Down Expand Up @@ -585,7 +601,9 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
local_code = utils.get_config_value("local.local_code", self.sagemaker_session.config)

bucket, key_prefix = s3.determine_bucket_and_prefix(
bucket=self.bucket, key_prefix=key_prefix, sagemaker_session=self.sagemaker_session
bucket=self.bucket,
key_prefix=key_prefix,
sagemaker_session=self.sagemaker_session,
)

if (self.sagemaker_session.local_mode and local_code) or self.entry_point is None:
Expand Down Expand Up @@ -633,7 +651,8 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
else:
repacked_model_data = "s3://" + "/".join([bucket, key_prefix, "model.tar.gz"])
self.uploaded_code = fw_utils.UploadedCode(
s3_prefix=repacked_model_data, script_name=os.path.basename(self.entry_point)
s3_prefix=repacked_model_data,
script_name=os.path.basename(self.entry_point),
)

LOGGER.info(
Expand Down Expand Up @@ -693,7 +712,11 @@ def enable_network_isolation(self):
return False if not self._enable_network_isolation else self._enable_network_isolation

def _create_sagemaker_model(
self, instance_type=None, accelerator_type=None, tags=None, serverless_inference_config=None
self,
instance_type=None,
accelerator_type=None,
tags=None,
serverless_inference_config=None,
):
"""Create a SageMaker Model Entity

Expand Down Expand Up @@ -734,10 +757,14 @@ def _create_sagemaker_model(
self._init_sagemaker_session_if_does_not_exist(instance_type)
# Depending on the instance type, a local session (or) a session is initialized.
self.role = resolve_value_from_config(
self.role, MODEL_EXECUTION_ROLE_ARN_PATH, sagemaker_session=self.sagemaker_session
self.role,
MODEL_EXECUTION_ROLE_ARN_PATH,
sagemaker_session=self.sagemaker_session,
)
self.vpc_config = resolve_value_from_config(
self.vpc_config, MODEL_VPC_CONFIG_PATH, sagemaker_session=self.sagemaker_session
self.vpc_config,
MODEL_VPC_CONFIG_PATH,
sagemaker_session=self.sagemaker_session,
)
self._enable_network_isolation = resolve_value_from_config(
self._enable_network_isolation,
Expand Down Expand Up @@ -955,11 +982,16 @@ def package_for_edge(
job_name = f"packaging{self._compilation_job_name[11:]}"
self._init_sagemaker_session_if_does_not_exist(None)
s3_kms_key = resolve_value_from_config(
s3_kms_key, EDGE_PACKAGING_KMS_KEY_ID_PATH, sagemaker_session=self.sagemaker_session
s3_kms_key,
EDGE_PACKAGING_KMS_KEY_ID_PATH,
sagemaker_session=self.sagemaker_session,
)
role = resolve_value_from_config(
role, EDGE_PACKAGING_ROLE_ARN_PATH, sagemaker_session=self.sagemaker_session
)
resource_key = resolve_value_from_config(
resource_key, EDGE_PACKAGING_RESOURCE_KEY_PATH, sagemaker_session=self.sagemaker_session
)
if role is not None:
role = self.sagemaker_session.expand_role(role)
config = self._edge_packaging_job_config(
Expand Down Expand Up @@ -1065,7 +1097,9 @@ def compile(

self._init_sagemaker_session_if_does_not_exist(target_instance_family)
role = resolve_value_from_config(
role, COMPILATION_JOB_ROLE_ARN_PATH, sagemaker_session=self.sagemaker_session
role,
COMPILATION_JOB_ROLE_ARN_PATH,
sagemaker_session=self.sagemaker_session,
)
if not role:
# Originally IAM role was a required parameter.
Expand Down Expand Up @@ -1232,10 +1266,14 @@ def deploy(
self._init_sagemaker_session_if_does_not_exist(instance_type)
# Depending on the instance type, a local session (or) a session is initialized.
self.role = resolve_value_from_config(
self.role, MODEL_EXECUTION_ROLE_ARN_PATH, sagemaker_session=self.sagemaker_session
self.role,
MODEL_EXECUTION_ROLE_ARN_PATH,
sagemaker_session=self.sagemaker_session,
)
self.vpc_config = resolve_value_from_config(
self.vpc_config, MODEL_VPC_CONFIG_PATH, sagemaker_session=self.sagemaker_session
self.vpc_config,
MODEL_VPC_CONFIG_PATH,
sagemaker_session=self.sagemaker_session,
)
self._enable_network_isolation = resolve_value_from_config(
self._enable_network_isolation,
Expand All @@ -1244,7 +1282,9 @@ def deploy(
)

tags = add_jumpstart_tags(
tags=tags, inference_model_uri=self.model_data, inference_script_uri=self.source_dir
tags=tags,
inference_model_uri=self.model_data,
inference_script_uri=self.source_dir,
)

if self.role is None:
Expand Down Expand Up @@ -1292,7 +1332,9 @@ def deploy(
compiled_model_suffix = None if is_serverless else "-".join(instance_type.split(".")[:-1])
if self._is_compiled_model and not is_serverless:
self._ensure_base_name_if_needed(
image_uri=self.image_uri, script_uri=self.source_dir, model_uri=self.model_data
image_uri=self.image_uri,
script_uri=self.source_dir,
model_uri=self.model_data,
)
if self._base_name is not None:
self._base_name = "-".join((self._base_name, compiled_model_suffix))
Expand Down Expand Up @@ -1673,7 +1715,12 @@ class ModelPackage(Model):
"""A SageMaker ``Model`` that can be deployed to an ``Endpoint``."""

def __init__(
self, role=None, model_data=None, algorithm_arn=None, model_package_arn=None, **kwargs
self,
role=None,
model_data=None,
algorithm_arn=None,
model_package_arn=None,
**kwargs,
):
"""Initialize a SageMaker ModelPackage.

Expand Down
Loading