Skip to content

Commit 378c868

Browse files
committed
Merge branch 'master' of github.com:jerrypeng7773/sagemaker-python-sdk
2 parents 2918765 + fe9bd70 commit 378c868

File tree

14 files changed

+577
-64
lines changed

14 files changed

+577
-64
lines changed

CHANGELOG.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,27 @@
11
# Changelog
22

3+
## v2.48.1 (2021-07-08)
4+
5+
### Bug Fixes and Other Changes
6+
7+
* skip HF inference test
8+
* remove upsert from test_workflow
9+
10+
### Documentation Changes
11+
12+
* Add Hugging Face docs
13+
* add tuning step to doc
14+
15+
## v2.48.0 (2021-07-07)
16+
17+
### Features
18+
19+
* HuggingFace Inference
20+
21+
### Bug Fixes and Other Changes
22+
23+
* add support for SageMaker workflow tuning step
24+
325
## v2.47.2.post0 (2021-07-01)
426

527
### Documentation Changes

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.47.3.dev0
1+
2.48.2.dev0
Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,26 @@
1-
HuggingFace
2-
===========
1+
Hugging Face
2+
============
33

4-
HuggingFace Estimator
5-
---------------------
4+
Hugging Face Estimator
5+
----------------------
66

77
.. autoclass:: sagemaker.huggingface.estimator.HuggingFace
88
:members:
99
:undoc-members:
1010
:show-inheritance:
11+
12+
Hugging Face Model
13+
------------------
14+
15+
.. autoclass:: sagemaker.huggingface.model.HuggingFaceModel
16+
:members:
17+
:undoc-members:
18+
:show-inheritance:
19+
20+
HuggingFace Predictor
21+
---------------------
22+
23+
.. autoclass:: sagemaker.huggingface.model.HuggingFacePredictor
24+
:members:
25+
:undoc-members:
26+
:show-inheritance:

src/sagemaker/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -718,7 +718,7 @@ def compile_model(
718718
'onnx', 'xgboost'
719719
framework_version (str): The version of the framework
720720
compile_max_run (int): Timeout in seconds for compilation (default:
721-
3 * 60). After this amount of time Amazon SageMaker Neo
721+
15 * 60). After this amount of time Amazon SageMaker Neo
722722
terminates the compilation job regardless of its current status.
723723
tags (list[dict]): List of tags for labeling a compilation job. For
724724
more, see

src/sagemaker/huggingface/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.huggingface.estimator import HuggingFace # noqa: F401
17+
from sagemaker.huggingface.model import HuggingFaceModel, HuggingFacePredictor # noqa: F401

src/sagemaker/huggingface/estimator.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
warn_if_parameter_server_with_multi_gpu,
2424
validate_smdistributed,
2525
)
26+
from sagemaker.huggingface.model import HuggingFaceModel
2627
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
2728

2829
logger = logging.getLogger("sagemaker")
@@ -233,8 +234,58 @@ def create_model(
233234
dependencies=None,
234235
**kwargs
235236
):
236-
"""Placeholder docstring"""
237-
raise NotImplementedError("Creating model with HuggingFace training job is not supported.")
237+
"""Create a SageMaker ``HuggingFaceModel`` object that can be deployed to an ``Endpoint``.
238+
239+
Args:
240+
model_server_workers (int): Optional. The number of worker processes
241+
used by the inference server. If None, server will use one
242+
worker per vCPU.
243+
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``,
244+
which is also used during transform jobs. If not specified, the
245+
role from the Estimator will be used.
246+
vpc_config_override (dict[str, list[str]]): Optional override for VpcConfig set on
247+
the model. Default: use subnets and security groups from this Estimator.
248+
* 'Subnets' (list[str]): List of subnet ids.
249+
* 'SecurityGroupIds' (list[str]): List of security group ids.
250+
entry_point (str): Path (absolute or relative) to the local Python source file which
251+
should be executed as the entry point to training. If ``source_dir`` is specified,
252+
then ``entry_point`` must point to a file located at the root of ``source_dir``.
253+
Defaults to `None`.
254+
source_dir (str): Path (absolute or relative) to a directory with any other serving
255+
source code dependencies aside from the entry point file.
256+
If not specified, the model source directory from training is used.
257+
dependencies (list[str]): A list of paths to directories (absolute or relative) with
258+
any additional libraries that will be exported to the container.
259+
If not specified, the dependencies from training are used.
260+
This is not supported with "local code" in Local Mode.
261+
**kwargs: Additional kwargs passed to the :class:`~sagemaker.huggingface.model.HuggingFaceModel`
262+
constructor.
263+
Returns:
264+
sagemaker.huggingface.model.HuggingFaceModel: A SageMaker ``HuggingFaceModel``
265+
object. See :func:`~sagemaker.huggingface.model.HuggingFaceModel` for full details.
266+
"""
267+
if "image_uri" not in kwargs:
268+
kwargs["image_uri"] = self.image_uri
269+
270+
kwargs["name"] = self._get_or_create_name(kwargs.get("name"))
271+
272+
return HuggingFaceModel(
273+
role or self.role,
274+
model_data=self.model_data,
275+
entry_point=entry_point,
276+
transformers_version=self.framework_version,
277+
tensorflow_version=self.tensorflow_version,
278+
pytorch_version=self.pytorch_version,
279+
py_version=self.py_version,
280+
source_dir=(source_dir or self._model_source_dir()),
281+
container_log_level=self.container_log_level,
282+
code_location=self.code_location,
283+
model_server_workers=model_server_workers,
284+
sagemaker_session=self.sagemaker_session,
285+
vpc_config=self.get_vpc_config(vpc_config_override),
286+
dependencies=(dependencies or self.dependencies),
287+
**kwargs
288+
)
238289

239290
@classmethod
240291
def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None):

0 commit comments

Comments
 (0)