Skip to content

Commit bcc71db

Browse files
committed
Merge from master
2 parents 7a2374f + 02f6f44 commit bcc71db

32 files changed

+1680
-27
lines changed

CHANGELOG.md

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,36 @@
11
# Changelog
22

3+
## v2.49.2 (2021-07-21)
4+
5+
### Bug Fixes and Other Changes
6+
7+
* order of populating container list
8+
* upgrade Adobe Analytics cookie to 3.0
9+
10+
## v2.49.1 (2021-07-19)
11+
12+
### Bug Fixes and Other Changes
13+
14+
* Set flag when debugger is disabled
15+
* KMS Key fix for kwargs
16+
* Update BiasConfig to accept multiple facet params
17+
18+
### Documentation Changes
19+
20+
* Update huggingface estimator documentation
21+
22+
## v2.49.0 (2021-07-15)
23+
24+
### Features
25+
26+
* Adding serial inference pipeline support to RegisterModel Step
27+
28+
### Documentation Changes
29+
30+
* add tuning step get_top_model_s3_uri and callback step to doc
31+
* links for HF in sdk
32+
* Add Clarify module to Model Monitoring API docs
33+
334
## v2.48.2 (2021-07-12)
435

536
### Bug Fixes and Other Changes

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.48.3.dev0
1+
2.49.3.dev0

doc/_static/aws-ux-shortbread/index.js

Lines changed: 3 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

doc/_static/aws-ux-shortbread/init.js

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
(function (w) {
2+
w.URLSearchParams = w.URLSearchParams || function (searchString) {
3+
var self = this;
4+
self.searchString = searchString;
5+
self.get = function (name) {
6+
var results = new RegExp('[\?&]' + name + '=([^&#]*)').exec(self.searchString);
7+
if (results === null) {
8+
return null;
9+
}
10+
else {
11+
return decodeURI(results[1]) || 0;
12+
}
13+
};
14+
}
15+
})(window);
16+
17+
const queryString = window.location.search;
18+
const urlParams = new URLSearchParams(queryString);
19+
const lang = urlParams.get('lang')
20+
window.onload = function () {
21+
var domainName = window.location.hostname;
22+
23+
// remove an instance of shortbread if already exists
24+
var existingShortbreadEl = document.getElementById("awsccc-sb-ux-c");
25+
existingShortbreadEl && existingShortbreadEl.remove();
26+
27+
var shortbread = AWSCShortbread({
28+
domain: domainName,
29+
language: lang,
30+
//queryGeolocation: function (geolocatedIn) { geolocatedIn("EU") },
31+
});
32+
33+
shortbread.checkForCookieConsent();
34+
}

doc/conf.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,12 @@
6868

6969
htmlhelp_basename = "%sdoc" % project
7070

71-
html_js_files = ["https://a0.awsstatic.com/s_code/js/1.0/awshome_s_code.js", "js/analytics.js"]
71+
# For Adobe Analytics
72+
html_js_files = [
73+
"https://a0.awsstatic.com/s_code/js/3.0/awshome_s_code.js",
74+
"aws-ux-shortbread/index.js",
75+
"aws-ux-shortbread/init.js",
76+
]
7277

7378
html_context = {"css_files": ["_static/theme_overrides.css"]}
7479

