Skip to content

Commit ba7e6a9

Browse files
authored
chore: telemetry for deployment configs (#4806)
* chore: telemetry for deployment configs * chore: minor fixes * chore: address minor issues * fix: flake8 * fix: model type for estimator * chore: add ranking name argument to get_top_ranked_config_name * chore: use named args * fix: remove tuple from model type * chore: add comment explaining test
1 parent 8f9a34d commit ba7e6a9

File tree

7 files changed

+196
-79
lines changed

7 files changed

+196
-79
lines changed

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
add_jumpstart_model_info_tags,
7070
get_eula_message,
7171
get_default_jumpstart_session_with_user_agent_suffix,
72+
get_top_ranked_config_name,
7273
update_dict_if_key_not_present,
7374
resolve_estimator_sagemaker_config_field,
7475
verify_model_region_and_return_specs,
@@ -204,7 +205,9 @@ def get_init_kwargs(
204205

205206
estimator_init_kwargs = _add_model_version_to_kwargs(estimator_init_kwargs)
206207
estimator_init_kwargs = _add_vulnerable_and_deprecated_status_to_kwargs(estimator_init_kwargs)
207-
estimator_init_kwargs = _add_sagemaker_session_to_kwargs(estimator_init_kwargs)
208+
estimator_init_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs(
209+
estimator_init_kwargs
210+
)
208211
estimator_init_kwargs = _add_region_to_kwargs(estimator_init_kwargs)
209212
estimator_init_kwargs = _add_instance_type_and_count_to_kwargs(estimator_init_kwargs)
210213
estimator_init_kwargs = _add_image_uri_to_kwargs(estimator_init_kwargs)
@@ -438,12 +441,17 @@ def _add_region_to_kwargs(kwargs: JumpStartKwargs) -> JumpStartKwargs:
438441
return kwargs
439442

440443

441-
def _add_sagemaker_session_to_kwargs(kwargs: JumpStartKwargs) -> JumpStartKwargs:
444+
def _add_sagemaker_session_with_custom_user_agent_to_kwargs(
445+
kwargs: JumpStartKwargs,
446+
) -> JumpStartKwargs:
442447
"""Sets session in kwargs based on default or override, returns full kwargs."""
443448
kwargs.sagemaker_session = (
444449
kwargs.sagemaker_session
445450
or get_default_jumpstart_session_with_user_agent_suffix(
446-
kwargs.model_id, kwargs.model_version, kwargs.hub_arn
451+
model_id=kwargs.model_id,
452+
model_version=kwargs.model_version,
453+
config_name=None,
454+
is_hub_content=kwargs.hub_arn is not None,
447455
)
448456
)
449457
return kwargs
@@ -903,21 +911,16 @@ def _add_config_name_to_kwargs(
903911
) -> JumpStartEstimatorInitKwargs:
904912
"""Sets tags in kwargs based on default or override, returns full kwargs."""
905913

906-
specs = verify_model_region_and_return_specs(
914+
kwargs.config_name = kwargs.config_name or get_top_ranked_config_name(
915+
region=kwargs.region,
907916
model_id=kwargs.model_id,
908-
version=kwargs.model_version,
917+
model_version=kwargs.model_version,
918+
sagemaker_session=kwargs.sagemaker_session,
909919
scope=JumpStartScriptScope.TRAINING,
910-
region=kwargs.region,
911-
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
920+
model_type=kwargs.model_type,
912921
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
913-
sagemaker_session=kwargs.sagemaker_session,
914-
config_name=kwargs.config_name,
922+
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
915923
hub_arn=kwargs.hub_arn,
916924
)
917925

918-
if specs.training_configs and specs.training_configs.get_top_config_from_ranking():
919-
kwargs.config_name = (
920-
kwargs.config_name or specs.training_configs.get_top_config_from_ranking().config_name
921-
)
922-
923926
return kwargs

src/sagemaker/jumpstart/factory/model.py

Lines changed: 84 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
)
3030
from sagemaker.jumpstart.artifacts.resource_names import _retrieve_resource_name_base
3131
from sagemaker.jumpstart.constants import (
32+
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
3233
INFERENCE_ENTRY_POINT_SCRIPT_NAME,
3334
JUMPSTART_DEFAULT_REGION_NAME,
3435
JUMPSTART_LOGGER,
@@ -54,6 +55,7 @@
5455
add_jumpstart_model_info_tags,
5556
get_default_jumpstart_session_with_user_agent_suffix,
5657
get_neo_content_bucket,
58+
get_top_ranked_config_name,
5759
update_dict_if_key_not_present,
5860
resolve_model_sagemaker_config_field,
5961
verify_model_region_and_return_specs,
@@ -155,15 +157,18 @@ def _add_region_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelIni
155157
return kwargs
156158

157159

158-
def _add_sagemaker_session_to_kwargs(
160+
def _add_sagemaker_session_with_custom_user_agent_to_kwargs(
159161
kwargs: Union[JumpStartModelInitKwargs, JumpStartModelDeployKwargs]
160162
) -> JumpStartModelInitKwargs:
161163
"""Sets session in kwargs based on default or override, returns full kwargs."""
162164

163165
kwargs.sagemaker_session = (
164166
kwargs.sagemaker_session
165167
or get_default_jumpstart_session_with_user_agent_suffix(
166-
kwargs.model_id, kwargs.model_version, kwargs.hub_arn
168+
model_id=kwargs.model_id,
169+
model_version=kwargs.model_version,
170+
config_name=kwargs.config_name,
171+
is_hub_content=kwargs.hub_arn is not None,
167172
)
168173
)
169174

@@ -244,6 +249,32 @@ def _add_instance_type_to_kwargs(
244249
kwargs.instance_type,
245250
)
246251

252+
specs = verify_model_region_and_return_specs(
253+
model_id=kwargs.model_id,
254+
version=kwargs.model_version,
255+
scope=JumpStartScriptScope.INFERENCE,
256+
region=kwargs.region,
257+
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
258+
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
259+
sagemaker_session=kwargs.sagemaker_session,
260+
model_type=kwargs.model_type,
261+
config_name=kwargs.config_name,
262+
)
263+
264+
if specs.inference_configs and kwargs.config_name not in specs.inference_configs.configs:
265+
return kwargs
266+
267+
resolved_config = (
268+
specs.inference_configs.configs[kwargs.config_name].resolved_config
269+
if specs.inference_configs
270+
else None
271+
)
272+
if resolved_config is None:
273+
return kwargs
274+
supported_instance_types = resolved_config.get("supported_inference_instance_types", [])
275+
if kwargs.instance_type not in supported_instance_types:
276+
JUMPSTART_LOGGER.warning("Overriding instance type to %s", kwargs.instance_type)
277+
247278
return kwargs
248279

249280

@@ -662,38 +693,25 @@ def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSta
662693
ValueError: If the instance_type is not supported with the current config.
663694
"""
664695

665-
specs = verify_model_region_and_return_specs(
696+
# we need to create a default JS session (without custom user agent)
697+
# in order to retrieve config name info
698+
temp_session = kwargs.sagemaker_session or DEFAULT_JUMPSTART_SAGEMAKER_SESSION
699+
700+
kwargs.config_name = kwargs.config_name or get_top_ranked_config_name(
701+
region=kwargs.region,
666702
model_id=kwargs.model_id,
667-
version=kwargs.model_version,
703+
model_version=kwargs.model_version,
704+
sagemaker_session=temp_session,
668705
scope=JumpStartScriptScope.INFERENCE,
669-
region=kwargs.region,
670-
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
671-
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
672-
sagemaker_session=kwargs.sagemaker_session,
673706
model_type=kwargs.model_type,
674-
config_name=kwargs.config_name,
707+
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
708+
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
675709
hub_arn=kwargs.hub_arn,
676710
)
677-
if specs.inference_configs:
678-
default_config_name = (
679-
specs.inference_configs.get_top_config_from_ranking().config_name
680-
if specs.inference_configs.get_top_config_from_ranking()
681-
else None
682-
)
683-
kwargs.config_name = kwargs.config_name or default_config_name
684-
685-
if not kwargs.config_name:
686-
return kwargs
687711

688-
if kwargs.config_name not in set(specs.inference_configs.configs.keys()):
689-
raise ValueError(
690-
f"Config {kwargs.config_name} is not supported for model {kwargs.model_id}."
691-
)
712+
if kwargs.config_name is None:
713+
return kwargs
692714

693-
resolved_config = specs.inference_configs.configs[kwargs.config_name].resolved_config
694-
supported_instance_types = resolved_config.get("supported_inference_instance_types", [])
695-
if kwargs.instance_type not in supported_instance_types:
696-
JUMPSTART_LOGGER.warning("Overriding instance type to %s", kwargs.instance_type)
697715
return kwargs
698716

699717

@@ -746,32 +764,41 @@ def _add_config_name_to_deploy_kwargs(
746764
ValueError: If the instance_type is not supported with the current config.
747765
"""
748766

749-
specs = verify_model_region_and_return_specs(
750-
model_id=kwargs.model_id,
751-
version=kwargs.model_version,
752-
scope=JumpStartScriptScope.INFERENCE,
753-
region=kwargs.region,
754-
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
755-
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
756-
sagemaker_session=kwargs.sagemaker_session,
757-
model_type=kwargs.model_type,
758-
config_name=kwargs.config_name,
759-
hub_arn=kwargs.hub_arn,
760-
)
767+
# we need to create a default JS session (without custom user agent)
768+
# in order to retrieve config name info
769+
temp_session = kwargs.sagemaker_session or DEFAULT_JUMPSTART_SAGEMAKER_SESSION
761770

762771
if training_config_name:
763-
kwargs.config_name = _select_inference_config_from_training_config(
772+
773+
specs = verify_model_region_and_return_specs(
774+
model_id=kwargs.model_id,
775+
version=kwargs.model_version,
776+
scope=JumpStartScriptScope.INFERENCE,
777+
region=kwargs.region,
778+
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
779+
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
780+
sagemaker_session=temp_session,
781+
model_type=kwargs.model_type,
782+
config_name=kwargs.config_name,
783+
)
784+
default_config_name = _select_inference_config_from_training_config(
764785
specs=specs, training_config_name=training_config_name
765786
)
766-
return kwargs
767787

768-
if specs.inference_configs:
769-
default_config_name = (
770-
specs.inference_configs.get_top_config_from_ranking().config_name
771-
if specs.inference_configs.get_top_config_from_ranking()
772-
else None
788+
else:
789+
default_config_name = get_top_ranked_config_name(
790+
region=kwargs.region,
791+
model_id=kwargs.model_id,
792+
model_version=kwargs.model_version,
793+
sagemaker_session=temp_session,
794+
scope=JumpStartScriptScope.INFERENCE,
795+
model_type=kwargs.model_type,
796+
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
797+
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
798+
hub_arn=kwargs.hub_arn,
773799
)
774-
kwargs.config_name = kwargs.config_name or default_config_name
800+
801+
kwargs.config_name = kwargs.config_name or default_config_name
775802

776803
return kwargs
777804

@@ -850,15 +877,15 @@ def get_deploy_kwargs(
850877
routing_config=routing_config,
851878
)
852879

853-
deploy_kwargs = _add_sagemaker_session_to_kwargs(kwargs=deploy_kwargs)
880+
deploy_kwargs = _add_config_name_to_deploy_kwargs(
881+
kwargs=deploy_kwargs, training_config_name=training_config_name
882+
)
854883

855884
deploy_kwargs = _add_model_version_to_kwargs(kwargs=deploy_kwargs)
856885

857-
deploy_kwargs = _add_endpoint_name_to_kwargs(kwargs=deploy_kwargs)
886+
deploy_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs(kwargs=deploy_kwargs)
858887

859-
deploy_kwargs = _add_config_name_to_deploy_kwargs(
860-
kwargs=deploy_kwargs, training_config_name=training_config_name
861-
)
888+
deploy_kwargs = _add_endpoint_name_to_kwargs(kwargs=deploy_kwargs)
862889

863890
deploy_kwargs = _add_instance_type_to_kwargs(kwargs=deploy_kwargs)
864891

@@ -1041,11 +1068,14 @@ def get_init_kwargs(
10411068
)
10421069

10431070
model_init_kwargs = _add_vulnerable_and_deprecated_status_to_kwargs(kwargs=model_init_kwargs)
1071+
model_init_kwargs = _add_model_version_to_kwargs(kwargs=model_init_kwargs)
1072+
model_init_kwargs = _add_config_name_to_init_kwargs(kwargs=model_init_kwargs)
10441073

1045-
model_init_kwargs = _add_sagemaker_session_to_kwargs(kwargs=model_init_kwargs)
1074+
model_init_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs(
1075+
kwargs=model_init_kwargs
1076+
)
10461077
model_init_kwargs = _add_region_to_kwargs(kwargs=model_init_kwargs)
10471078

1048-
model_init_kwargs = _add_model_version_to_kwargs(kwargs=model_init_kwargs)
10491079
model_init_kwargs = _add_model_name_to_kwargs(kwargs=model_init_kwargs)
10501080

10511081
model_init_kwargs = _add_instance_type_to_kwargs(
@@ -1073,8 +1103,6 @@ def get_init_kwargs(
10731103

10741104
model_init_kwargs = _add_resources_to_kwargs(kwargs=model_init_kwargs)
10751105

1076-
model_init_kwargs = _add_config_name_to_init_kwargs(kwargs=model_init_kwargs)
1077-
10781106
model_init_kwargs = _add_additional_model_data_sources_to_kwargs(kwargs=model_init_kwargs)
10791107

10801108
return model_init_kwargs

src/sagemaker/jumpstart/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2458,7 +2458,7 @@ def __init__(
24582458
self.model_id = model_id
24592459
self.model_version = model_version
24602460
self.hub_arn = hub_arn
2461-
self.model_type = (model_type,)
2461+
self.model_type = model_type
24622462
self.instance_type = instance_type
24632463
self.instance_count = instance_count
24642464
self.region = region

src/sagemaker/jumpstart/utils.py

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1111,11 +1111,15 @@ def get_jumpstart_configs(
11111111

11121112

11131113
def get_jumpstart_user_agent_extra_suffix(
1114-
model_id: Optional[str], model_version: Optional[str], is_hub_content: Optional[bool]
1114+
model_id: Optional[str],
1115+
model_version: Optional[str],
1116+
config_name: Optional[str],
1117+
is_hub_content: Optional[bool],
11151118
) -> str:
11161119
"""Returns the model-specific user agent string to be added to requests."""
11171120
sagemaker_python_sdk_headers = get_user_agent_extra_suffix()
11181121
jumpstart_specific_suffix = f"md/js_model_id#{model_id} md/js_model_ver#{model_version}"
1122+
config_specific_suffix = f"md/js_config#{config_name}"
11191123
hub_specific_suffix = f"md/js_is_hub_content#{is_hub_content}"
11201124

11211125
if os.getenv(constants.ENV_VARIABLE_DISABLE_JUMPSTART_TELEMETRY, None):
@@ -1130,19 +1134,74 @@ def get_jumpstart_user_agent_extra_suffix(
11301134
else:
11311135
headers = f"{sagemaker_python_sdk_headers} {jumpstart_specific_suffix}"
11321136

1137+
if config_name:
1138+
headers = f"{headers} {config_specific_suffix}"
1139+
11331140
return headers
11341141

11351142

1143+
def get_top_ranked_config_name(
1144+
region: str,
1145+
model_id: str,
1146+
model_version: str,
1147+
sagemaker_session: Optional[Session] = constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
1148+
scope: enums.JumpStartScriptScope = enums.JumpStartScriptScope.INFERENCE,
1149+
model_type: enums.JumpStartModelType = enums.JumpStartModelType.OPEN_WEIGHTS,
1150+
tolerate_deprecated_model: bool = False,
1151+
tolerate_vulnerable_model: bool = False,
1152+
hub_arn: Optional[str] = None,
1153+
ranking_name: enums.JumpStartConfigRankingName = enums.JumpStartConfigRankingName.DEFAULT,
1154+
) -> Optional[str]:
1155+
"""Returns the top ranked config name for the given model ID and region.
1156+
1157+
Raises:
1158+
ValueError: If the script scope is not supported by JumpStart.
1159+
"""
1160+
model_specs = verify_model_region_and_return_specs(
1161+
model_id=model_id,
1162+
version=model_version,
1163+
scope=scope,
1164+
region=region,
1165+
hub_arn=hub_arn,
1166+
tolerate_vulnerable_model=tolerate_vulnerable_model,
1167+
tolerate_deprecated_model=tolerate_deprecated_model,
1168+
sagemaker_session=sagemaker_session,
1169+
model_type=model_type,
1170+
)
1171+
1172+
if scope == enums.JumpStartScriptScope.INFERENCE:
1173+
return (
1174+
model_specs.inference_configs.get_top_config_from_ranking(
1175+
ranking_name=ranking_name
1176+
).config_name
1177+
if model_specs.inference_configs
1178+
else None
1179+
)
1180+
if scope == enums.JumpStartScriptScope.TRAINING:
1181+
return (
1182+
model_specs.training_configs.get_top_config_from_ranking(
1183+
ranking_name=ranking_name
1184+
).config_name
1185+
if model_specs.training_configs
1186+
else None
1187+
)
1188+
raise ValueError(f"Unsupported script scope: {scope}.")
1189+
1190+
11361191
def get_default_jumpstart_session_with_user_agent_suffix(
11371192
model_id: Optional[str] = None,
11381193
model_version: Optional[str] = None,
1194+
config_name: Optional[str] = None,
11391195
is_hub_content: Optional[bool] = False,
11401196
) -> Session:
11411197
"""Returns default JumpStart SageMaker Session with model-specific user agent suffix."""
11421198
botocore_session = botocore.session.get_session()
11431199
botocore_config = botocore.config.Config(
11441200
user_agent_extra=get_jumpstart_user_agent_extra_suffix(
1145-
model_id, model_version, is_hub_content
1201+
model_id=model_id,
1202+
model_version=model_version,
1203+
config_name=config_name,
1204+
is_hub_content=is_hub_content,
11461205
),
11471206
)
11481207
botocore_session.set_default_client_config(botocore_config)

0 commit comments

Comments
 (0)