Skip to content

Commit c2cea54

Browse files
committed
feature: infer framework and version for huggingface class
1 parent d246c52 commit c2cea54

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

src/sagemaker/huggingface/model.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,13 @@ def _validate_pt_tf_versions(pytorch_version, tensorflow_version, image_uri):
8585
)
8686

8787

88+
def fetch_framework_and_framework_version(tensorflow_version, pytorch_version):
89+
if tensorflow_version is not None: # pylint: disable=no-member
90+
return ("tensorflow", tensorflow_version) # pylint: disable=no-member
91+
else:
92+
return ("pytorch", pytorch_version) # pylint: disable=no-member
93+
94+
8895
class HuggingFaceModel(FrameworkModel):
8996
"""A Hugging Face SageMaker ``Model`` that can be deployed to a SageMaker ``Endpoint``."""
9097

@@ -289,6 +296,7 @@ def deploy(
289296
serverless_inference_config,
290297
)
291298

299+
292300
def register(
293301
self,
294302
content_types,
@@ -387,8 +395,8 @@ def register(
387395
domain=domain,
388396
sample_payload_url=sample_payload_url,
389397
task=task,
390-
framework=framework,
391-
framework_version=framework_version,
398+
framework=framework or fetch_framework_and_framework_version()[0],
399+
framework_version=framework_version or fetch_framework_and_framework_version()[1],
392400
nearest_model_name=nearest_model_name,
393401
data_input_configuration=data_input_configuration,
394402
)

0 commit comments

Comments
 (0)