18
18
import pytest
19
19
import numpy as np
20
20
from airflow import DAG
21
- from airflow .contrib .operators .sagemaker_training_operator import SageMakerTrainingOperator
22
- from airflow .contrib .operators .sagemaker_transform_operator import SageMakerTransformOperator
21
+ from airflow .contrib .operators .sagemaker_training_operator import (
22
+ SageMakerTrainingOperator ,
23
+ )
24
+ from airflow .contrib .operators .sagemaker_transform_operator import (
25
+ SageMakerTransformOperator ,
26
+ )
23
27
from six .moves .urllib .parse import urlparse
24
28
25
29
import tests .integ
@@ -69,7 +73,8 @@ def test_byo_airflow_config_uploads_data_source_to_s3_when_inputs_provided(
69
73
70
74
data_source_location = "test-airflow-config-{}" .format (sagemaker_timestamp ())
71
75
inputs = sagemaker_session .upload_data (
72
- path = training_data_path , key_prefix = os .path .join (data_source_location , "train" )
76
+ path = training_data_path ,
77
+ key_prefix = os .path .join (data_source_location , "train" ),
73
78
)
74
79
75
80
estimator = Estimator (
@@ -88,12 +93,16 @@ def test_byo_airflow_config_uploads_data_source_to_s3_when_inputs_provided(
88
93
89
94
_assert_that_s3_url_contains_data (
90
95
sagemaker_session ,
91
- training_config ["InputDataConfig" ][0 ]["DataSource" ]["S3DataSource" ]["S3Uri" ],
96
+ training_config ["InputDataConfig" ][0 ]["DataSource" ]["S3DataSource" ][
97
+ "S3Uri"
98
+ ],
92
99
)
93
100
94
101
95
102
@pytest .mark .canary_quick
96
- def test_kmeans_airflow_config_uploads_data_source_to_s3 (sagemaker_session , cpu_instance_type ):
103
+ def test_kmeans_airflow_config_uploads_data_source_to_s3 (
104
+ sagemaker_session , cpu_instance_type
105
+ ):
97
106
with timeout (seconds = AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS ):
98
107
kmeans = KMeans (
99
108
role = ROLE ,
@@ -121,11 +130,15 @@ def test_kmeans_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_
121
130
122
131
_assert_that_s3_url_contains_data (
123
132
sagemaker_session ,
124
- training_config ["InputDataConfig" ][0 ]["DataSource" ]["S3DataSource" ]["S3Uri" ],
133
+ training_config ["InputDataConfig" ][0 ]["DataSource" ]["S3DataSource" ][
134
+ "S3Uri"
135
+ ],
125
136
)
126
137
127
138
128
- def test_fm_airflow_config_uploads_data_source_to_s3 (sagemaker_session , cpu_instance_type ):
139
+ def test_fm_airflow_config_uploads_data_source_to_s3 (
140
+ sagemaker_session , cpu_instance_type
141
+ ):
129
142
with timeout (seconds = AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS ):
130
143
fm = FactorizationMachines (
131
144
role = ROLE ,
@@ -141,20 +154,26 @@ def test_fm_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_inst
141
154
)
142
155
143
156
training_set = datasets .one_p_mnist ()
144
- records = fm .record_set (training_set [0 ][:200 ], training_set [1 ][:200 ].astype ("float32" ))
157
+ records = fm .record_set (
158
+ training_set [0 ][:200 ], training_set [1 ][:200 ].astype ("float32" )
159
+ )
145
160
146
161
training_config = _build_airflow_workflow (
147
162
estimator = fm , instance_type = cpu_instance_type , inputs = records
148
163
)
149
164
150
165
_assert_that_s3_url_contains_data (
151
166
sagemaker_session ,
152
- training_config ["InputDataConfig" ][0 ]["DataSource" ]["S3DataSource" ]["S3Uri" ],
167
+ training_config ["InputDataConfig" ][0 ]["DataSource" ]["S3DataSource" ][
168
+ "S3Uri"
169
+ ],
153
170
)
154
171
155
172
156
173
@pytest .mark .canary_quick
157
- def test_ipinsights_airflow_config_uploads_data_source_to_s3 (sagemaker_session , cpu_instance_type ):
174
+ def test_ipinsights_airflow_config_uploads_data_source_to_s3 (
175
+ sagemaker_session , cpu_instance_type
176
+ ):
158
177
with timeout (seconds = AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS ):
159
178
data_path = os .path .join (DATA_DIR , "ipinsights" )
160
179
data_filename = "train.csv"
@@ -181,11 +200,15 @@ def test_ipinsights_airflow_config_uploads_data_source_to_s3(sagemaker_session,
181
200
182
201
_assert_that_s3_url_contains_data (
183
202
sagemaker_session ,
184
- training_config ["InputDataConfig" ][0 ]["DataSource" ]["S3DataSource" ]["S3Uri" ],
203
+ training_config ["InputDataConfig" ][0 ]["DataSource" ]["S3DataSource" ][
204
+ "S3Uri"
205
+ ],
185
206
)
186
207
187
208
188
- def test_knn_airflow_config_uploads_data_source_to_s3 (sagemaker_session , cpu_instance_type ):
209
+ def test_knn_airflow_config_uploads_data_source_to_s3 (
210
+ sagemaker_session , cpu_instance_type
211
+ ):
189
212
with timeout (seconds = AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS ):
190
213
knn = KNN (
191
214
role = ROLE ,
@@ -198,15 +221,19 @@ def test_knn_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_ins
198
221
)
199
222
200
223
training_set = datasets .one_p_mnist ()
201
- records = knn .record_set (training_set [0 ][:200 ], training_set [1 ][:200 ].astype ("float32" ))
224
+ records = knn .record_set (
225
+ training_set [0 ][:200 ], training_set [1 ][:200 ].astype ("float32" )
226
+ )
202
227
203
228
training_config = _build_airflow_workflow (
204
229
estimator = knn , instance_type = cpu_instance_type , inputs = records
205
230
)
206
231
207
232
_assert_that_s3_url_contains_data (
208
233
sagemaker_session ,
209
- training_config ["InputDataConfig" ][0 ]["DataSource" ]["S3DataSource" ]["S3Uri" ],
234
+ training_config ["InputDataConfig" ][0 ]["DataSource" ]["S3DataSource" ][
235
+ "S3Uri"
236
+ ],
210
237
)
211
238
212
239
@@ -215,7 +242,9 @@ def test_knn_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_ins
215
242
reason = "LDA image is not supported in certain regions" ,
216
243
)
217
244
@pytest .mark .canary_quick
218
- def test_lda_airflow_config_uploads_data_source_to_s3 (sagemaker_session , cpu_instance_type ):
245
+ def test_lda_airflow_config_uploads_data_source_to_s3 (
246
+ sagemaker_session , cpu_instance_type
247
+ ):
219
248
with timeout (seconds = AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS ):
220
249
data_path = os .path .join (DATA_DIR , "lda" )
221
250
data_filename = "nips-train_1.pbr"
@@ -234,16 +263,25 @@ def test_lda_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_ins
234
263
)
235
264
236
265
records = prepare_record_set_from_local_files (
237
- data_path , lda .data_location , len (all_records ), feature_num , sagemaker_session
266
+ data_path ,
267
+ lda .data_location ,
268
+ len (all_records ),
269
+ feature_num ,
270
+ sagemaker_session ,
238
271
)
239
272
240
273
training_config = _build_airflow_workflow (
241
- estimator = lda , instance_type = cpu_instance_type , inputs = records , mini_batch_size = 100
274
+ estimator = lda ,
275
+ instance_type = cpu_instance_type ,
276
+ inputs = records ,
277
+ mini_batch_size = 100 ,
242
278
)
243
279
244
280
_assert_that_s3_url_contains_data (
245
281
sagemaker_session ,
246
- training_config ["InputDataConfig" ][0 ]["DataSource" ]["S3DataSource" ]["S3Uri" ],
282
+ training_config ["InputDataConfig" ][0 ]["DataSource" ]["S3DataSource" ][
283
+ "S3Uri"
284
+ ],
247
285
)
248
286
249
287
@@ -308,12 +346,16 @@ def test_linearlearner_airflow_config_uploads_data_source_to_s3(
308
346
309
347
_assert_that_s3_url_contains_data (
310
348
sagemaker_session ,
311
- training_config ["InputDataConfig" ][0 ]["DataSource" ]["S3DataSource" ]["S3Uri" ],
349
+ training_config ["InputDataConfig" ][0 ]["DataSource" ]["S3DataSource" ][
350
+ "S3Uri"
351
+ ],
312
352
)
313
353
314
354
315
355
@pytest .mark .canary_quick
316
- def test_ntm_airflow_config_uploads_data_source_to_s3 (sagemaker_session , cpu_instance_type ):
356
+ def test_ntm_airflow_config_uploads_data_source_to_s3 (
357
+ sagemaker_session , cpu_instance_type
358
+ ):
317
359
with timeout (seconds = AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS ):
318
360
data_path = os .path .join (DATA_DIR , "ntm" )
319
361
data_filename = "nips-train_1.pbr"
@@ -333,7 +375,11 @@ def test_ntm_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_ins
333
375
)
334
376
335
377
records = prepare_record_set_from_local_files (
336
- data_path , ntm .data_location , len (all_records ), feature_num , sagemaker_session
378
+ data_path ,
379
+ ntm .data_location ,
380
+ len (all_records ),
381
+ feature_num ,
382
+ sagemaker_session ,
337
383
)
338
384
339
385
training_config = _build_airflow_workflow (
@@ -342,12 +388,16 @@ def test_ntm_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_ins
342
388
343
389
_assert_that_s3_url_contains_data (
344
390
sagemaker_session ,
345
- training_config ["InputDataConfig" ][0 ]["DataSource" ]["S3DataSource" ]["S3Uri" ],
391
+ training_config ["InputDataConfig" ][0 ]["DataSource" ]["S3DataSource" ][
392
+ "S3Uri"
393
+ ],
346
394
)
347
395
348
396
349
397
@pytest .mark .canary_quick
350
- def test_pca_airflow_config_uploads_data_source_to_s3 (sagemaker_session , cpu_instance_type ):
398
+ def test_pca_airflow_config_uploads_data_source_to_s3 (
399
+ sagemaker_session , cpu_instance_type
400
+ ):
351
401
with timeout (seconds = AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS ):
352
402
pca = PCA (
353
403
role = ROLE ,
@@ -369,12 +419,16 @@ def test_pca_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_ins
369
419
370
420
_assert_that_s3_url_contains_data (
371
421
sagemaker_session ,
372
- training_config ["InputDataConfig" ][0 ]["DataSource" ]["S3DataSource" ]["S3Uri" ],
422
+ training_config ["InputDataConfig" ][0 ]["DataSource" ]["S3DataSource" ][
423
+ "S3Uri"
424
+ ],
373
425
)
374
426
375
427
376
428
@pytest .mark .canary_quick
377
- def test_rcf_airflow_config_uploads_data_source_to_s3 (sagemaker_session , cpu_instance_type ):
429
+ def test_rcf_airflow_config_uploads_data_source_to_s3 (
430
+ sagemaker_session , cpu_instance_type
431
+ ):
378
432
with timeout (seconds = AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS ):
379
433
# Generate a thousand 14-dimensional datapoints.
380
434
feature_num = 14
@@ -398,13 +452,18 @@ def test_rcf_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_ins
398
452
399
453
_assert_that_s3_url_contains_data (
400
454
sagemaker_session ,
401
- training_config ["InputDataConfig" ][0 ]["DataSource" ]["S3DataSource" ]["S3Uri" ],
455
+ training_config ["InputDataConfig" ][0 ]["DataSource" ]["S3DataSource" ][
456
+ "S3Uri"
457
+ ],
402
458
)
403
459
404
460
405
461
@pytest .mark .canary_quick
406
462
def test_chainer_airflow_config_uploads_data_source_to_s3 (
407
- sagemaker_local_session , cpu_instance_type , chainer_latest_version , chainer_latest_py_version
463
+ sagemaker_local_session ,
464
+ cpu_instance_type ,
465
+ chainer_latest_version ,
466
+ chainer_latest_py_version ,
408
467
):
409
468
with timeout (seconds = AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS ):
410
469
script_path = os .path .join (DATA_DIR , "chainer_mnist" , "mnist.py" )
@@ -498,10 +557,12 @@ def test_sklearn_airflow_config_uploads_data_source_to_s3(
498
557
)
499
558
500
559
train_input = sklearn .sagemaker_session .upload_data (
501
- path = os .path .join (data_path , "train" ), key_prefix = "integ-test-data/sklearn_mnist/train"
560
+ path = os .path .join (data_path , "train" ),
561
+ key_prefix = "integ-test-data/sklearn_mnist/train" ,
502
562
)
503
563
test_input = sklearn .sagemaker_session .upload_data (
504
- path = os .path .join (data_path , "test" ), key_prefix = "integ-test-data/sklearn_mnist/test"
564
+ path = os .path .join (data_path , "test" ),
565
+ key_prefix = "integ-test-data/sklearn_mnist/test" ,
505
566
)
506
567
507
568
training_config = _build_airflow_workflow (
@@ -537,7 +598,8 @@ def test_tf_airflow_config_uploads_data_source_to_s3(
537
598
],
538
599
)
539
600
inputs = tf .sagemaker_session .upload_data (
540
- path = os .path .join (TF_MNIST_RESOURCE_PATH , "data" ), key_prefix = "scriptmode/mnist"
601
+ path = os .path .join (TF_MNIST_RESOURCE_PATH , "data" ),
602
+ key_prefix = "scriptmode/mnist" ,
541
603
)
542
604
543
605
training_config = _build_airflow_workflow (
@@ -552,13 +614,13 @@ def test_tf_airflow_config_uploads_data_source_to_s3(
552
614
553
615
@pytest .mark .canary_quick
554
616
def test_xgboost_airflow_config_uploads_data_source_to_s3 (
555
- sagemaker_session , cpu_instance_type , xgboost_latest_version , xgboost_latest_py_version
617
+ sagemaker_session , cpu_instance_type , xgboost_latest_version
556
618
):
557
619
with timeout (seconds = AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS ):
558
620
xgboost = XGBoost (
559
621
entry_point = os .path .join (DATA_DIR , "dummy_script.py" ),
560
622
framework_version = xgboost_latest_version ,
561
- py_version = xgboost_latest_py_version ,
623
+ py_version = "py3" ,
562
624
role = ROLE ,
563
625
sagemaker_session = sagemaker_session ,
564
626
instance_type = cpu_instance_type ,
@@ -613,7 +675,9 @@ def _assert_that_s3_url_contains_data(sagemaker_session, s3_url):
613
675
assert s3_request ["KeyCount" ] > 0
614
676
615
677
616
- def _build_airflow_workflow (estimator , instance_type , inputs = None , mini_batch_size = None ):
678
+ def _build_airflow_workflow (
679
+ estimator , instance_type , inputs = None , mini_batch_size = None
680
+ ):
617
681
training_config = sm_airflow .training_config (
618
682
estimator = estimator , inputs = inputs , mini_batch_size = mini_batch_size
619
683
)
@@ -642,14 +706,19 @@ def _build_airflow_workflow(estimator, instance_type, inputs=None, mini_batch_si
642
706
"provide_context" : True ,
643
707
}
644
708
645
- dag = DAG ("tensorflow_example" , default_args = default_args , schedule_interval = "@once" )
709
+ dag = DAG (
710
+ "tensorflow_example" , default_args = default_args , schedule_interval = "@once"
711
+ )
646
712
647
713
train_op = SageMakerTrainingOperator (
648
714
task_id = "tf_training" , config = training_config , wait_for_completion = True , dag = dag
649
715
)
650
716
651
717
transform_op = SageMakerTransformOperator (
652
- task_id = "transform_operator" , config = transform_config , wait_for_completion = True , dag = dag
718
+ task_id = "transform_operator" ,
719
+ config = transform_config ,
720
+ wait_for_completion = True ,
721
+ dag = dag ,
653
722
)
654
723
655
724
transform_op .set_upstream (train_op )
0 commit comments