Skip to content

Commit 5715a33

Browse files
authored
Merge branch 'master' into master
2 parents b9c6c82 + 93006f3 commit 5715a33

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+7903
-241
lines changed

.flake8

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ application_import_names = sagemaker, tests
33
import-order-style = google
44
per-file-ignores =
55
tests/unit/test_tuner.py: F405
6+
src/sagemaker/config/config_schema.py: E501

CHANGELOG.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,16 @@
11
# Changelog
22

3+
## v2.148.0 (2023-04-20)
4+
5+
### Features
6+
7+
* [huggingface] Add `torch.distributed` support for Trainium and `torchrun`
8+
* Add PyTorch 2.0 to SDK
9+
10+
### Bug Fixes and Other Changes
11+
12+
* updating batch transform job in monitoring schedule
13+
314
## v2.147.0 (2023-04-18)
415

516
### Features

README.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,8 @@ To run the integration tests, the following prerequisites must be met
133133
1. AWS account credentials are available in the environment for the boto3 client to use.
134134
2. The AWS account has an IAM role named :code:`SageMakerRole`.
135135
It should have the AmazonSageMakerFullAccess policy attached as well as a policy with `the necessary permissions to use Elastic Inference <https://docs.aws.amazon.com/sagemaker/latest/dg/ei-setup.html>`__.
136+
3. To run remote_function tests, dummy ecr repo should be created. It can be created by running -
137+
:code:`aws ecr create-repository --repository-name remote-function-dummy-container`
136138

137139
We recommend selectively running just those integration tests you'd like to run. You can filter by individual test function names with:
138140

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.147.1.dev0
1+
2.148.1.dev0

requirements/extras/test_requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,4 @@ sagemaker-experiments==0.1.35
2121
Jinja2==3.0.3
2222
pandas>=1.3.5,<1.5
2323
scikit-learn==1.0.2
24+
cloudpickle==2.2.1

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def read_requirements(filename):
4949
required_packages = [
5050
"attrs>=20.3.0,<23",
5151
"boto3>=1.26.28,<2.0",
52+
"cloudpickle==2.2.1",
5253
"google-pasta",
5354
"numpy>=1.9.0,<2.0",
5455
"protobuf>=3.1,<4.0",
@@ -62,6 +63,7 @@ def read_requirements(filename):
6263
"PyYAML==5.4.1",
6364
"jsonschema",
6465
"platformdirs",
66+
"tblib==1.7.0",
6567
]
6668

6769
# Specific use case dependencies

src/sagemaker/clarify.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import tempfile
2727
from abc import ABC, abstractmethod
2828
from typing import List, Union, Dict, Optional, Any
29-
29+
from enum import Enum
3030
from schema import Schema, And, Use, Or, Optional as SchemaOptional, Regex
3131

3232
from sagemaker import image_uris, s3, utils
@@ -304,6 +304,16 @@
304304
)
305305

306306

307+
class DatasetType(Enum):
308+
"""Enum to store different dataset types supported in the Analysis config file"""
309+
310+
TEXTCSV = "text/csv"
311+
JSONLINES = "application/jsonlines"
312+
JSON = "application/json"
313+
PARQUET = "application/x-parquet"
314+
IMAGE = "application/x-image"
315+
316+
307317
class DataConfig:
308318
"""Config object related to configurations of the input and output dataset."""
309319

@@ -1451,7 +1461,7 @@ def _run(
14511461
source=self._CLARIFY_OUTPUT,
14521462
destination=data_config.s3_output_path,
14531463
output_name="analysis_result",
1454-
s3_upload_mode="EndOfJob",
1464+
s3_upload_mode=ProcessingOutputHandler.get_s3_upload_mode(analysis_config),
14551465
)
14561466

