Skip to content

Commit 9141233

Browse files
committed
Remove xgboost_latest_py_version fixture from airflow config test
1 parent e7bfaa1 commit 9141233

File tree

1 file changed

+104
-35
lines changed

1 file changed

+104
-35
lines changed

tests/integ/test_airflow_config.py

Lines changed: 104 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,12 @@
1818
import pytest
1919
import numpy as np
2020
from airflow import DAG
21-
from airflow.contrib.operators.sagemaker_training_operator import SageMakerTrainingOperator
22-
from airflow.contrib.operators.sagemaker_transform_operator import SageMakerTransformOperator
21+
from airflow.contrib.operators.sagemaker_training_operator import (
22+
SageMakerTrainingOperator,
23+
)
24+
from airflow.contrib.operators.sagemaker_transform_operator import (
25+
SageMakerTransformOperator,
26+
)
2327
from six.moves.urllib.parse import urlparse
2428

2529
import tests.integ
@@ -69,7 +73,8 @@ def test_byo_airflow_config_uploads_data_source_to_s3_when_inputs_provided(
6973

7074
data_source_location = "test-airflow-config-{}".format(sagemaker_timestamp())
7175
inputs = sagemaker_session.upload_data(
72-
path=training_data_path, key_prefix=os.path.join(data_source_location, "train")
76+
path=training_data_path,
77+
key_prefix=os.path.join(data_source_location, "train"),
7378
)
7479

7580
estimator = Estimator(
@@ -88,12 +93,16 @@ def test_byo_airflow_config_uploads_data_source_to_s3_when_inputs_provided(
8893

8994
_assert_that_s3_url_contains_data(
9095
sagemaker_session,
91-
training_config["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"],
96+
training_config["InputDataConfig"][0]["DataSource"]["S3DataSource"][
97+
"S3Uri"
98+
],
9299
)
93100

94101

95102
@pytest.mark.canary_quick
96-
def test_kmeans_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_instance_type):
103+
def test_kmeans_airflow_config_uploads_data_source_to_s3(
104+
sagemaker_session, cpu_instance_type
105+
):
97106
with timeout(seconds=AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS):
98107
kmeans = KMeans(
99108
role=ROLE,
@@ -121,11 +130,15 @@ def test_kmeans_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_
121130

122131
_assert_that_s3_url_contains_data(
123132
sagemaker_session,
124-
training_config["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"],
133+
training_config["InputDataConfig"][0]["DataSource"]["S3DataSource"][
134+
"S3Uri"
135+
],
125136
)
126137

127138

128-
def test_fm_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_instance_type):
139+
def test_fm_airflow_config_uploads_data_source_to_s3(
140+
sagemaker_session, cpu_instance_type
141+
):
129142
with timeout(seconds=AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS):
130143
fm = FactorizationMachines(
131144
role=ROLE,
@@ -141,20 +154,26 @@ def test_fm_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_inst
141154
)
142155

143156
training_set = datasets.one_p_mnist()
144-
records = fm.record_set(training_set[0][:200], training_set[1][:200].astype("float32"))
157+
records = fm.record_set(
158+
training_set[0][:200], training_set[1][:200].astype("float32")
159+
)
145160

146161
training_config = _build_airflow_workflow(
147162
estimator=fm, instance_type=cpu_instance_type, inputs=records
148163
)
149164

150165
_assert_that_s3_url_contains_data(
151166
sagemaker_session,
152-
training_config["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"],
167+
training_config["InputDataConfig"][0]["DataSource"]["S3DataSource"][
168+
"S3Uri"
169+
],
153170
)
154171

155172

156173
@pytest.mark.canary_quick
157-
def test_ipinsights_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_instance_type):
174+
def test_ipinsights_airflow_config_uploads_data_source_to_s3(
175+
sagemaker_session, cpu_instance_type
176+
):
158177
with timeout(seconds=AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS):
159178
data_path = os.path.join(DATA_DIR, "ipinsights")
160179
data_filename = "train.csv"
@@ -181,11 +200,15 @@ def test_ipinsights_airflow_config_uploads_data_source_to_s3(sagemaker_session,
181200

182201
_assert_that_s3_url_contains_data(
183202
sagemaker_session,
184-
training_config["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"],
203+
training_config["InputDataConfig"][0]["DataSource"]["S3DataSource"][
204+
"S3Uri"
205+
],
185206
)
186207

187208

188-
def test_knn_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_instance_type):
209+
def test_knn_airflow_config_uploads_data_source_to_s3(
210+
sagemaker_session, cpu_instance_type
211+
):
189212
with timeout(seconds=AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS):
190213
knn = KNN(
191214
role=ROLE,
@@ -198,15 +221,19 @@ def test_knn_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_ins
198221
)
199222

200223
training_set = datasets.one_p_mnist()
201-
records = knn.record_set(training_set[0][:200], training_set[1][:200].astype("float32"))
224+
records = knn.record_set(
225+
training_set[0][:200], training_set[1][:200].astype("float32")
226+
)
202227

203228
training_config = _build_airflow_workflow(
204229
estimator=knn, instance_type=cpu_instance_type, inputs=records
205230
)
206231

207232
_assert_that_s3_url_contains_data(
208233
sagemaker_session,
209-
training_config["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"],
234+
training_config["InputDataConfig"][0]["DataSource"]["S3DataSource"][
235+
"S3Uri"
236+
],
210237
)
211238

212239

@@ -215,7 +242,9 @@ def test_knn_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_ins
215242
reason="LDA image is not supported in certain regions",
216243
)
217244
@pytest.mark.canary_quick
218-
def test_lda_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_instance_type):
245+
def test_lda_airflow_config_uploads_data_source_to_s3(
246+
sagemaker_session, cpu_instance_type
247+
):
219248
with timeout(seconds=AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS):
220249
data_path = os.path.join(DATA_DIR, "lda")
221250
data_filename = "nips-train_1.pbr"
@@ -234,16 +263,25 @@ def test_lda_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_ins
234263
)
235264

