Skip to content

Commit bceb17b

Browse files
committed
rename to proprietary model and fix unittests
1 parent 07fa93e commit bceb17b

File tree

11 files changed

+46
-28
lines changed

11 files changed

+46
-28
lines changed

doc/doc_utils/jumpstart_doc_utils.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -206,10 +206,10 @@ def get_model_source(url):
206206
return "Source"
207207

208208

209-
def create_marketplace_model_table():
209+
def create_proprietary_model_table():
210210
marketpkace_content_intro = []
211211
marketpkace_content_intro.append("\n")
212-
marketpkace_content_intro.append(".. list-table:: Available Models\n")
212+
marketpkace_content_intro.append(".. list-table:: Available Proprietary Models\n")
213213
marketpkace_content_intro.append(" :widths: 50 20 20 20 20\n")
214214
marketpkace_content_intro.append(" :header-rows: 1\n")
215215
marketpkace_content_intro.append(" :class: datatable\n")
@@ -232,17 +232,17 @@ def create_marketplace_model_table():
232232
):
233233
sdk_manifest_top_versions_for_models[model["model_id"]] = model
234234

235-
marketplace_content_entries = []
235+
proprietary_content_entries = []
236236
for model in sdk_manifest_top_versions_for_models.values():
237237
model_spec = get_jumpstart_sdk_spec(model["spec_key"])
238-
marketplace_content_entries.append(" * - {}\n".format(model_spec["model_id"]))
239-
marketplace_content_entries.append(" - {}\n".format(False)) # TODO: support training
240-
marketplace_content_entries.append(" - {}\n".format(model["version"]))
241-
marketplace_content_entries.append(" - {}\n".format(model["min_version"]))
242-
marketplace_content_entries.append(
238+
proprietary_content_entries.append(" * - {}\n".format(model_spec["model_id"]))
239+
proprietary_content_entries.append(" - {}\n".format(False)) # TODO: support training
240+
proprietary_content_entries.append(" - {}\n".format(model["version"]))
241+
proprietary_content_entries.append(" - {}\n".format(model["min_version"]))
242+
proprietary_content_entries.append(
243243
" - `{} <{}>`__ |external-link|\n".format("Source", model_spec.get("url"))
244244
)
245-
return marketpkace_content_intro + marketplace_content_entries + ["\n"]
245+
return marketpkace_content_intro + proprietary_content_entries + ["\n"]
246246

247247

248248
def create_jumpstart_model_table():
@@ -348,10 +348,10 @@ def create_jumpstart_model_table():
348348
f.writelines(file_content_single_entry)
349349
f.close()
350350

351-
marketplace_content_entries = create_marketplace_model_table()
351+
proprietary_content_entries = create_proprietary_model_table()
352352

353353
f = open("doc_utils/pretrainedmodels.rst", "a")
354354
f.writelines(file_content_intro)
355355
f.writelines(open_weight_content_entries)
356-
f.writelines(marketplace_content_entries)
356+
f.writelines(proprietary_content_entries)
357357
f.close()

src/sagemaker/jumpstart/enums.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,6 @@ class JumpStartTag(str, Enum):
9191
MODEL_VERSION = "sagemaker-sdk:jumpstart-model-version"
9292
MODEL_TYPE = "sagemaker-sdk:jumpstart-model-type"
9393

94-
MARKETPLACE_MODEL_TYPE_VALUE = "SageMakerJumpStartMarketplace"
95-
9694

9795
class SerializerType(str, Enum):
9896
"""Enum class for serializers associated with JumpStart models."""

src/sagemaker/jumpstart/exceptions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,10 @@ def get_proprietary_model_subscription_msg(
6262
model_id: str,
6363
subscription_link: str,
6464
) -> str:
65-
"""Returns customer-facing message for using a Marketplace model."""
65+
"""Returns customer-facing message for using a proprietary model."""
6666

6767
return (
68-
f"INFO: Using Marketplace model '{model_id}'. "
68+
f"INFO: Using proprietary model '{model_id}'. "
6969
f"Please make sure to subscribe to the model from {subscription_link}"
7070
)
7171

@@ -75,7 +75,7 @@ def get_wildcard_proprietary_model_version_msg(
7575
) -> str:
7676
"""Returns customer-facing message for passing wildcard version to proprietary models."""
7777
msg = (
78-
f"Marketplace model '{model_id}' does not support "
78+
f"Proprietary model '{model_id}' does not support "
7979
f"wildcard version identifier '{wildcard_model_version}'. "
8080
)
8181
if len(available_versions) > 0:

src/sagemaker/jumpstart/factory/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def _log_model_type(kwargs: JumpStartModelInitKwargs) -> None:
169169
"""Log the model type being used"""
170170
if kwargs.model_type == JumpStartModelType.PROPRIETARY:
171171
JUMPSTART_LOGGER.info(
172-
"Marketplace model %s of version %s is being used.",
172+
"Proprietary model %s of version %s is being used.",
173173
kwargs.model_id,
174174
kwargs.model_version,
175175
)

src/sagemaker/jumpstart/filters.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,22 @@ def __init__(self, key: str, value: str, operator: str):
430430
self.value = value
431431
self.operator = operator
432432

433+
def set_key(self, key: str) -> None:
434+
"""Sets the key for the model filter.
435+
436+
Args:
437+
key (str): The key to be set.
438+
"""
439+
self.key = key
440+
441+
def set_value(self, value: str) -> None:
442+
"""Sets the value for the model filter.
443+
444+
Args:
445+
value (str): The value to be set.
446+
"""
447+
self.value = value
448+
433449

434450
def parse_filter_string(filter_string: str) -> ModelFilter:
435451
"""Parse filter string and return a serialized ``ModelFilter`` object.

src/sagemaker/jumpstart/notebook_utils.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
2626
JUMPSTART_DEFAULT_REGION_NAME,
2727
PROPRIETARY_MODEL_SPEC_PREFIX,
28-
PROPRIETARY_MODEL_FILTER_NAME,
2928
)
3029
from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType
3130
from sagemaker.jumpstart.filters import (
@@ -130,7 +129,7 @@ def extract_framework_task_model(model_id: str) -> Tuple[str, str, str]:
130129
"""
131130
_id_parts = model_id.split("-")
132131

133-
if len(_id_parts) != 3:
132+
if len(_id_parts) < 3:
134133
return "", "", ""
135134

136135
framework = _id_parts[0]
@@ -149,7 +148,7 @@ def extract_model_type(spec_key: str) -> str:
149148
model_spec_prefix = spec_key.split("/")[0]
150149

151150
if model_spec_prefix == PROPRIETARY_MODEL_SPEC_PREFIX:
152-
return PROPRIETARY_MODEL_FILTER_NAME
151+
return JumpStartModelType.PROPRIETARY.value
153152

154153
return JumpStartModelType.OPEN_WEIGHT.value
155154

@@ -228,6 +227,7 @@ def list_jumpstart_scripts( # pylint: disable=redefined-builtin
228227
sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to
229228
use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
230229
"""
230+
231231
if (isinstance(filter, Constant) and filter.resolved_value == BooleanValues.TRUE) or (
232232
isinstance(filter, str) and filter.lower() == BooleanValues.TRUE.lower()
233233
):
@@ -279,8 +279,6 @@ def list_jumpstart_models( # pylint: disable=redefined-builtin
279279
versions should be included in the returned result. (Default: False).
280280
list_versions (bool): Optional. True if versions for models should be returned in addition
281281
to the id of the model. (Default: False).
282-
marketplace_models (bool): Optional. True if only listing JumpStart Marketplace models.
283-
(Default: False).
284282
sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to use
285283
to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
286284
"""
@@ -361,6 +359,11 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin
361359
model_filter = operator.unresolved_value
362360
key = model_filter.key
363361
all_keys.add(key)
362+
if model_filter.key == SpecialSupportedFilterKeys.MODEL_TYPE and model_filter.value in [
363+
"marketplace",
364+
"proprietary",
365+
]:
366+
model_filter.set_value(JumpStartModelType.PROPRIETARY.value)
364367
model_filters.add(model_filter)
365368

366369
for key in all_keys:

src/sagemaker/jumpstart/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ def add_jumpstart_model_id_version_tags(
369369
)
370370
if model_type == enums.JumpStartModelType.PROPRIETARY:
371371
tags = add_single_jumpstart_tag(
372-
enums.JumpStartTag.MARKETPLACE_MODEL_TYPE_VALUE,
372+
enums.JumpStartModelType.PROPRIETARY.value,
373373
enums.JumpStartTag.MODEL_TYPE,
374374
tags,
375375
is_uri=False,

src/sagemaker/payloads.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,8 @@ def retrieve_example(
137137
the model payload.
138138
model_version (str): The version of the JumpStart model for which to retrieve
139139
the model payload.
140-
model_type (str): The model type of the JumpStart model, either is open source
141-
or marketplace (proprietary).
140+
model_type (str): The model type of the JumpStart model, either is open weight
141+
or proprietary.
142142
serialize (bool): Whether to serialize byte-stream valued payloads by downloading
143143
binary files from s3 and applying encoding, or to keep payload in pre-serialized
144144
state. Set this option to False if you want to avoid s3 downloads or if you

tests/unit/sagemaker/jumpstart/model/test_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ def test_proprietary_model_endpoint(
492492
tags=[
493493
{"Key": JumpStartTag.MODEL_ID, "Value": "ai21-summarization"},
494494
{"Key": JumpStartTag.MODEL_VERSION, "Value": "2.0.004"},
495-
{"Key": JumpStartTag.MODEL_TYPE, "Value": "SageMakerJumpStartMarketplace"},
495+
{"Key": JumpStartTag.MODEL_TYPE, "Value": "proprietary"},
496496
],
497497
endpoint_logging=False,
498498
model_data_download_timeout=3600,

tests/unit/sagemaker/jumpstart/test_cache.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def test_jumpstart_cache_get_header():
181181
model_type=JumpStartModelType.PROPRIETARY,
182182
)
183183
assert (
184-
"Marketplace model 'ai21-summarization' does not support wildcard version identifier '3.*'. "
184+
"Proprietary model 'ai21-summarization' does not support wildcard version identifier '3.*'. "
185185
"You can pin to version '1.1.003'. "
186186
"https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html "
187187
"for list of supported model IDs. " in str(e.value)
@@ -941,7 +941,7 @@ def test_jumpstart_cache_get_specs():
941941
model_type=JumpStartModelType.PROPRIETARY,
942942
)
943943
assert (
944-
"Marketplace model 'ai21-summarization' does not support wildcard version identifier '3.*'. "
944+
"Proprietary model 'ai21-summarization' does not support wildcard version identifier '3.*'. "
945945
"You can pin to version '1.1.003'. "
946946
"https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html "
947947
"for list of supported model IDs. " in str(e.value)

tests/unit/sagemaker/jumpstart/test_notebook_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,7 @@ def test_list_jumpstart_proprietary_models(
640640
]
641641

642642
assert list_jumpstart_models("model_type == proprietary") == all_prop_model_ids
643+
assert list_jumpstart_models("model_type == marketplace") == all_prop_model_ids
643644
assert list_jumpstart_models("model_type == open_weight") == all_open_weight_model_ids
644645

645646
assert list_jumpstart_models(list_versions=False) == sorted(

0 commit comments

Comments
 (0)