14571467
return super().run(
@@ -2171,6 +2181,33 @@ def _upload_analysis_config(analysis_config_file, s3_output_path, sagemaker_sess
21712181
)
21722182

21732183

2184+
class ProcessingOutputHandler:
2185+
"""Class to handle the parameters for SagemakerProcessor.Processingoutput"""
2186+
2187+
class S3UploadMode(Enum):
2188+
"""Enum values for different uplaod modes to s3 bucket"""
2189+
2190+
CONTINUOUS = "Continuous"
2191+
ENDOFJOB = "EndOfJob"
2192+
2193+
@classmethod
2194+
def get_s3_upload_mode(cls, analysis_config: Dict[str, Any]) -> str:
2195+
"""Fetches s3_upload mode based on the shap_config values
2196+
2197+
Args:
2198+
analysis_config (dict): dict Config following the analysis_config.json format
2199+
2200+
Returns:
2201+
The s3_upload_mode type for the processing output.
2202+
"""
2203+
dataset_type = analysis_config["dataset_type"]
2204+
return (
2205+
ProcessingOutputHandler.S3UploadMode.CONTINUOUS.value
2206+
if dataset_type == DatasetType.IMAGE.value
2207+
else ProcessingOutputHandler.S3UploadMode.ENDOFJOB.value
2208+
)
2209+
2210+
21742211
def _set(value, key, dictionary):
21752212
"""Sets dictionary[key] = value if value is not None."""
21762213
if value is not None:

src/sagemaker/config/config_schema.py

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,17 @@
4444
SAGEMAKER = "SageMaker"
4545
PYTHON_SDK = "PythonSDK"
4646
MODULES = "Modules"
47+
REMOTE_FUNCTION = "RemoteFunction"
48+
DEPENDENCIES = "Dependencies"
49+
PRE_EXECUTION_SCRIPT = "PreExecutionScript"
50+
PRE_EXECUTION_COMMANDS = "PreExecutionCommands"
51+
ENVIRONMENT_VARIABLES = "EnvironmentVariables"
52+
IMAGE_URI = "ImageUri"
53+
INCLUDE_LOCAL_WORKDIR = "IncludeLocalWorkDir"
54+
INSTANCE_TYPE = "InstanceType"
55+
S3_KMS_KEY_ID = "S3KmsKeyId"
56+
S3_ROOT_URI = "S3RootUri"
57+
JOB_CONDA_ENV = "JobCondaEnvironment"
4758
OFFLINE_STORE_CONFIG = "OfflineStoreConfig"
4859
ONLINE_STORE_CONFIG = "OnlineStoreConfig"
4960
S3_STORAGE_CONFIG = "S3StorageConfig"
@@ -221,6 +232,49 @@ def _simple_path(*args: str):
221232
SAGEMAKER, MODEL_PACKAGE, VALIDATION_SPECIFICATION, VALIDATION_PROFILES
222233
)
223234

235+
REMOTE_FUNCTION_DEPENDENCIES = _simple_path(
236+
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, DEPENDENCIES
237+
)
238+
REMOTE_FUNCTION_PRE_EXECUTION_COMMANDS = _simple_path(
239+
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, PRE_EXECUTION_COMMANDS
240+
)
241+
REMOTE_FUNCTION_PRE_EXECUTION_SCRIPT = _simple_path(
242+
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, PRE_EXECUTION_SCRIPT
243+
)
244+
REMOTE_FUNCTION_ENVIRONMENT_VARIABLES = _simple_path(
245+
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, ENVIRONMENT_VARIABLES
246+
)
247+
REMOTE_FUNCTION_IMAGE_URI = _simple_path(SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, IMAGE_URI)
248+
REMOTE_FUNCTION_INCLUDE_LOCAL_WORKDIR = _simple_path(
249+
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, INCLUDE_LOCAL_WORKDIR
250+
)
251+
REMOTE_FUNCTION_INSTANCE_TYPE = _simple_path(
252+
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, INSTANCE_TYPE
253+
)
254+
REMOTE_FUNCTION_JOB_CONDA_ENV = _simple_path(
255+
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, JOB_CONDA_ENV
256+
)
257+
REMOTE_FUNCTION_ROLE_ARN = _simple_path(SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, ROLE_ARN)
258+
REMOTE_FUNCTION_S3_KMS_KEY_ID = _simple_path(
259+
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, S3_KMS_KEY_ID
260+
)
261+
REMOTE_FUNCTION_S3_ROOT_URI = _simple_path(
262+
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, S3_ROOT_URI
263+
)
264+
REMOTE_FUNCTION_TAGS = _simple_path(SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, TAGS)
265+
REMOTE_FUNCTION_VOLUME_KMS_KEY_ID = _simple_path(
266+
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, VOLUME_KMS_KEY_ID
267+
)
268+
REMOTE_FUNCTION_VPC_CONFIG_SUBNETS = _simple_path(
269+
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, VPC_CONFIG, SUBNETS
270+
)
271+
REMOTE_FUNCTION_VPC_CONFIG_SECURITY_GROUP_IDS = _simple_path(
272+
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, VPC_CONFIG, SECURITY_GROUP_IDS
273+
)
274+
REMOTE_FUNCTION_ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION = _simple_path(
275+
SAGEMAKER, PYTHON_SDK, MODULES, REMOTE_FUNCTION, ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION
276+
)
277+
224278
# Paths for reference elsewhere in the SDK.
225279
# Names include the schema version since the paths could change with other schema versions
226280
MONITORING_SCHEDULE_INTER_CONTAINER_ENCRYPTION_PATH = _simple_path(
@@ -245,7 +299,6 @@ def _simple_path(*args: str):
245299
SAGEMAKER, TRAINING_JOB, ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION
246300
)
247301

