Skip to content

Commit 5403218

Browse files
authored
Merge branch 'master' into master
2 parents 35eda59 + 5af52b8 commit 5403218

File tree

8 files changed

+170
-540
lines changed

8 files changed

+170
-540
lines changed

doc/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
sphinx==3.1.1
22
sphinx-rtd-theme==0.5.0
3+
docutils==0.15.2

src/sagemaker/clarify.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def __init__(
123123
content_type=None,
124124
content_template=None,
125125
custom_attributes=None,
126+
accelerator_type=None,
126127
):
127128
"""Initializes a configuration of a model and the endpoint to be created for it.
128129
@@ -151,6 +152,9 @@ def __init__(
151152
Section 3.3.6. Field Value Components (
152153
https://tools.ietf.org/html/rfc7230#section-3.2.6) of the Hypertext Transfer
153154
Protocol (HTTP/1.1).
155+
accelerator_type (str): The Elastic Inference accelerator type to deploy to the model
156+
endpoint instance for making inferences to the model, see
157+
https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html.
154158
"""
155159
self.predictor_config = {
156160
"model_name": model_name,
@@ -178,9 +182,8 @@ def __init__(
178182
f" Please include a placeholder $features."
179183
)
180184
self.predictor_config["content_template"] = content_template
181-
182-
if custom_attributes is not None:
183-
self.predictor_config["custom_attributes"] = custom_attributes
185+
_set(custom_attributes, "custom_attributes", self.predictor_config)
186+
_set(accelerator_type, "accelerator_type", self.predictor_config)
184187

185188
def get_predictor_config(self):
186189
"""Returns part of the predictor dictionary of the analysis config."""

tests/conftest.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def pytorch_inference_py_version(pytorch_inference_version, request):
190190
return "py3"
191191

192192

193-
def _huggingface_pytorch_version(huggingface_vesion):
193+
def _huggingface_base_fm_version(huggingface_vesion, base_fw):
194194
config = image_uris.config_for_framework("huggingface")
195195
training_config = config.get("training")
196196
original_version = huggingface_vesion
@@ -200,21 +200,26 @@ def _huggingface_pytorch_version(huggingface_vesion):
200200
)
201201
version_config = training_config.get("versions").get(huggingface_vesion)
202202
for key in list(version_config.keys()):
203-
if key.startswith("pytorch"):
204-
pt_version = key[7:]
203+
if key.startswith(base_fw):
204+
base_fw_version = key[len(base_fw) :]
205205
if len(original_version.split(".")) == 2:
206-
pt_version = ".".join(pt_version.split(".")[:-1])
207-
return pt_version
206+
base_fw_version = ".".join(base_fw_version.split(".")[:-1])
207+
return base_fw_version
208208

209209

210210
@pytest.fixture(scope="module")
211211
def huggingface_pytorch_version(huggingface_training_version):
212-
return _huggingface_pytorch_version(huggingface_training_version)
212+
return _huggingface_base_fm_version(huggingface_training_version, "pytorch")
213213

214214

215215
@pytest.fixture(scope="module")
216216
def huggingface_pytorch_latest_version(huggingface_training_latest_version):
217-
return _huggingface_pytorch_version(huggingface_training_latest_version)
217+
return _huggingface_base_fm_version(huggingface_training_latest_version, "pytorch")
218+
219+
220+
@pytest.fixture(scope="module")
221+
def huggingface_tensorflow_latest_version(huggingface_training_latest_version):
222+
return _huggingface_base_fm_version(huggingface_training_latest_version, "tensorflow")
218223

219224

220225
@pytest.fixture(scope="module")

0 commit comments

Comments
 (0)