Skip to content

Commit 9b44c54

Browse files
committed
add hub_arn support for accept_types, content_types, serializers, deserializers, and predictor (aws#4463)
1 parent 65ef0d3 commit 9b44c54

File tree

2 files changed

+6
-0
lines changed

2 files changed

+6
-0
lines changed

src/sagemaker/jumpstart/factory/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -725,6 +725,7 @@ def get_init_kwargs(
725725
model_version: Optional[str] = None,
726726
hub_arn: Optional[str] = None,
727727
model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS,
728+
hub_arn: Optional[str] = None,
728729
tolerate_vulnerable_model: Optional[bool] = None,
729730
tolerate_deprecated_model: Optional[bool] = None,
730731
instance_type: Optional[str] = None,
@@ -758,6 +759,7 @@ def get_init_kwargs(
758759
model_version=model_version,
759760
hub_arn=hub_arn,
760761
model_type=model_type,
762+
hub_arn=hub_arn,
761763
instance_type=instance_type,
762764
region=region,
763765
image_uri=image_uri,

src/sagemaker/jumpstart/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1310,6 +1310,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs):
13101310
"model_version",
13111311
"hub_arn",
13121312
"model_type",
1313+
"hub_arn",
13131314
"instance_type",
13141315
"tolerate_vulnerable_model",
13151316
"tolerate_deprecated_model",
@@ -1342,6 +1343,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs):
13421343
"model_version",
13431344
"hub_arn",
13441345
"model_type",
1346+
"hub_arn",
13451347
"tolerate_vulnerable_model",
13461348
"tolerate_deprecated_model",
13471349
"region",
@@ -1355,6 +1357,7 @@ def __init__(
13551357
model_version: Optional[str] = None,
13561358
hub_arn: Optional[str] = None,
13571359
model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS,
1360+
hub_arn: Optional[str] = None,
13581361
region: Optional[str] = None,
13591362
instance_type: Optional[str] = None,
13601363
image_uri: Optional[Union[str, Any]] = None,
@@ -1386,6 +1389,7 @@ def __init__(
13861389
self.model_version = model_version
13871390
self.hub_arn = hub_arn
13881391
self.model_type = model_type
1392+
self.hub_arn = hub_arn
13891393
self.instance_type = instance_type
13901394
self.region = region
13911395
self.image_uri = image_uri

0 commit comments

Comments
 (0)