248-
249302
SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA = {
250303
"$schema": "https://json-schema.org/draft/2020-12/schema",
251304
TYPE: OBJECT,
@@ -377,6 +430,23 @@ def _simple_path(*args: str):
377430
"minItems": 0,
378431
"maxItems": 50,
379432
},
433+
# Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTrainingJob.html#sagemaker-CreateTrainingJob-request-Environment
434+
"environmentVariables": {
435+
TYPE: OBJECT,
436+
ADDITIONAL_PROPERTIES: False,
437+
PATTERN_PROPERTIES: {
438+
r"([a-zA-Z_][a-zA-Z0-9_]*){1,512}": {
439+
TYPE: "string",
440+
"pattern": r"[\S\s]*",
441+
"maxLength": 512,
442+
}
443+
},
444+
"maxProperties": 48,
445+
},
446+
# Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_S3DataSource.html#sagemaker-Type-S3DataSource-S3Uri
447+
"s3Uri": {TYPE: "string", "pattern": "^(https|s3)://([^/]+)/?(.*)$", "maxLength": 1024},
448+
# Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_AlgorithmSpecification.html#sagemaker-Type-AlgorithmSpecification-ContainerEntrypoint
449+
"preExecutionCommand": {TYPE: "string", "pattern": r".*"},
380450
},
381451
PROPERTIES: {
382452
SCHEMA_VERSION: {
@@ -406,6 +476,36 @@ def _simple_path(*args: str):
406476
# Any SageMaker Python SDK specific configuration will be added here.
407477
TYPE: OBJECT,
408478
ADDITIONAL_PROPERTIES: False,
479+
PROPERTIES: {
480+
REMOTE_FUNCTION: {
481+
TYPE: OBJECT,
482+
ADDITIONAL_PROPERTIES: False,
483+
PROPERTIES: {
484+
DEPENDENCIES: {TYPE: "string"},
485+
PRE_EXECUTION_COMMANDS: {
486+
TYPE: "array",
487+
"items": {"$ref": "#/definitions/preExecutionCommand"},
488+
},
489+
PRE_EXECUTION_SCRIPT: {TYPE: "string"},
490+
ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION: {
491+
TYPE: "boolean"
492+
},
493+
ENVIRONMENT_VARIABLES: {
494+
"$ref": "#/definitions/environmentVariables"
495+
},
496+
IMAGE_URI: {TYPE: "string"},
497+
INCLUDE_LOCAL_WORKDIR: {TYPE: "boolean"},
498+
INSTANCE_TYPE: {TYPE: "string"},
499+
JOB_CONDA_ENV: {TYPE: "string"},
500+
ROLE_ARN: {"$ref": "#/definitions/roleArn"},
501+
S3_KMS_KEY_ID: {"$ref": "#/definitions/kmsKeyId"},
502+
S3_ROOT_URI: {"$ref": "#/definitions/s3Uri"},
503+
TAGS: {"$ref": "#/definitions/tags"},
504+
VOLUME_KMS_KEY_ID: {"$ref": "#/definitions/kmsKeyId"},
505+
VPC_CONFIG: {"$ref": "#/definitions/vpcConfig"},
506+
},
507+
}
508+
},
409509
}
410510
},
411511
},

src/sagemaker/experiments/run.py

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -715,6 +715,14 @@ def __exit__(self, exc_type, exc_value, exc_traceback):
715715

716716
self.close()
717717

