Skip to content

Commit 7878abd

Browse files
Jaeyeon AhnEC2 Default UsercbechirJaeyeon
authored
feature: support for Sample Weights for SageMaker Autopilot (#3835)
Co-authored-by: EC2 Default User <[email protected]> Co-authored-by: Choucri Bechir <[email protected]> Co-authored-by: Jaeyeon <[email protected]>
1 parent af26174 commit 7878abd

File tree

5 files changed

+46
-3
lines changed

5 files changed

+46
-3
lines changed

src/sagemaker/automl/automl.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(
4949
channel_type=None,
5050
content_type=None,
5151
s3_data_type=None,
52+
sample_weight_attribute_name=None,
5253
):
5354
"""Convert an S3 Uri or a list of S3 Uri to an AutoMLInput object.
5455
@@ -67,13 +68,16 @@ def __init__(
6768
The content type of the data from the input source.
6869
s3_data_type (str, PipelineVariable): The data type for S3 data source.
6970
Valid values: ManifestFile or S3Prefix.
71+
sample_weight_attribute_name (str, PipelineVariable):
72+
the name of the dataset column representing sample weights
7073
"""
7174
self.inputs = inputs
7275
self.target_attribute_name = target_attribute_name
7376
self.compression = compression
7477
self.channel_type = channel_type
7578
self.content_type = content_type
7679
self.s3_data_type = s3_data_type
80+
self.sample_weight_attribute_name = sample_weight_attribute_name
7781

7882
def to_request_dict(self):
7983
"""Generates a request dictionary using the parameters provided to the class."""
@@ -96,6 +100,8 @@ def to_request_dict(self):
96100
input_entry["ContentType"] = self.content_type
97101
if self.s3_data_type is not None:
98102
input_entry["DataSource"]["S3DataSource"]["S3DataType"] = self.s3_data_type
103+
if self.sample_weight_attribute_name is not None:
104+
input_entry["SampleWeightAttributeName"] = self.sample_weight_attribute_name
99105
auto_ml_input.append(input_entry)
100106
return auto_ml_input
101107

@@ -129,6 +135,7 @@ def __init__(
129135
mode: Optional[str] = None,
130136
auto_generate_endpoint_name: Optional[bool] = None,
131137
endpoint_name: Optional[str] = None,
138+
sample_weight_attribute_name: str = None,
132139
):
133140
"""Initialize the an AutoML object.
134141
@@ -179,6 +186,8 @@ def __init__(
179186
model deployment if the endpoint name is not generated automatically.
180187
Specify the endpoint_name if and only if
181188
auto_generate_endpoint_name is set to False
189+
sample_weight_attribute_name (str): The name of dataset column representing
190+
sample weights.
182191
183192
Returns:
184193
AutoML object.
@@ -234,6 +243,7 @@ def __init__(
234243
)
235244

236245
self._check_problem_type_and_job_objective(self.problem_type, self.job_objective)
246+
self.sample_weight_attribute_name = sample_weight_attribute_name
237247

238248
@runnable_by_pipeline
239249
def fit(self, inputs=None, wait=True, logs=True, job_name=None):
@@ -342,6 +352,9 @@ def attach(cls, auto_ml_job_name, sagemaker_session=None):
342352
"AutoGenerateEndpointName", False
343353
),
344354
endpoint_name=auto_ml_job_desc.get("ModelDeployConfig", {}).get("EndpointName"),
355+
sample_weight_attribute_name=auto_ml_job_desc["InputDataConfig"][0].get(
356+
"SampleWeightAttributeName", None
357+
),
345358
)
346359
amlj.current_job_name = auto_ml_job_name
347360
amlj.latest_auto_ml_job = auto_ml_job_name # pylint: disable=W0201
@@ -867,6 +880,7 @@ def _load_config(cls, inputs, auto_ml, expand_role=True, validate_uri=True):
867880
auto_ml.target_attribute_name,
868881
auto_ml.content_type,
869882
auto_ml.s3_data_type,
883+
auto_ml.sample_weight_attribute_name,
870884
)
871885
output_config = _Job._prepare_output_config(auto_ml.output_path, auto_ml.output_kms_key)
872886

@@ -932,6 +946,7 @@ def _format_inputs_to_input_config(
932946
target_attribute_name=None,
933947
content_type=None,
934948
s3_data_type=None,
949+
sample_weight_attribute_name=None,
935950
):
936951
"""Convert inputs to AutoML InputDataConfig.
937952
@@ -961,6 +976,8 @@ def _format_inputs_to_input_config(
961976
channel["ContentType"] = content_type
962977
if s3_data_type is not None:
963978
channel["DataSource"]["S3DataSource"]["S3DataType"] = s3_data_type
979+
if sample_weight_attribute_name is not None:
980+
channel["SampleWeightAttributeName"] = sample_weight_attribute_name
964981
channels.append(channel)
965982
elif isinstance(inputs, list):
966983
for input_entry in inputs:
@@ -974,6 +991,8 @@ def _format_inputs_to_input_config(
974991
channel["ContentType"] = content_type
975992
if s3_data_type is not None:
976993
channel["DataSource"]["S3DataSource"]["S3DataType"] = s3_data_type
994+
if sample_weight_attribute_name is not None:
995+
channel["SampleWeightAttributeName"] = sample_weight_attribute_name
977996
channels.append(channel)
978997
else:
979998
msg = (

src/sagemaker/session.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2100,7 +2100,8 @@ def auto_ml(
21002100
21012101
Args:
21022102
input_config (list[dict]): A list of Channel objects. Each channel contains "DataSource"
2103-
and "TargetAttributeName", "CompressionType" is an optional field.
2103+
and "TargetAttributeName", "CompressionType" and "SampleWeightAttributeName" are
2104+
optional fields.
21042105
output_config (dict): The S3 URI where you want to store the training results and
21052106
optional KMS key ID.
21062107
auto_ml_job_config (dict): A dict of AutoMLJob config, containing "StoppingCondition",
@@ -2167,7 +2168,8 @@ def _get_auto_ml_request(
21672168
21682169
Args:
21692170
input_config (list[dict]): A list of Channel objects. Each channel contains "DataSource"
2170-
and "TargetAttributeName", "CompressionType" is an optional field.
2171+
and "TargetAttributeName", "CompressionType" and "SampleWeightAttributeName" are
2172+
optional fields.
21712173
output_config (dict): The S3 URI where you want to store the training results and
21722174
optional KMS key ID.
21732175
auto_ml_job_config (dict): A dict of AutoMLJob config, containing "StoppingCondition",

tests/unit/sagemaker/automl/test_auto_ml.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
RESOURCE_POOLS = [{"InstanceType": INSTANCE_TYPE, "PoolSize": INSTANCE_COUNT}]
3535
ROLE = "DummyRole"
3636
TARGET_ATTRIBUTE_NAME = "target"
37+
SAMPLE_WEIGHT_ATTRIBUTE_NAME = "sampleWeight"
3738
REGION = "us-west-2"
3839
DEFAULT_S3_INPUT_DATA = "s3://{}/data".format(BUCKET_NAME)
3940
DEFAULT_S3_VALIDATION_DATA = "s3://{}/validation_data".format(BUCKET_NAME)
@@ -103,6 +104,7 @@
103104
}
104105
},
105106
"TargetAttributeName": "y",
107+
"SampleWeightAttributeName": "sampleWeight",
106108
},
107109
{
108110
"ChannelType": "validation",
@@ -115,6 +117,7 @@
115117
}
116118
},
117119
"TargetAttributeName": "y",
120+
"SampleWeightAttributeName": "sampleWeight",
118121
},
119122
],
120123
"OutputDataConfig": {"KmsKeyId": "string", "S3OutputPath": "s3://output_prefix"},
@@ -362,10 +365,12 @@ def test_auto_ml_validation_channel_name(sagemaker_session):
362365
target_attribute_name="target",
363366
compression="Gzip",
364367
channel_type="training",
368+
sample_weight_attribute_name="sampleWeight",
365369
)
366370
input_validation = AutoMLInput(
367371
inputs=DEFAULT_S3_VALIDATION_DATA,
368372
target_attribute_name="target",
373+
sample_weight_attribute_name="sampleWeight",
369374
compression="Gzip",
370375
channel_type="validation",
371376
)
@@ -384,6 +389,7 @@ def test_auto_ml_validation_channel_name(sagemaker_session):
384389
}
385390
},
386391
"TargetAttributeName": TARGET_ATTRIBUTE_NAME,
392+
"SampleWeightAttributeName": SAMPLE_WEIGHT_ATTRIBUTE_NAME,
387393
},
388394
{
389395
"ChannelType": "validation",
@@ -395,6 +401,7 @@ def test_auto_ml_validation_channel_name(sagemaker_session):
395401
}
396402
},
397403
"TargetAttributeName": TARGET_ATTRIBUTE_NAME,
404+
"SampleWeightAttributeName": SAMPLE_WEIGHT_ATTRIBUTE_NAME,
398405
},
399406
]
400407