236265
records = prepare_record_set_from_local_files(
237-
data_path, lda.data_location, len(all_records), feature_num, sagemaker_session
266+
data_path,
267+
lda.data_location,
268+
len(all_records),
269+
feature_num,
270+
sagemaker_session,
238271
)
239272

240273
training_config = _build_airflow_workflow(
241-
estimator=lda, instance_type=cpu_instance_type, inputs=records, mini_batch_size=100
274+
estimator=lda,
275+
instance_type=cpu_instance_type,
276+
inputs=records,
277+
mini_batch_size=100,
242278
)
243279

244280
_assert_that_s3_url_contains_data(
245281
sagemaker_session,
246-
training_config["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"],
282+
training_config["InputDataConfig"][0]["DataSource"]["S3DataSource"][
283+
"S3Uri"
284+
],
247285
)
248286

249287

@@ -308,12 +346,16 @@ def test_linearlearner_airflow_config_uploads_data_source_to_s3(
308346

309347
_assert_that_s3_url_contains_data(
310348
sagemaker_session,
311-
training_config["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"],
349+
training_config["InputDataConfig"][0]["DataSource"]["S3DataSource"][
350+
"S3Uri"
351+
],
312352
)
313353

314354

315355
@pytest.mark.canary_quick
316-
def test_ntm_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_instance_type):
356+
def test_ntm_airflow_config_uploads_data_source_to_s3(
357+
sagemaker_session, cpu_instance_type
358+
):
317359
with timeout(seconds=AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS):
318360
data_path = os.path.join(DATA_DIR, "ntm")
319361
data_filename = "nips-train_1.pbr"
@@ -333,7 +375,11 @@ def test_ntm_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_ins
333375
)
334376

335377
records = prepare_record_set_from_local_files(
336-
data_path, ntm.data_location, len(all_records), feature_num, sagemaker_session
378+
data_path,
379+
ntm.data_location,
380+
len(all_records),
381+
feature_num,
382+
sagemaker_session,
337383
)
338384

339385
training_config = _build_airflow_workflow(
@@ -342,12 +388,16 @@ def test_ntm_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_ins
342388

343389
_assert_that_s3_url_contains_data(
344390
sagemaker_session,
345-
training_config["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"],
391+
training_config["InputDataConfig"][0]["DataSource"]["S3DataSource"][
392+
"S3Uri"
393+
],
346394
)
347395

348396

349397
@pytest.mark.canary_quick
350-
def test_pca_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_instance_type):
398+
def test_pca_airflow_config_uploads_data_source_to_s3(
399+
sagemaker_session, cpu_instance_type
400+
):
351401
with timeout(seconds=AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS):
352402
pca = PCA(
353403
role=ROLE,
@@ -369,12 +419,16 @@ def test_pca_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_ins
369419

370420
_assert_that_s3_url_contains_data(
371421
sagemaker_session,
372-
training_config["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"],
422+
training_config["InputDataConfig"][0]["DataSource"]["S3DataSource"][
423+
"S3Uri"
424+
],
373425
)
374426

375427

376428
@pytest.mark.canary_quick
377-
def test_rcf_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_instance_type):
429+
def test_rcf_airflow_config_uploads_data_source_to_s3(
430+
sagemaker_session, cpu_instance_type
431+
):
378432
with timeout(seconds=AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS):
379433
# Generate a thousand 14-dimensional datapoints.
380434
feature_num = 14
@@ -398,13 +452,18 @@ def test_rcf_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_ins
398452