src/sagemaker/clarify.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,21 +88,34 @@ def __init__(
8888
Args:
8989
label_values_or_threshold (Any): List of label values or threshold to indicate positive
9090
outcome used for bias metrics.
91-
facet_name (str): Sensitive attribute in the input data for which we like to compare
92-
metrics.
91+
facet_name (str or [str]): String or List of strings of sensitive attribute(s) in the
92+
input data for which we like to compare metrics.
9393
facet_values_or_threshold (list): Optional list of values to form a sensitive group or
9494
threshold for a numeric facet column that defines the lower bound of a sensitive
9595
group. Defaults to considering each possible value as sensitive group and
9696
computing metrics vs all the other examples.
97+
If facet_name is a list, this needs to be None or a List consisting of lists or None
98+
with the same length as facet_name list.
9799
group_name (str): Optional column name or index to indicate a group column to be used
98100
for the bias metric 'Conditional Demographic Disparity in Labels - CDDL' or
99101
'Conditional Demographic Disparity in Predicted Labels - CDDPL'.
100102
"""
101-
facet = {"name_or_index": facet_name}
102-
_set(facet_values_or_threshold, "value_or_threshold", facet)
103+
if isinstance(facet_name, str):
104+
facet = {"name_or_index": facet_name}
105+
_set(facet_values_or_threshold, "value_or_threshold", facet)
106+
facet_list = [facet]
107+
elif facet_values_or_threshold is None or len(facet_name) == len(facet_values_or_threshold):
108+
facet_list = []
109+
for i, single_facet_name in enumerate(facet_name):
110+
facet = {"name_or_index": single_facet_name}
111+
if facet_values_or_threshold is not None:
112+
_set(facet_values_or_threshold[i], "value_or_threshold", facet)
113+
facet_list.append(facet)
114+
else:
115+
raise ValueError("Wrong combination of argument values passed")
103116
self.analysis_config = {
104117
"label_values_or_threshold": label_values_or_threshold,
105-
"facet": [facet],
118+
"facet": facet_list,
106119
}
107120
_set(group_name, "group_variable", self.analysis_config)
108121

src/sagemaker/debugger/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from sagemaker.debugger.debugger import ( # noqa: F401
1717
CollectionConfig,
18+
DEBUGGER_FLAG,
1819
DebuggerHookConfig,
1920
framework_name,
2021
get_default_profiler_rule,

src/sagemaker/debugger/debugger.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from sagemaker.utils import build_dict
3333

3434
framework_name = "debugger"
35+
DEBUGGER_FLAG = "USE_SMDEBUG"
3536

3637

3738
def get_rule_container_image_uri(region):

src/sagemaker/estimator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from sagemaker.analytics import TrainingJobAnalytics
3030
from sagemaker.debugger import TensorBoardOutputConfig # noqa: F401 # pylint: disable=unused-import
3131
from sagemaker.debugger import (
32+
DEBUGGER_FLAG,
3233
DebuggerHookConfig,
3334
FrameworkProfile,
3435
get_default_profiler_rule,
@@ -2269,6 +2270,11 @@ def _validate_and_set_debugger_configs(self):
22692270
)
22702271
self.debugger_hook_config = False
22712272

2273+
if self.debugger_hook_config is False:
2274+
if self.environment is None:
2275+
self.environment = {}
2276+
self.environment[DEBUGGER_FLAG] = "0"
2277+
22722278
def _stage_user_code_in_s3(self):
22732279
"""Upload the user training script to s3 and return the location.
22742280

src/sagemaker/huggingface/estimator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,10 @@ def __init__(
7070
``image_uri`` is provided. The current supported version is ``4.6.1``.
7171
tensorflow_version (str): TensorFlow version you want to use for
7272
executing your model training code. Defaults to ``None``. Required unless
73-
``pytorch_version`` is provided. The current supported version is ``1.6.0``.
73+
``pytorch_version`` is provided. The current supported version is ``2.4.1``.
7474
pytorch_version (str): PyTorch version you want to use for
7575
executing your model training code. Defaults to ``None``. Required unless
76-
``tensorflow_version`` is provided. The current supported version is ``2.4.1``.
76+
``tensorflow_version`` is provided. The current supported versions are ``1.7.1`` and ``1.6.0``.
7777
source_dir (str): Path (absolute, relative or an S3 URI) to a directory
7878
with any other training source code dependencies aside from the entry
7979
point file (default: None). If ``source_dir`` is an S3 URI, it must

src/sagemaker/model.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16+
import abc
1617
import json
1718
import logging
1819
import os
@@ -29,6 +30,7 @@
2930
git_utils,
3031
)
3132
from sagemaker.deprecations import removed_kwargs
33+
from sagemaker.predictor import PredictorBase
3234
from sagemaker.transformer import Transformer
3335

3436
LOGGER = logging.getLogger("sagemaker")
@@ -38,7 +40,23 @@
3840
)
3941

4042

41-
class Model(object):
43+
class ModelBase(abc.ABC):
44+
"""An object that encapsulates a trained model.
45+
46+
Models can be deployed to compute services like a SageMaker ``Endpoint``
47+
or Lambda. Deployed models can be used to perform real-time inference.
48+
"""
49+
50+
@abc.abstractmethod
51+
def deploy(self, *args, **kwargs) -> PredictorBase:
52+
"""Deploy this model to a compute service."""
53+
54+
@abc.abstractmethod
55+
def delete_model(self, *args, **kwargs) -> None:
56+
"""Destroy resources associated with this model."""
57+
58+
59+
class Model(ModelBase):
4260
"""A SageMaker ``Model`` that can be deployed to an ``Endpoint``."""
4361

