Skip to content

Commit 6fb885e

Browse files
committed
address comments
1 parent bceb17b commit 6fb885e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+323
-274
lines changed

doc/doc_utils/jumpstart_doc_utils.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -155,25 +155,26 @@ class Frameworks(str, Enum):
155155
}
156156

157157

158-
def get_jumpstart_sdk_manifest():
159-
url = "{}/{}".format(JUMPSTART_BUCKET_BASE_URL, SDK_MANIFEST_FILE)
158+
def get_public_s3_json_object(url):
160159
with request.urlopen(url) as f:
161160
models_manifest = f.read().decode("utf-8")
162161
return json.loads(models_manifest)
163162

164163

164+
def get_jumpstart_sdk_manifest():
165+
return get_public_s3_json_object(f"{JUMPSTART_BUCKET_BASE_URL}/{SDK_MANIFEST_FILE}")
166+
167+
165168
def get_proprietary_sdk_manifest():
166-
url = "{}/{}".format(PROPRIETARY_DOC_BUCKET, PROPRIETARY_SDK_MANIFEST_FILE)
167-
with request.urlopen(url) as f:
168-
models_manifest = f.read().decode("utf-8")
169-
return json.loads(models_manifest)
169+
return get_public_s3_json_object(f"{PROPRIETARY_DOC_BUCKET}/{PROPRIETARY_SDK_MANIFEST_FILE}")
170170

171171

172-
def get_jumpstart_sdk_spec(key):
173-
url = "{}/{}".format(PROPRIETARY_DOC_BUCKET, key)
174-
with request.urlopen(url) as f:
175-
model_spec = f.read().decode("utf-8")
176-
return json.loads(model_spec)
172+
def get_jumpstart_sdk_spec(s3_key: str):
173+
return get_public_s3_json_object(f"{JUMPSTART_BUCKET_BASE_URL}/{s3_key}")
174+
175+
176+
def get_proprietary_sdk_spec(s3_key: str):
177+
return get_public_s3_json_object(f"{PROPRIETARY_DOC_BUCKET}/{s3_key}")
177178

178179

179180
def get_model_task(id):
@@ -207,18 +208,19 @@ def get_model_source(url):
207208

208209

209210
def create_proprietary_model_table():
210-
marketpkace_content_intro = []
211-
marketpkace_content_intro.append("\n")
212-
marketpkace_content_intro.append(".. list-table:: Available Proprietary Models\n")
213-
marketpkace_content_intro.append(" :widths: 50 20 20 20 20\n")
214-
marketpkace_content_intro.append(" :header-rows: 1\n")
215-
marketpkace_content_intro.append(" :class: datatable\n")
216-
marketpkace_content_intro.append("\n")
217-
marketpkace_content_intro.append(" * - Model ID\n")
218-
marketpkace_content_intro.append(" - Fine Tunable?\n")
219-
marketpkace_content_intro.append(" - Supported Version\n")
220-
marketpkace_content_intro.append(" - Min SDK Version\n")
221-
marketpkace_content_intro.append(" - Source\n")
211+
marketpkace_content_intro = f"""
212+
.. list-table:: Available Proprietary Models
213+
:widths: 50 20 20 20 20
214+
:header-rows: 1
215+
:class: datatable
216+
217+
* - Model ID
218+
- Fine Tunable?
219+
- Supported Version
220+
- Min SDK Version
221+
- Source
222+
223+
"""
222224

223225
sdk_manifest = get_proprietary_sdk_manifest()
224226
sdk_manifest_top_versions_for_models = {}
@@ -234,7 +236,7 @@ def create_proprietary_model_table():
234236