399453
_assert_that_s3_url_contains_data(
400454
sagemaker_session,
401-
training_config["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"],
455+
training_config["InputDataConfig"][0]["DataSource"]["S3DataSource"][
456+
"S3Uri"
457+
],
402458
)
403459

404460

405461
@pytest.mark.canary_quick
406462
def test_chainer_airflow_config_uploads_data_source_to_s3(
407-
sagemaker_local_session, cpu_instance_type, chainer_latest_version, chainer_latest_py_version
463+
sagemaker_local_session,
464+
cpu_instance_type,
465+
chainer_latest_version,
466+
chainer_latest_py_version,
408467
):
409468
with timeout(seconds=AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS):
410469
script_path = os.path.join(DATA_DIR, "chainer_mnist", "mnist.py")
@@ -498,10 +557,12 @@ def test_sklearn_airflow_config_uploads_data_source_to_s3(
498557
)
499558

500559
train_input = sklearn.sagemaker_session.upload_data(
501-
path=os.path.join(data_path, "train"), key_prefix="integ-test-data/sklearn_mnist/train"
560+
path=os.path.join(data_path, "train"),
561+
key_prefix="integ-test-data/sklearn_mnist/train",
502562
)
503563
test_input = sklearn.sagemaker_session.upload_data(
504-
path=os.path.join(data_path, "test"), key_prefix="integ-test-data/sklearn_mnist/test"
564+
path=os.path.join(data_path, "test"),
565+
key_prefix="integ-test-data/sklearn_mnist/test",
505566
)
506567

507568
training_config = _build_airflow_workflow(
@@ -537,7 +598,8 @@ def test_tf_airflow_config_uploads_data_source_to_s3(
537598
],
538599
)
539600
inputs = tf.sagemaker_session.upload_data(
540-
path=os.path.join(TF_MNIST_RESOURCE_PATH, "data"), key_prefix="scriptmode/mnist"
601+
path=os.path.join(TF_MNIST_RESOURCE_PATH, "data"),
602+
key_prefix="scriptmode/mnist",
541603
)
542604

543605
training_config = _build_airflow_workflow(
@@ -552,13 +614,13 @@ def test_tf_airflow_config_uploads_data_source_to_s3(
552614

553615
@pytest.mark.canary_quick
554616
def test_xgboost_airflow_config_uploads_data_source_to_s3(
555-
sagemaker_session, cpu_instance_type, xgboost_latest_version, xgboost_latest_py_version
617+
sagemaker_session, cpu_instance_type, xgboost_latest_version
556618
):
557619
with timeout(seconds=AIRFLOW_CONFIG_TIMEOUT_IN_SECONDS):
558620
xgboost = XGBoost(
559621
entry_point=os.path.join(DATA_DIR, "dummy_script.py"),
560622
framework_version=xgboost_latest_version,
561-
py_version=xgboost_latest_py_version,
623+
py_version="py3",
562624
role=ROLE,
563625
sagemaker_session=sagemaker_session,
564626
instance_type=cpu_instance_type,
@@ -613,7 +675,9 @@ def _assert_that_s3_url_contains_data(sagemaker_session, s3_url):
613675
assert s3_request["KeyCount"] > 0
614676

615677

616-
def _build_airflow_workflow(estimator, instance_type, inputs=None, mini_batch_size=None):
678+
def _build_airflow_workflow(
679+
estimator, instance_type, inputs=None, mini_batch_size=None
680+
):
617681
training_config = sm_airflow.training_config(
618682
estimator=estimator, inputs=inputs, mini_batch_size=mini_batch_size
619683
)
@@ -642,14 +706,19 @@ def _build_airflow_workflow(estimator, instance_type, inputs=None, mini_batch_si
642706
"provide_context": True,
643707
}
644708

645-
dag = DAG("tensorflow_example", default_args=default_args, schedule_interval="@once")
709+
dag = DAG(
710+
"tensorflow_example", default_args=default_args, schedule_interval="@once"
711+
)
646712

647713
train_op = SageMakerTrainingOperator(
648714
task_id="tf_training", config=training_config, wait_for_completion=True, dag=dag
649715
)
650716

651717
transform_op = SageMakerTransformOperator(
652-
task_id="transform_operator", config=transform_config, wait_for_completion=True, dag=dag
718+
task_id="transform_operator",
719+
config=transform_config,
720+
wait_for_completion=True,
721+
dag=dag,
653722
)
654723

655724
transform_op.set_upstream(train_op)

0 commit comments

Comments
 (0)