Skip to content

Commit 5af52b8

Browse files
milahajaykarpurahsan-z-khan
authored
feature: Add support for accelerator in Clarify (#2249)
Co-authored-by: Ajay Karpur <[email protected]> Co-authored-by: Ahsan Khan <[email protected]>
1 parent f27682f commit 5af52b8

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

src/sagemaker/clarify.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def __init__(
123123
content_type=None,
124124
content_template=None,
125125
custom_attributes=None,
126+
accelerator_type=None,
126127
):
127128
"""Initializes a configuration of a model and the endpoint to be created for it.
128129
@@ -151,6 +152,9 @@ def __init__(
151152
Section 3.3.6. Field Value Components (
152153
https://tools.ietf.org/html/rfc7230#section-3.2.6) of the Hypertext Transfer
153154
Protocol (HTTP/1.1).
155+
accelerator_type (str): The Elastic Inference accelerator type to deploy to the model
156+
endpoint instance for making inferences to the model, see
157+
https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html.
154158
"""
155159
self.predictor_config = {
156160
"model_name": model_name,
@@ -178,9 +182,8 @@ def __init__(
178182
f" Please include a placeholder $features."
179183
)
180184
self.predictor_config["content_template"] = content_template
181-
182-
if custom_attributes is not None:
183-
self.predictor_config["custom_attributes"] = custom_attributes
185+
_set(custom_attributes, "custom_attributes", self.predictor_config)
186+
_set(accelerator_type, "accelerator_type", self.predictor_config)
184187

185188
def get_predictor_config(self):
186189
"""Returns part of the predictor dictionary of the analysis config."""

tests/unit/test_clarify.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,15 @@ def test_model_config():
9292
accept_type = "text/csv"
9393
content_type = "application/jsonlines"
9494
custom_attributes = "c000b4f9-df62-4c85-a0bf-7c525f9104a4"
95+
accelerator_type = "ml.eia1.medium"
9596
model_config = ModelConfig(
9697
model_name=model_name,
9798
instance_type=instance_type,
9899
instance_count=instance_count,
99100
accept_type=accept_type,
100101
content_type=content_type,
101102
custom_attributes=custom_attributes,
103+
accelerator_type=accelerator_type,
102104
)
103105
expected_config = {
104106
"model_name": model_name,
@@ -107,6 +109,7 @@ def test_model_config():
107109
"accept_type": accept_type,
108110
"content_type": content_type,
109111
"custom_attributes": custom_attributes,
112+
"accelerator_type": accelerator_type,
110113
}
111114
assert expected_config == model_config.get_predictor_config()
112115

0 commit comments

Comments
 (0)