235237
proprietary_content_entries = []
236238
for model in sdk_manifest_top_versions_for_models.values():
237-
model_spec = get_jumpstart_sdk_spec(model["spec_key"])
239+
model_spec = get_proprietary_sdk_spec(model["spec_key"])
238240
proprietary_content_entries.append(" * - {}\n".format(model_spec["model_id"]))
239241
proprietary_content_entries.append(" - {}\n".format(False)) # TODO: support training
240242
proprietary_content_entries.append(" - {}\n".format(model["version"]))

src/sagemaker/accept_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def retrieve_default(
7676
tolerate_vulnerable_model: bool = False,
7777
tolerate_deprecated_model: bool = False,
7878
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
79-
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT,
79+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
8080
) -> str:
8181
"""Retrieves the default accept type for the model matching the given arguments.
8282

src/sagemaker/content_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def retrieve_default(
7676
tolerate_vulnerable_model: bool = False,
7777
tolerate_deprecated_model: bool = False,
7878
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
79-
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT,
79+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
8080
) -> str:
8181
"""Retrieves the default content type for the model matching the given arguments.
8282

src/sagemaker/deserializers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def retrieve_default(
9696
tolerate_vulnerable_model: bool = False,
9797
tolerate_deprecated_model: bool = False,
9898
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
99-
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT,
99+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
100100
) -> BaseDeserializer:
101101
"""Retrieves the default deserializer for the model matching the given arguments.
102102

src/sagemaker/instance_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def retrieve_default(
3535
tolerate_deprecated_model: bool = False,
3636
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
3737
training_instance_type: Optional[str] = None,
38-
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT,
38+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
3939
) -> str:
4040
"""Retrieves the default instance type for the model matching the given arguments.
4141

src/sagemaker/jumpstart/accessors.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def _set_cache_and_region(region: str, cache_kwargs: dict) -> None:
200200
def _get_manifest(
201201
region: str = JUMPSTART_DEFAULT_REGION_NAME,
202202
s3_client: Optional[boto3.client] = None,
203-
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT,
203+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
204204
) -> List[JumpStartModelHeader]:
205205
"""Return entire JumpStart models manifest.
206206
@@ -229,7 +229,7 @@ def get_model_header(
229229
region: str,
230230
model_id: str,
231231
version: str,
232-
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT,
232+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
233233
) -> JumpStartModelHeader:
234234
"""Returns model header from JumpStart models cache.
235235
@@ -254,7 +254,7 @@ def get_model_specs(
254254
model_id: str,
255255
version: str,
256256
s3_client: Optional[boto3.client] = None,
257-
model_type=JumpStartModelType.OPEN_WEIGHT,
257+
model_type=JumpStartModelType.OPEN_WEIGHTS,
258258
) -> JumpStartModelSpecs:
259259
"""Returns model specs from JumpStart models cache.
260260

src/sagemaker/jumpstart/artifacts/instance_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def _retrieve_default_instance_type(
3939
tolerate_deprecated_model: bool = False,
4040
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
4141
training_instance_type: Optional[str] = None,
42-
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT,
42+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
4343
) -> str:
4444
"""Retrieves the default instance type for the model.
4545

src/sagemaker/jumpstart/artifacts/kwargs.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def _retrieve_model_init_kwargs(
3636
tolerate_vulnerable_model: bool = False,
3737
tolerate_deprecated_model: bool = False,
3838
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
39-
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT,
39+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
4040
) -> dict:
4141
"""Retrieves kwargs for `Model`.
4242
@@ -81,9 +81,6 @@ def _retrieve_model_init_kwargs(
8181
if model_specs.inference_enable_network_isolation is not None:
8282
kwargs.update({"enable_network_isolation": model_specs.inference_enable_network_isolation})
8383

84-
if model_type == JumpStartModelType.PROPRIETARY:
85-
kwargs.update({"enable_network_isolation": True})
86-
8784
return kwargs
8885

8986

@@ -95,7 +92,7 @@ def _retrieve_model_deploy_kwargs(
9592
tolerate_vulnerable_model: bool = False,
9693
tolerate_deprecated_model: bool = False,
9794
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
98-
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT,
95+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
9996
) -> dict:
10097
"""Retrieves kwargs for `Model.deploy`.
10198

