Skip to content

Commit 27e14b9

Browse files
committed
update filter name
1 parent d5b9b76 commit 27e14b9

File tree

3 files changed

+9
-5
lines changed

3 files changed

+9
-5
lines changed

src/sagemaker/jumpstart/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,9 @@
190190

191191
SAGEMAKER_GATED_MODEL_S3_URI_TRAINING_ENV_VAR_KEY = "SageMakerGatedModelS3Uri"
192192

193+
PROPRIETARY_MODEL_SPEC_PREFIX = "proprietary-models"
194+
PROPRIETARY_MODEL_FILTER_NAME = "marketplace"
195+
193196
CONTENT_TYPE_TO_SERIALIZER_TYPE_MAP: Dict[MIMEType, SerializerType] = {
194197
MIMEType.X_IMAGE: SerializerType.RAW_BYTES,
195198
MIMEType.LIST_TEXT: SerializerType.JSON,

src/sagemaker/jumpstart/notebook_utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from sagemaker.jumpstart.constants import (
2525
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
2626
JUMPSTART_DEFAULT_REGION_NAME,
27+
PROPRIETARY_MODEL_SPEC_PREFIX,
28+
PROPRIETARY_MODEL_FILTER_NAME,
2729
)
2830
from sagemaker.jumpstart.enums import JumpStartScriptScope, JumpStartModelType
2931
from sagemaker.jumpstart.filters import (
@@ -128,7 +130,7 @@ def extract_framework_task_model(model_id: str) -> Tuple[str, str, str]:
128130
"""
129131
_id_parts = model_id.split("-")
130132

131-
if len(_id_parts) < 3:
133+
if len(_id_parts) != 3:
132134
return "", "", ""
133135

134136
framework = _id_parts[0]
@@ -144,10 +146,10 @@ def extract_model_type(spec_key: str) -> str:
144146
Args:
145147
spek_key (str): The model spec key for which to extract the model type.
146148
"""
147-
model_type = spec_key.split("/")[0]
149+
model_spec_prefix = spec_key.split("/")[0]
148150

149-
if model_type == "proprietary-models":
150-
return JumpStartModelType.PROPRIETARY.value
151+
if model_spec_prefix == PROPRIETARY_MODEL_SPEC_PREFIX:
152+
return PROPRIETARY_MODEL_FILTER_NAME
151153

152154
return JumpStartModelType.OPEN_WEIGHT.value
153155

tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from sagemaker.predictor import retrieve_default
2222

2323
import tests.integ
24-
from sagemaker.jumpstart import notebook_utils
2524

2625
from sagemaker.jumpstart.model import JumpStartModel
2726
from tests.integ.sagemaker.jumpstart.constants import (

0 commit comments

Comments
 (0)