4462
def __init__(

src/sagemaker/predictor.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
"""Placeholder docstring"""
1414
from __future__ import print_function, absolute_import
1515

16+
import abc
17+
from typing import Any, Tuple
18+
1619
from sagemaker.deprecations import (
1720
deprecated_class,
1821
deprecated_deserialize,
@@ -51,7 +54,29 @@
5154
from sagemaker.lineage.context import EndpointContext
5255

5356

54-
class Predictor(object):
57+
class PredictorBase(abc.ABC):
58+
"""An object that encapsulates a deployed model."""
59+
60+
@abc.abstractmethod
61+
def predict(self, *args, **kwargs) -> Any:
62+
"""Perform inference on the provided data and return a prediction."""
63+
64+
@abc.abstractmethod
65+
def delete_endpoint(self, *args, **kwargs) -> None:
66+
"""Destroy resources associated with this predictor."""
67+
68+
@property
69+
@abc.abstractmethod
70+
def content_type(self) -> str:
71+
"""The MIME type of the data sent to the inference server."""
72+
73+
@property
74+
@abc.abstractmethod
75+
def accept(self) -> Tuple[str]:
76+
"""The content type(s) that are expected from the inference server."""
77+
78+
79+
class Predictor(PredictorBase):
5580
"""Make prediction requests to an Amazon SageMaker endpoint."""
5681

5782
def __init__(

src/sagemaker/serverless/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Classes for performing machine learning on serverless compute."""
14+
from sagemaker.serverless.model import LambdaModel # noqa: F401
15+
from sagemaker.serverless.predictor import LambdaPredictor # noqa: F401

src/sagemaker/serverless/model.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Models that can be deployed to serverless compute."""
14+
from __future__ import absolute_import
15+
16+
import time
17+
from typing import Optional
18+
19+
import boto3
20+
import botocore
21+
22+
from sagemaker.model import ModelBase
23+
24+
from .predictor import LambdaPredictor
25+
26+
27+
class LambdaModel(ModelBase):
28+
"""A model that can be deployed to Lambda."""
29+
30+
def __init__(
31+
self, image_uri: str, role: str, client: Optional[botocore.client.BaseClient] = None
32+
) -> None:
33+
"""Initialize instance attributes.
34+
35+
Arguments:
36+
image_uri: URI of a container image in the Amazon ECR registry. The image
37+
should contain a handler that performs inference.
38+
role: The Amazon Resource Name (ARN) of the IAM role that Lambda will assume
39+
when it performs inference
40+
client: The Lambda client used to interact with Lambda.
41+
"""
42+
self._client = client or boto3.client("lambda")
43+
self._image_uri = image_uri
44+
self._role = role
45+
46+
def deploy(
47+
self, function_name: str, timeout: int, memory_size: int, wait: bool = True
48+
) -> LambdaPredictor:
49+
"""Create a Lambda function using the image specified in the constructor.
50+
51+
Arguments:
52+
function_name: The name of the function.
53+
timeout: The number of seconds that the function can run for before being terminated.
54+
memory_size: The amount of memory in MB that the function has access to.
55+
wait: If true, wait until the deployment completes (default: True).
56+
57+
Returns:
58+
A LambdaPredictor instance that performs inference using the specified image.
59+
"""
60+
response = self._client.create_function(
61+
FunctionName=function_name,
62+
PackageType="Image",
63+
Role=self._role,
64+
Code={
65+
"ImageUri": self._image_uri,
66+
},
67+
Timeout=timeout,
68+
MemorySize=memory_size,
69+
)
70+
71+
if not wait:
72+
return LambdaPredictor(function_name, client=self._client)
73+
74+
# Poll function state.
75+
polling_interval = 5
76+
while response["State"] == "Pending":
77+
time.sleep(polling_interval)
78+
response = self._client.get_function_configuration(FunctionName=function_name)
79+
80+
if response["State"] != "Active":
81+
raise RuntimeError("Failed to deploy model to Lambda: %s" % response["StateReason"])
82+
83+
return LambdaPredictor(function_name, client=self._client)
84+
85+
def delete_model(self) -> None:
86+
"""Destroy resources associated with this model.
87+
88+
This method does not delete the image specified in the constructor. As
89+
a result, this method is a no-op.
90+
"""

0 commit comments

Comments
 (0)