src/sagemaker/jumpstart/artifacts/model_packages.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def _retrieve_model_package_arn(
3636
tolerate_vulnerable_model: bool = False,
3737
tolerate_deprecated_model: bool = False,
3838
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
39-
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT,
39+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
4040
) -> Optional[str]:
4141
"""Retrieves associated model pacakge arn for the model.
4242

src/sagemaker/jumpstart/artifacts/payloads.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def _retrieve_example_payloads(
3636
tolerate_vulnerable_model: bool = False,
3737
tolerate_deprecated_model: bool = False,
3838
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
39-
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT,
39+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
4040
) -> Optional[Dict[str, JumpStartSerializablePayload]]:
4141
"""Returns example payloads.
4242

src/sagemaker/jumpstart/artifacts/predictors.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def _retrieve_default_deserializer(
7777
tolerate_vulnerable_model: bool = False,
7878
tolerate_deprecated_model: bool = False,
7979
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
80-
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT,
80+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
8181
) -> BaseDeserializer:
8282
"""Retrieves the default deserializer for the model.
8383
@@ -123,7 +123,7 @@ def _retrieve_default_serializer(
123123
tolerate_vulnerable_model: bool = False,
124124
tolerate_deprecated_model: bool = False,
125125
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
126-
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT,
126+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
127127
) -> BaseSerializer:
128128
"""Retrieves the default serializer for the model.
129129
@@ -168,7 +168,7 @@ def _retrieve_deserializer_options(
168168
tolerate_vulnerable_model: bool = False,
169169
tolerate_deprecated_model: bool = False,
170170
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
171-
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT,
171+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
172172
) -> List[BaseDeserializer]:
173173
"""Retrieves the supported deserializers for the model.
174174
@@ -283,7 +283,7 @@ def _retrieve_default_content_type(
283283
region: Optional[str],
284284
tolerate_vulnerable_model: bool = False,
285285
tolerate_deprecated_model: bool = False,
286-
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT,
286+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
287287
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
288288
) -> str:
289289
"""Retrieves the default content type for the model.
@@ -334,7 +334,7 @@ def _retrieve_default_accept_type(
334334
tolerate_vulnerable_model: bool = False,
335335
tolerate_deprecated_model: bool = False,
336336
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
337-
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT,
337+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
338338
) -> str:
339339
"""Retrieves the default accept type for the model.
340340
@@ -385,7 +385,7 @@ def _retrieve_supported_accept_types(
385385
tolerate_vulnerable_model: bool = False,
386386
tolerate_deprecated_model: bool = False,
387387
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
388-
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT,
388+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
389389
) -> List[str]:
390390
"""Retrieves the supported accept types for the model.
391391
@@ -436,7 +436,7 @@ def _retrieve_supported_content_types(
436436
tolerate_vulnerable_model: bool = False,
437437
tolerate_deprecated_model: bool = False,
438438
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
439-
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT,
439+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
440440
) -> List[str]:
441441
"""Retrieves the supported content types for the model.
442442

src/sagemaker/jumpstart/artifacts/resource_names.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def _retrieve_resource_name_base(
3333
region: Optional[str],
3434
tolerate_vulnerable_model: bool = False,
3535
tolerate_deprecated_model: bool = False,
36-
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT,
36+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
3737
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
3838
) -> bool:
3939
"""Returns default resource name.

src/sagemaker/jumpstart/artifacts/resource_requirements.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def _retrieve_default_resources(
5151
region: Optional[str] = None,
5252
tolerate_vulnerable_model: bool = False,
5353
tolerate_deprecated_model: bool = False,
54-
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHT,
54+
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
5555
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
5656
instance_type: Optional[str] = None,
5757
) -> ResourceRequirements:

0 commit comments

Comments
 (0)