Skip to content

Commit a42fff5

Browse files
committed
Add AutoMLV2 support
1 parent a4ef985 commit a42fff5

29 files changed

+16007
-4
lines changed

doc/api/training/automlv2.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
AutoMLV2
2+
--------
3+
4+
.. automodule:: sagemaker.automl.automlv2
5+
:members:
6+
:undoc-members:
7+
:show-inheritance:

doc/api/training/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ Training APIs
88
algorithm
99
analytics
1010
automl
11+
automlv2
1112
debugger
1213
estimators
1314
tuner

src/sagemaker/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,17 @@
6161

6262
from sagemaker.automl.automl import AutoML, AutoMLJob, AutoMLInput # noqa: F401
6363
from sagemaker.automl.candidate_estimator import CandidateEstimator, CandidateStep # noqa: F401
64+
from sagemaker.automl.automlv2 import ( # noqa: F401
65+
AutoMLV2,
66+
AutoMLJobV2,
67+
LocalAutoMLDataChannel,
68+
AutoMLDataChannel,
69+
AutoMLTimeSeriesForecastingConfig,
70+
AutoMLImageClassificationConfig,
71+
AutoMLTabularConfig,
72+
AutoMLTextClassificationConfig,
73+
AutoMLTextGenerationConfig,
74+
)
6475

6576
from sagemaker.debugger import ProfilerConfig, Profiler # noqa: F401
6677

src/sagemaker/automl/automlv2.py

Lines changed: 1432 additions & 0 deletions
Large diffs are not rendered by default.

src/sagemaker/config/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
AUTO_ML_OUTPUT_CONFIG_PATH,
5151
AUTO_ML_JOB_CONFIG_PATH,
5252
AUTO_ML_JOB,
53+
AUTO_ML_JOB_V2,
5354
COMPILATION_JOB_ROLE_ARN_PATH,
5455
COMPILATION_JOB_OUTPUT_CONFIG_PATH,
5556
COMPILATION_JOB_VPC_CONFIG_PATH,

src/sagemaker/config/config_schema.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
ENDPOINT = "Endpoint"
8484
INFERENCE_COMPONENT = "InferenceComponent"
8585
AUTO_ML_JOB = "AutoMLJob"
86+
AUTO_ML_JOB_V2 = "AutoMLJobV2"
8687
COMPILATION_JOB = "CompilationJob"
8788
CUSTOM_PARAMETERS = "CustomParameters"
8889
PIPELINE = "Pipeline"

src/sagemaker/session.py

