|
29 | 29 | )
|
30 | 30 | from sagemaker.jumpstart.artifacts.resource_names import _retrieve_resource_name_base
|
31 | 31 | from sagemaker.jumpstart.constants import (
|
| 32 | + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, |
32 | 33 | INFERENCE_ENTRY_POINT_SCRIPT_NAME,
|
33 | 34 | JUMPSTART_DEFAULT_REGION_NAME,
|
34 | 35 | JUMPSTART_LOGGER,
|
|
54 | 55 | add_jumpstart_model_info_tags,
|
55 | 56 | get_default_jumpstart_session_with_user_agent_suffix,
|
56 | 57 | get_neo_content_bucket,
|
| 58 | + get_top_ranked_config_name, |
57 | 59 | update_dict_if_key_not_present,
|
58 | 60 | resolve_model_sagemaker_config_field,
|
59 | 61 | verify_model_region_and_return_specs,
|
@@ -155,15 +157,18 @@ def _add_region_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModelIni
|
155 | 157 | return kwargs
|
156 | 158 |
|
157 | 159 |
|
158 |
| -def _add_sagemaker_session_to_kwargs( |
| 160 | +def _add_sagemaker_session_with_custom_user_agent_to_kwargs( |
159 | 161 | kwargs: Union[JumpStartModelInitKwargs, JumpStartModelDeployKwargs]
|
160 | 162 | ) -> JumpStartModelInitKwargs:
|
161 | 163 | """Sets session in kwargs based on default or override, returns full kwargs."""
|
162 | 164 |
|
163 | 165 | kwargs.sagemaker_session = (
|
164 | 166 | kwargs.sagemaker_session
|
165 | 167 | or get_default_jumpstart_session_with_user_agent_suffix(
|
166 |
| - kwargs.model_id, kwargs.model_version, kwargs.hub_arn |
| 168 | + model_id=kwargs.model_id, |
| 169 | + model_version=kwargs.model_version, |
| 170 | + config_name=kwargs.config_name, |
| 171 | + is_hub_content=kwargs.hub_arn is not None, |
167 | 172 | )
|
168 | 173 | )
|
169 | 174 |
|
@@ -244,6 +249,32 @@ def _add_instance_type_to_kwargs(
|
244 | 249 | kwargs.instance_type,
|
245 | 250 | )
|
246 | 251 |
|
| 252 | + specs = verify_model_region_and_return_specs( |
| 253 | + model_id=kwargs.model_id, |
| 254 | + version=kwargs.model_version, |
| 255 | + scope=JumpStartScriptScope.INFERENCE, |
| 256 | + region=kwargs.region, |
| 257 | + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, |
| 258 | + tolerate_deprecated_model=kwargs.tolerate_deprecated_model, |
| 259 | + sagemaker_session=kwargs.sagemaker_session, |
| 260 | + model_type=kwargs.model_type, |
| 261 | + config_name=kwargs.config_name, |
| 262 | + ) |
| 263 | + |
| 264 | + if specs.inference_configs and kwargs.config_name not in specs.inference_configs.configs: |
| 265 | + return kwargs |
| 266 | + |
| 267 | + resolved_config = ( |
| 268 | + specs.inference_configs.configs[kwargs.config_name].resolved_config |
| 269 | + if specs.inference_configs |
| 270 | + else None |
| 271 | + ) |
| 272 | + if resolved_config is None: |
| 273 | + return kwargs |
| 274 | + supported_instance_types = resolved_config.get("supported_inference_instance_types", []) |
| 275 | + if kwargs.instance_type not in supported_instance_types: |
| 276 | + JUMPSTART_LOGGER.warning("Overriding instance type to %s", kwargs.instance_type) |
| 277 | + |
247 | 278 | return kwargs
|
248 | 279 |
|
249 | 280 |
|
@@ -662,38 +693,25 @@ def _add_config_name_to_init_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpSta
|
662 | 693 | ValueError: If the instance_type is not supported with the current config.
|
663 | 694 | """
|
664 | 695 |
|
665 |
| - specs = verify_model_region_and_return_specs( |
| 696 | + # we need to create a default JS session (without custom user agent) |
| 697 | + # in order to retrieve config name info |
| 698 | + temp_session = kwargs.sagemaker_session or DEFAULT_JUMPSTART_SAGEMAKER_SESSION |
| 699 | + |
| 700 | + kwargs.config_name = kwargs.config_name or get_top_ranked_config_name( |
| 701 | + region=kwargs.region, |
666 | 702 | model_id=kwargs.model_id,
|
667 |
| - version=kwargs.model_version, |
| 703 | + model_version=kwargs.model_version, |
| 704 | + sagemaker_session=temp_session, |
668 | 705 | scope=JumpStartScriptScope.INFERENCE,
|
669 |
| - region=kwargs.region, |
670 |
| - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, |
671 |
| - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, |
672 |
| - sagemaker_session=kwargs.sagemaker_session, |
673 | 706 | model_type=kwargs.model_type,
|
674 |
| - config_name=kwargs.config_name, |
| 707 | + tolerate_deprecated_model=kwargs.tolerate_deprecated_model, |
| 708 | + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, |
675 | 709 | hub_arn=kwargs.hub_arn,
|
676 | 710 | )
|
677 |
| - if specs.inference_configs: |
678 |
| - default_config_name = ( |
679 |
| - specs.inference_configs.get_top_config_from_ranking().config_name |
680 |
| - if specs.inference_configs.get_top_config_from_ranking() |
681 |
| - else None |
682 |
| - ) |
683 |
| - kwargs.config_name = kwargs.config_name or default_config_name |
684 |
| - |
685 |
| - if not kwargs.config_name: |
686 |
| - return kwargs |
687 | 711 |
|
688 |
| - if kwargs.config_name not in set(specs.inference_configs.configs.keys()): |
689 |
| - raise ValueError( |
690 |
| - f"Config {kwargs.config_name} is not supported for model {kwargs.model_id}." |
691 |
| - ) |
| 712 | + if kwargs.config_name is None: |
| 713 | + return kwargs |
692 | 714 |
|
693 |
| - resolved_config = specs.inference_configs.configs[kwargs.config_name].resolved_config |
694 |
| - supported_instance_types = resolved_config.get("supported_inference_instance_types", []) |
695 |
| - if kwargs.instance_type not in supported_instance_types: |
696 |
| - JUMPSTART_LOGGER.warning("Overriding instance type to %s", kwargs.instance_type) |
697 | 715 | return kwargs
|
698 | 716 |
|
699 | 717 |
|
@@ -746,32 +764,41 @@ def _add_config_name_to_deploy_kwargs(
|
746 | 764 | ValueError: If the instance_type is not supported with the current config.
|
747 | 765 | """
|
748 | 766 |
|
749 |
| - specs = verify_model_region_and_return_specs( |
750 |
| - model_id=kwargs.model_id, |
751 |
| - version=kwargs.model_version, |
752 |
| - scope=JumpStartScriptScope.INFERENCE, |
753 |
| - region=kwargs.region, |
754 |
| - tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, |
755 |
| - tolerate_deprecated_model=kwargs.tolerate_deprecated_model, |
756 |
| - sagemaker_session=kwargs.sagemaker_session, |
757 |
| - model_type=kwargs.model_type, |
758 |
| - config_name=kwargs.config_name, |
759 |
| - hub_arn=kwargs.hub_arn, |
760 |
| - ) |
| 767 | + # we need to create a default JS session (without custom user agent) |
| 768 | + # in order to retrieve config name info |
| 769 | + temp_session = kwargs.sagemaker_session or DEFAULT_JUMPSTART_SAGEMAKER_SESSION |
761 | 770 |
|
762 | 771 | if training_config_name:
|
763 |
| - kwargs.config_name = _select_inference_config_from_training_config( |
| 772 | + |
| 773 | + specs = verify_model_region_and_return_specs( |
| 774 | + model_id=kwargs.model_id, |
| 775 | + version=kwargs.model_version, |
| 776 | + scope=JumpStartScriptScope.INFERENCE, |
| 777 | + region=kwargs.region, |
| 778 | + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, |
| 779 | + tolerate_deprecated_model=kwargs.tolerate_deprecated_model, |
| 780 | + sagemaker_session=temp_session, |
| 781 | + model_type=kwargs.model_type, |
| 782 | + config_name=kwargs.config_name, |
| 783 | + ) |
| 784 | + default_config_name = _select_inference_config_from_training_config( |
764 | 785 | specs=specs, training_config_name=training_config_name
|
765 | 786 | )
|
766 |
| - return kwargs |
767 | 787 |
|
768 |
| - if specs.inference_configs: |
769 |
| - default_config_name = ( |
770 |
| - specs.inference_configs.get_top_config_from_ranking().config_name |
771 |
| - if specs.inference_configs.get_top_config_from_ranking() |
772 |
| - else None |
| 788 | + else: |
| 789 | + default_config_name = get_top_ranked_config_name( |
| 790 | + region=kwargs.region, |
| 791 | + model_id=kwargs.model_id, |
| 792 | + model_version=kwargs.model_version, |
| 793 | + sagemaker_session=temp_session, |
| 794 | + scope=JumpStartScriptScope.INFERENCE, |
| 795 | + model_type=kwargs.model_type, |
| 796 | + tolerate_deprecated_model=kwargs.tolerate_deprecated_model, |
| 797 | + tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model, |
| 798 | + hub_arn=kwargs.hub_arn, |
773 | 799 | )
|
774 |
| - kwargs.config_name = kwargs.config_name or default_config_name |
| 800 | + |
| 801 | + kwargs.config_name = kwargs.config_name or default_config_name |
775 | 802 |
|
776 | 803 | return kwargs
|
777 | 804 |
|
@@ -850,15 +877,15 @@ def get_deploy_kwargs(
|
850 | 877 | routing_config=routing_config,
|
851 | 878 | )
|
852 | 879 |
|
853 |
| - deploy_kwargs = _add_sagemaker_session_to_kwargs(kwargs=deploy_kwargs) |
| 880 | + deploy_kwargs = _add_config_name_to_deploy_kwargs( |
| 881 | + kwargs=deploy_kwargs, training_config_name=training_config_name |
| 882 | + ) |
854 | 883 |
|
855 | 884 | deploy_kwargs = _add_model_version_to_kwargs(kwargs=deploy_kwargs)
|
856 | 885 |
|
857 |
| - deploy_kwargs = _add_endpoint_name_to_kwargs(kwargs=deploy_kwargs) |
| 886 | + deploy_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs(kwargs=deploy_kwargs) |
858 | 887 |
|
859 |
| - deploy_kwargs = _add_config_name_to_deploy_kwargs( |
860 |
| - kwargs=deploy_kwargs, training_config_name=training_config_name |
861 |
| - ) |
| 888 | + deploy_kwargs = _add_endpoint_name_to_kwargs(kwargs=deploy_kwargs) |
862 | 889 |
|
863 | 890 | deploy_kwargs = _add_instance_type_to_kwargs(kwargs=deploy_kwargs)
|
864 | 891 |
|
@@ -1041,11 +1068,14 @@ def get_init_kwargs(
|
1041 | 1068 | )
|
1042 | 1069 |
|
1043 | 1070 | model_init_kwargs = _add_vulnerable_and_deprecated_status_to_kwargs(kwargs=model_init_kwargs)
|
| 1071 | + model_init_kwargs = _add_model_version_to_kwargs(kwargs=model_init_kwargs) |
| 1072 | + model_init_kwargs = _add_config_name_to_init_kwargs(kwargs=model_init_kwargs) |
1044 | 1073 |
|
1045 |
| - model_init_kwargs = _add_sagemaker_session_to_kwargs(kwargs=model_init_kwargs) |
| 1074 | + model_init_kwargs = _add_sagemaker_session_with_custom_user_agent_to_kwargs( |
| 1075 | + kwargs=model_init_kwargs |
| 1076 | + ) |
1046 | 1077 | model_init_kwargs = _add_region_to_kwargs(kwargs=model_init_kwargs)
|
1047 | 1078 |
|
1048 |
| - model_init_kwargs = _add_model_version_to_kwargs(kwargs=model_init_kwargs) |
1049 | 1079 | model_init_kwargs = _add_model_name_to_kwargs(kwargs=model_init_kwargs)
|
1050 | 1080 |
|
1051 | 1081 | model_init_kwargs = _add_instance_type_to_kwargs(
|
@@ -1073,8 +1103,6 @@ def get_init_kwargs(
|
1073 | 1103 |
|
1074 | 1104 | model_init_kwargs = _add_resources_to_kwargs(kwargs=model_init_kwargs)
|
1075 | 1105 |
|
1076 |
| - model_init_kwargs = _add_config_name_to_init_kwargs(kwargs=model_init_kwargs) |
1077 |
| - |
1078 | 1106 | model_init_kwargs = _add_additional_model_data_sources_to_kwargs(kwargs=model_init_kwargs)
|
1079 | 1107 |
|
1080 | 1108 | return model_init_kwargs
|
0 commit comments