|
13 | 13 | from __future__ import absolute_import
|
14 | 14 |
|
15 | 15 | import pytest
|
16 |
| -from mock import Mock |
| 16 | +from mock import Mock, patch |
17 | 17 | from sagemaker import AutoML, AutoMLJob, AutoMLInput, CandidateEstimator
|
18 | 18 |
|
19 | 19 | MODEL_DATA = "s3://bucket/model.tar.gz"
|
20 | 20 | MODEL_IMAGE = "mi"
|
21 | 21 | ENTRY_POINT = "blah.py"
|
22 | 22 |
|
| 23 | +TIMESTAMP = "2017-11-06-14:14:15.671" |
23 | 24 | BUCKET_NAME = "mybucket"
|
24 | 25 | INSTANCE_COUNT = 1
|
25 | 26 | INSTANCE_TYPE = "ml.c5.2xlarge"
|
|
31 | 32 | DEFAULT_OUTPUT_PATH = "s3://{}/".format(BUCKET_NAME)
|
32 | 33 | LOCAL_DATA_PATH = "file://data"
|
33 | 34 | DEFAULT_MAX_CANDIDATES = 500
|
| 35 | +DEFAULT_JOB_NAME = "automl-{}".format(TIMESTAMP) |
34 | 36 |
|
35 | 37 | JOB_NAME = "default-job-name"
|
36 | 38 | JOB_NAME_2 = "banana-auto-ml-job"
|
@@ -281,34 +283,38 @@ def test_auto_ml_additional_optional_params(sagemaker_session):
|
281 | 283 | }
|
282 | 284 |
|
283 | 285 |
|
284 |
| -def test_auto_ml_default_fit(sagemaker_session): |
| 286 | +@patch("time.strftime", return_value=TIMESTAMP) |
| 287 | +def test_auto_ml_default_fit(strftime, sagemaker_session): |
285 | 288 | auto_ml = AutoML(
|
286 | 289 | role=ROLE, target_attribute_name=TARGET_ATTRIBUTE_NAME, sagemaker_session=sagemaker_session
|
287 | 290 | )
|
288 | 291 | inputs = DEFAULT_S3_INPUT_DATA
|
289 | 292 | auto_ml.fit(inputs)
|
290 | 293 | sagemaker_session.auto_ml.assert_called_once()
|
291 | 294 | _, args = sagemaker_session.auto_ml.call_args
|
292 |
| - assert args["input_config"] == [ |
293 |
| - { |
294 |
| - "DataSource": { |
295 |
| - "S3DataSource": {"S3DataType": "S3Prefix", "S3Uri": DEFAULT_S3_INPUT_DATA} |
| 295 | + assert args == { |
| 296 | + "input_config": [ |
| 297 | + { |
| 298 | + "DataSource": { |
| 299 | + "S3DataSource": {"S3DataType": "S3Prefix", "S3Uri": DEFAULT_S3_INPUT_DATA} |
| 300 | + }, |
| 301 | + "TargetAttributeName": TARGET_ATTRIBUTE_NAME, |
| 302 | + } |
| 303 | + ], |
| 304 | + "output_config": {"S3OutputPath": DEFAULT_OUTPUT_PATH}, |
| 305 | + "auto_ml_job_config": { |
| 306 | + "CompletionCriteria": {"MaxCandidates": DEFAULT_MAX_CANDIDATES}, |
| 307 | + "SecurityConfig": { |
| 308 | + "EnableInterContainerTrafficEncryption": ENCRYPT_INTER_CONTAINER_TRAFFIC |
296 | 309 | },
|
297 |
| - "TargetAttributeName": TARGET_ATTRIBUTE_NAME, |
298 |
| - } |
299 |
| - ] |
300 |
| - assert args["output_config"] == {"S3OutputPath": DEFAULT_OUTPUT_PATH} |
301 |
| - assert args["auto_ml_job_config"] == { |
302 |
| - "CompletionCriteria": {"MaxCandidates": DEFAULT_MAX_CANDIDATES}, |
303 |
| - "SecurityConfig": { |
304 |
| - "EnableInterContainerTrafficEncryption": ENCRYPT_INTER_CONTAINER_TRAFFIC |
305 | 310 | },
|
| 311 | + "role": ROLE, |
| 312 | + "job_name": DEFAULT_JOB_NAME, |
| 313 | + "problem_type": None, |
| 314 | + "job_objective": None, |
| 315 | + "generate_candidate_definitions_only": GENERATE_CANDIDATE_DEFINITIONS_ONLY, |
| 316 | + "tags": None, |
306 | 317 | }
|
307 |
| - assert args["role"] == ROLE |
308 |
| - assert args["problem_type"] is None |
309 |
| - assert args["job_objective"] is None |
310 |
| - assert args["generate_candidate_definitions_only"] == GENERATE_CANDIDATE_DEFINITIONS_ONLY |
311 |
| - assert args["tags"] is None |
312 | 318 |
|
313 | 319 |
|
314 | 320 | def test_auto_ml_local_input(sagemaker_session):
|
|
0 commit comments