@@ -617,7 +624,10 @@ def test_auto_ml_local_input(sagemaker_session):
617624

618625
def test_auto_ml_input(sagemaker_session):
619626
inputs = AutoMLInput(
620-
inputs=DEFAULT_S3_INPUT_DATA, target_attribute_name="target", compression="Gzip"
627+
inputs=DEFAULT_S3_INPUT_DATA,
628+
target_attribute_name="target",
629+
compression="Gzip",
630+
sample_weight_attribute_name="sampleWeight",
621631
)
622632
auto_ml = AutoML(
623633
role=ROLE,
@@ -636,6 +646,7 @@ def test_auto_ml_input(sagemaker_session):
636646
}
637647
},
638648
"TargetAttributeName": TARGET_ATTRIBUTE_NAME,
649+
"SampleWeightAttributeName": SAMPLE_WEIGHT_ATTRIBUTE_NAME,
639650
}
640651
]
641652

tests/unit/sagemaker/workflow/test_automl_step.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,14 @@ def test_single_automl_step(pipeline_session):
5656
target_attribute_name="target",
5757
compression="Gzip",
5858
channel_type="training",
59+
sample_weight_attribute_name="sampleWeight",
5960
)
6061
input_validation = AutoMLInput(
6162
inputs="s3://bucket/validation_data",
6263
target_attribute_name="target",
6364
compression="Gzip",
6465
channel_type="validation",
66+
sample_weight_attribute_name="sampleWeight",
6567
)
6668
inputs = [input_training, input_validation]
6769