718+
def __getstate__(self):
719+
"""Overriding this method to prevent instance of Run from being pickled.
720+
721+
Raise:
722+
NotImplementedError: If attempting to pickle this instance.
723+
"""
724+
raise NotImplementedError("Instance of Run type is not allowed to be pickled.")
725+
718726

719727
def load_run(
720728
run_name: Optional[str] = None,
@@ -787,36 +795,38 @@ def load_run(
787795
Returns:
788796
Run: The loaded Run object.
789797
"""
790-
sagemaker_session = sagemaker_session or _utils.default_session()
791798
environment = _RunEnvironment.load()
792799

793800
verify_load_input_names(run_name=run_name, experiment_name=experiment_name)
794801

795-
if run_name or environment:
796-
if run_name:
797-
logger.warning(
798-
"run_name is explicitly supplied in load_run, "
799-
"which will be prioritized to load the Run object. "
800-
"In other words, the run name in the experiment config, fetched from the "
801-
"job environment or the current run context, will be ignored."
802-
)
803-
else:
804-
exp_config = get_tc_and_exp_config_from_job_env(
805-
environment=environment, sagemaker_session=sagemaker_session
806-
)
807-
run_name = Run._extract_run_name_from_tc_name(
808-
trial_component_name=exp_config[RUN_NAME],
809-
experiment_name=exp_config[EXPERIMENT_NAME],
810-
)
811-
experiment_name = exp_config[EXPERIMENT_NAME]
812-
802+
if run_name:
803+
logger.warning(
804+
"run_name is explicitly supplied in load_run, "
805+
"which will be prioritized to load the Run object. "
806+
"In other words, the run name in the experiment config, fetched from the "
807+
"job environment or the current run context, will be ignored."
808+
)
813809
run_instance = Run(
814810
experiment_name=experiment_name,
815811
run_name=run_name,
816-
sagemaker_session=sagemaker_session,
812+
sagemaker_session=sagemaker_session or _utils.default_session(),
817813
)
818814
elif _RunContext.get_current_run():
819815
run_instance = _RunContext.get_current_run()
816+
elif environment:
817+
exp_config = get_tc_and_exp_config_from_job_env(
818+
environment=environment, sagemaker_session=sagemaker_session or _utils.default_session()
819+
)
820+
run_name = Run._extract_run_name_from_tc_name(
821+
trial_component_name=exp_config[RUN_NAME],
822+
experiment_name=exp_config[EXPERIMENT_NAME],
823+
)
824+
experiment_name = exp_config[EXPERIMENT_NAME]
825+
run_instance = Run(
826+
experiment_name=experiment_name,
827+
run_name=run_name,
828+
sagemaker_session=sagemaker_session or _utils.default_session(),
829+
)
820830
else:
821831
raise RuntimeError(
822832
"Failed to load a Run object. "
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
{
2+
"versions": {
3+
"1.0": {
4+
"registries": {
5+
"us-east-2": "429704687514",
6+
"me-south-1": "117516905037",
7+
"us-west-2": "236514542706",
8+
"ca-central-1": "310906938811",
9+
"ap-east-1": "493642496378",
10+
"us-east-1": "081325390199",
11+
"ap-northeast-2": "806072073708",
12+
"eu-west-2": "712779665605",
13+
"ap-southeast-2": "52832661640",
14+
"cn-northwest-1": "390780980154",
15+
"eu-north-1": "243637512696",
16+
"cn-north-1": "390048526115",
17+
"ap-south-1": "394103062818",
18+
"eu-west-3": "615547856133",
19+
"ap-southeast-3": "276181064229",
20+
"af-south-1": "559312083959",
21+
"eu-west-1": "470317259841",
22+
"eu-central-1": "936697816551",
23+
"sa-east-1": "782484402741",
24+
"ap-northeast-3": "792733760839",
25+
"eu-south-1": "592751261982",
26+
"ap-northeast-1": "102112518831",
27+
"us-west-1": "742091327244",
28+
"ap-southeast-1": "492261229750",
29+
"me-central-1": "103105715889",
30+
"us-gov-east-1": "107072934176",
31+
"us-gov-west-1": "107173498710"
32+
},
33+
"repository": "sagemaker-base-python"
34+
}
35+
}
36+
}

0 commit comments

Comments
 (0)