Lines changed: 265 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2570,12 +2570,273 @@ def logs_for_auto_ml_job( # noqa: C901 - suppress complexity warning for this m
25702570
exceptions.UnexpectedStatusException: If waiting and auto ml job fails.
25712571
"""
25722572

2573-
description = _wait_until(lambda: self.describe_auto_ml_job(job_name), poll)
2573+
description = _wait_until(lambda: self.describe_auto_ml_job_v2(job_name), poll)
25742574

2575-
instance_count, stream_names, positions, client, log_group, dot, color_wrap = _logs_init(
2576-
self.boto_session, description, job="AutoML"
2575+
(
2576+
instance_count,
2577+
stream_names,
2578+
positions,
2579+
client,
2580+
log_group,
2581+
dot,
2582+
color_wrap,
2583+
) = _logs_init(self.boto_session, description, job="AutoML")
2584+
2585+
state = _get_initial_job_state(description, "AutoMLJobStatus", wait)
2586+
2587+
# The loop below implements a state machine that alternates between checking the job status
2588+
# and reading whatever is available in the logs at this point. Note, that if we were
2589+
# called with wait == False, we never check the job status.
2590+
#
2591+
# If wait == TRUE and job is not completed, the initial state is TAILING
2592+
# If wait == FALSE, the initial state is COMPLETE (doesn't matter if the job really is
2593+
# complete).
2594+
#
2595+
# The state table:
2596+
#
2597+
# STATE ACTIONS CONDITION NEW STATE
2598+
# ---------------- ---------------- ----------------- ----------------
2599+
# TAILING Read logs, Pause, Get status Job complete JOB_COMPLETE
2600+
# Else TAILING
2601+
# JOB_COMPLETE Read logs, Pause Any COMPLETE
2602+
# COMPLETE Read logs, Exit N/A
2603+
#
2604+
# Notes:
2605+
# - The JOB_COMPLETE state forces us to do an extra pause and read any items that got to
2606+
# Cloudwatch after the job was marked complete.
2607+
last_describe_job_call = time.time()
2608+
while True:
2609+
_flush_log_streams(
2610+
stream_names,
2611+
instance_count,
2612+
client,
2613+
log_group,
2614+
job_name,
2615+
positions,
2616+
dot,
2617+
color_wrap,
2618+
)
2619+
if state == LogState.COMPLETE:
2620+
break
2621+
2622+
time.sleep(poll)
2623+
2624+
if state == LogState.JOB_COMPLETE:
2625+
state = LogState.COMPLETE
2626+
elif time.time() - last_describe_job_call >= 30:
2627+
description = self.sagemaker_client.describe_auto_ml_job_v2(AutoMLJobName=job_name)
2628+
last_describe_job_call = time.time()
2629+
2630+
status = description["AutoMLJobStatus"]
2631+
2632+
if status in ("Completed", "Failed", "Stopped"):
2633+
print()
2634+
state = LogState.JOB_COMPLETE
2635+
2636+
if wait:
2637+
_check_job_status(job_name, description, "AutoMLJobStatus")
2638+
if dot:
2639+
print()
2640+
2641+
def create_auto_ml_v2(
2642+
self,
2643+
input_config,
2644+
job_name,
2645+
problem_config,
2646+
output_config,
2647+
job_objective=None,
2648+
model_deploy_config=None,
2649+
data_split_config=None,
2650+
role=None,
2651+
security_config=None,
2652+
tags=None,
2653+
):
2654+
"""Create an Amazon SageMaker AutoMLV2 job.
2655+
2656+
Args:
2657+
input_config (list[dict]): A list of AutoMLDataChannel objects.
2658+
Each channel contains "DataSource" and other optional fields.
2659+
job_name (str): A string that can be used to identify an AutoMLJob. Each AutoMLJob
2660+
should have a unique job name.
2661+
problem_config (object): A collection of settings specific
2662+
to the problem type used to configure an AutoML job V2.
2663+
There must be one and only one config of the following type.
2664+
Supported problem types are:
2665+
2666+
- Image Classification (sagemaker.automl.automlv2.ImageClassificationJobConfig),
2667+
- Tabular (sagemaker.automl.automlv2.TabularJobConfig),
2668+
- Text Classification (sagemaker.automl.automlv2.TextClassificationJobConfig),
2669+
- Text Generation (TextGenerationJobConfig),
2670+
- Time Series Forecasting (
2671+
sagemaker.automl.automlv2.TimeSeriesForecastingJobConfig).
2672+
2673+
output_config (dict): The S3 URI where you want to store the training results and
2674+
optional KMS key ID.
2675+
job_objective (dict): AutoMLJob objective, contains "AutoMLJobObjectiveType" (optional),
2676+
"MetricName" and "Value".
2677+
model_deploy_config (dict): Specifies how to generate the endpoint name
2678+
for an automatic one-click Autopilot model deployment.
2679+
Contains "AutoGenerateEndpointName" and "EndpointName"
2680+
data_split_config (dict): This structure specifies how to split the data
2681+
into train and validation datasets.
2682+
role (str): The Amazon Resource Name (ARN) of an IAM role that
2683+
Amazon SageMaker can assume to perform tasks on your behalf.
2684+
security_config (dict): The security configuration for traffic encryption
2685+
or Amazon VPC settings.
2686+
tags (Optional[Tags]): A list of dictionaries containing key-value
2687+
pairs.
2688+
"""
2689+
2690+
role = resolve_value_from_config(role, AUTO_ML_ROLE_ARN_PATH, sagemaker_session=self)
2691+
inferred_output_config = update_nested_dictionary_with_values_from_config(
2692+
output_config, AUTO_ML_OUTPUT_CONFIG_PATH, sagemaker_session=self
25772693
)
25782694

2695+
auto_ml_job_v2_request = self._get_auto_ml_request_v2(
2696+
input_config=input_config,
2697+
job_name=job_name,
2698+
problem_config=problem_config,
2699+
output_config=inferred_output_config,
2700+
role=role,
2701+
job_objective=job_objective,
2702+
model_deploy_config=model_deploy_config,
2703+
data_split_config=data_split_config,
2704+
security_config=security_config,
2705+
tags=format_tags(tags),
2706+
)
2707+
2708+
def submit(request):
2709+
logger.info("Creating auto-ml-v2-job with name: %s", job_name)
2710+
logger.debug("auto ml v2 request: %s", json.dumps(request), indent=4)
2711+
print(json.dumps(request))
2712+
self.sagemaker_client.create_auto_ml_job_v2(**request)
2713+
2714+
self._intercept_create_request(
2715+
auto_ml_job_v2_request, submit, self.create_auto_ml_v2.__name__
2716+
)
2717+
2718+
def _get_auto_ml_request_v2(
2719+
self,
2720+
input_config,
2721+
output_config,
2722+
job_name,
2723+
problem_config,
2724+
role,
2725+
job_objective=None,
2726+
model_deploy_config=None,
2727+
data_split_config=None,
2728+
security_config=None,
2729+
tags=None,
2730+
):
2731+
"""Constructs a request compatible for creating an Amazon SageMaker AutoML job.
2732+
2733+
Args:
2734+
input_config (list[dict]): A list of Channel objects. Each channel contains "DataSource"
2735+
and "TargetAttributeName", "CompressionType" and "SampleWeightAttributeName" are
2736+
optional fields.
2737+
output_config (dict): The S3 URI where you want to store the training results and
2738+
optional KMS key ID.
2739+
job_name (str): A string that can be used to identify an AutoMLJob. Each AutoMLJob
2740+
should have a unique job name.
2741+
problem_config (object): A collection of settings specific
2742+
to the problem type used to configure an AutoML job V2.
2743+
There must be one and only one config of the following type.
2744+
Supported problem types are:
2745+
2746+
- Image Classification (sagemaker.automl.automlv2.ImageClassificationJobConfig),
2747+
- Tabular (sagemaker.automl.automlv2.TabularJobConfig),
2748+
- Text Classification (sagemaker.automl.automlv2.TextClassificationJobConfig),
2749+
- Text Generation (TextGenerationJobConfig),
2750+
- Time Series Forecasting (
2751+
sagemaker.automl.automlv2.TimeSeriesForecastingJobConfig).
2752+
2753+
role (str): The Amazon Resource Name (ARN) of an IAM role that
2754+
Amazon SageMaker can assume to perform tasks on your behalf.
2755+
job_objective (dict): AutoMLJob objective, contains "AutoMLJobObjectiveType" (optional),
2756+
"MetricName" and "Value".
2757+
model_deploy_config (dict): Specifies how to generate the endpoint name
2758+
for an automatic one-click Autopilot model deployment.
2759+
Contains "AutoGenerateEndpointName" and "EndpointName"
2760+
data_split_config (dict): This structure specifies how to split the data
2761+
into train and validation datasets.
2762+
security_config (dict): The security configuration for traffic encryption
2763+
or Amazon VPC settings.
2764+
tags (Optional[Tags]): A list of dictionaries containing key-value
2765+
pairs.
2766+
2767+
Returns:
2768+
Dict: a automl v2 request dict
2769+
"""
2770+
auto_ml_job_v2_request = {
2771+
"AutoMLJobName": job_name,
2772+
"AutoMLJobInputDataConfig": input_config,
2773+
"OutputDataConfig": output_config,
2774+
"AutoMLProblemTypeConfig": problem_config,
2775+
"RoleArn": role,
2776+
}
2777+
if job_objective is not None:
2778+
auto_ml_job_v2_request["AutoMLJobObjective"] = job_objective
2779+
if model_deploy_config is not None:
2780+
auto_ml_job_v2_request["ModelDeployConfig"] = model_deploy_config
2781+
if data_split_config is not None:
2782+
auto_ml_job_v2_request["DataSplitConfig"] = data_split_config
2783+
if security_config is not None:
2784+
auto_ml_job_v2_request["SecurityConfig"] = security_config
2785+
2786+
tags = _append_project_tags(format_tags(tags))
2787+
tags = self._append_sagemaker_config_tags(
2788+
tags, "{}.{}.{}".format(SAGEMAKER, AUTO_ML_JOB, TAGS)
2789+
)
2790+
if tags is not None:
2791+
auto_ml_job_v2_request["Tags"] = tags
2792+
2793+
return auto_ml_job_v2_request
2794+
2795+
# Done
2796+
def describe_auto_ml_job_v2(self, job_name):
2797+
"""Calls the DescribeAutoMLJobV2 API for the given job name and returns the response.
2798+
2799+
Args:
2800+
job_name (str): The name of the AutoML job to describe.
2801+
2802+
Returns:
2803+
dict: A dictionary response with the AutoMLV2 Job description.
2804+
"""
2805+
return self.sagemaker_client.describe_auto_ml_job_v2(AutoMLJobName=job_name)
2806+
2807+
def logs_for_auto_ml_job_v2( # noqa: C901 - suppress complexity warning for this method
2808+
self, job_name, wait=False, poll=10
2809+
):
2810+
"""Display logs for a given AutoML V2 job, optionally tailing them until job is complete.
2811+
2812+
If the output is a tty or a Jupyter cell, it will be color-coded
2813+
based on which instance the log entry is from.
2814+
2815+
Args:
2816+
job_name (str): Name of the Auto ML V2 job to display the logs for.
2817+
wait (bool): Whether to keep looking for new log entries until the job completes
2818+
(default: False).
2819+
poll (int): The interval in seconds between polling for new log entries and job
2820+
completion (default: 5).
2821+
2822+
Raises:
2823+
exceptions.CapacityError: If waiting and auto ml job fails with CapacityError.
2824+
exceptions.UnexpectedStatusException: If waiting and auto ml job fails.
2825+
"""
2826+
2827+
# Why here is the method from the session and below is the method from the sagemaker client?
2828+
description = _wait_until(lambda: self.describe_auto_ml_job_v2(job_name), poll)
2829+
2830+
(
2831+
instance_count,
2832+
stream_names,
2833+
positions,
2834+
client,
2835+
log_group,
2836+
dot,
2837+
color_wrap,
2838+
) = _logs_init(self.boto_session, description, job="AutoML")
2839+
25792840
state = _get_initial_job_state(description, "AutoMLJobStatus", wait)
25802841

25812842
# The loop below implements a state machine that alternates between checking the job status
@@ -2618,7 +2879,7 @@ def logs_for_auto_ml_job( # noqa: C901 - suppress complexity warning for this m
26182879
if state == LogState.JOB_COMPLETE:
26192880
state = LogState.COMPLETE
26202881
elif time.time() - last_describe_job_call >= 30:
2621-
description = self.sagemaker_client.describe_auto_ml_job(AutoMLJobName=job_name)
2882+
description = self.sagemaker_client.describe_auto_ml_job_v2(AutoMLJobName=job_name)
26222883
last_describe_job_call = time.time()
26232884

26242885
status = description["AutoMLJobStatus"]
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading
Loading

0 commit comments

Comments
 (0)