@@ -114,6 +116,7 @@ def test_single_automl_step(pipeline_session):
114116
}
115117
},
116118
"TargetAttributeName": "target",
119+
"SampleWeightAttributeName": "sampleWeight",
117120
},
118121
{
119122
"ChannelType": "validation",
@@ -125,6 +128,7 @@ def test_single_automl_step(pipeline_session):
125128
}
126129
},
127130
"TargetAttributeName": "target",
131+
"SampleWeightAttributeName": "sampleWeight",
128132
},
129133
],
130134
"OutputDataConfig": {

tests/unit/test_session.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3175,6 +3175,7 @@ def test_wait_until_fail_access_denied_after_5_mins(patched_sleep):
31753175
{
31763176
"DataSource": {"S3DataSource": {"S3DataType": "S3Prefix", "S3Uri": S3_INPUT_URI}},
31773177
"TargetAttributeName": "y",
3178+
"SampleWeightAttributeName": "sampleWeight",
31783179
}
31793180
],
31803181
"OutputDataConfig": {"S3OutputPath": S3_OUTPUT},
@@ -3202,6 +3203,7 @@ def test_wait_until_fail_access_denied_after_5_mins(patched_sleep):
32023203
}
32033204
},
32043205
"TargetAttributeName": "y",
3206+
"SampleWeightAttributeName": "sampleWeight",
32053207
},
32063208
{
32073209
"ChannelType": "validation",
@@ -3213,6 +3215,7 @@ def test_wait_until_fail_access_denied_after_5_mins(patched_sleep):
32133215
}
32143216
},
32153217
"TargetAttributeName": "y",
3218+
"SampleWeightAttributeName": "sampleWeight",
32163219
},
32173220
],
32183221
"OutputDataConfig": {"S3OutputPath": S3_OUTPUT},
@@ -3254,6 +3257,7 @@ def test_auto_ml_pack_to_request(sagemaker_session):
32543257
{
32553258
"DataSource": {"S3DataSource": {"S3DataType": "S3Prefix", "S3Uri": S3_INPUT_URI}},
32563259
"TargetAttributeName": "y",
3260+
"SampleWeightAttributeName": "sampleWeight",
32573261
}
32583262
]
32593263

@@ -3288,6 +3292,7 @@ def test_create_auto_ml_with_sagemaker_config_injection(sagemaker_session):
32883292
{
32893293
"DataSource": {"S3DataSource": {"S3DataType": "S3Prefix", "S3Uri": S3_INPUT_URI}},
32903294
"TargetAttributeName": "y",
3295+
"SampleWeightAttributeName": "sampleWeight",
32913296
}
32923297
]
32933298

@@ -3350,6 +3355,7 @@ def test_auto_ml_pack_to_request_with_optional_args(sagemaker_session):
33503355
}
33513356
},
33523357
"TargetAttributeName": "y",
3358+
"SampleWeightAttributeName": "sampleWeight",
33533359
},
33543360
{
33553361
"ChannelType": "validation",
@@ -3361,6 +3367,7 @@ def test_auto_ml_pack_to_request_with_optional_args(sagemaker_session):
33613367
}
33623368
},
33633369
"TargetAttributeName": "y",
3370+
"SampleWeightAttributeName": "sampleWeight",
33643371
},
33653372
]
33663373

0 commit comments

Comments
 (0)