Skip to content

Commit b29c1a9

Browse files
authored
Remove duplicated tests in zero code change tests for TF, and make them pytest compatible (aws#110)
* make pytest compatible * Remove tests copy * Make pytorch zcc test pytest compatible * Add init file to make .tf_utils import succeed * Fix test after increase in default save interval
1 parent bf05d6f commit b29c1a9

File tree

6 files changed

+52
-1079
lines changed

6 files changed

+52
-1079
lines changed

tests/zero_code_change/pytorch_integration_tests.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from smdebug.core.utils import SagemakerSimulator, ScriptSimulator
2222

2323

24-
def test_pytorch(script_mode: bool, use_loss_module=False):
24+
def test_pytorch(script_mode: bool = False, use_loss_module=False):
2525
smd.del_hook()
2626

2727
sim_class = ScriptSimulator if script_mode else SagemakerSimulator
@@ -82,6 +82,10 @@ def test_pytorch(script_mode: bool, use_loss_module=False):
8282
)
8383

8484

85+
def test_pytorch_loss_module(script_mode: bool = False):
86+
test_pytorch(script_mode=script_mode, use_loss_module=True)
87+
88+
8589
if __name__ == "__main__":
8690
parser = argparse.ArgumentParser()
8791
parser.add_argument(

tests/zero_code_change/tensorflow_integration_tests.py

Lines changed: 47 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,13 @@
2222
import tensorflow_datasets as tfds
2323
from tests.tensorflow.hooks.test_mirrored_strategy import test_basic
2424
from tests.tensorflow.keras.test_keras_mirrored import test_tf_keras
25-
from tf_utils import (
25+
26+
# First Party
27+
import smdebug.tensorflow as smd
28+
from smdebug.core.utils import SagemakerSimulator
29+
30+
# Local
31+
from .tf_utils import (
2632
get_data,
2733
get_estimator,
2834
get_input_fns,
@@ -31,12 +37,8 @@
3137
get_train_op_and_placeholders,
3238
)
3339

34-
# First Party
35-
import smdebug.tensorflow as smd
36-
from smdebug.core.utils import SagemakerSimulator
3740

38-
39-
def test_estimator(script_mode: bool):
41+
def test_estimator(script_mode: bool = False):
4042
""" Works as intended. """
4143
smd.del_hook()
4244
tf.reset_default_graph()
@@ -134,7 +136,15 @@ def test_estimator_gradients_zcc(nested=False, mirrored=False):
134136
assert len(trial.modes()) == 2
135137

136138

137-
def test_linear_classifier(script_mode: bool):
139+
def test_estimator_gradients_zcc_nested():
140+
test_estimator_gradients_zcc(nested=True)
141+
142+
143+
def test_estimator_gradients_zcc_mirrored():
144+
test_estimator_gradients_zcc(nested=False, mirrored=True)
145+
146+
147+
def test_linear_classifier(script_mode: bool = False):
138148
""" Works as intended. """
139149
smd.del_hook()
140150
tf.reset_default_graph()
@@ -160,11 +170,20 @@ def test_linear_classifier(script_mode: bool):
160170
assert len(trial.tensor_names()) > 0, "Tensors were not saved."
161171

162172

163-
def test_monitored_session(script_mode: bool):
173+
def test_monitored_session(script_mode: bool = False):
164174
""" Works as intended. """
165175
smd.del_hook()
166176
tf.reset_default_graph()
167-
with SagemakerSimulator() as sim:
177+
json_file_contents = """
178+
{
179+
"S3OutputPath": "s3://sagemaker-test",
180+
"LocalPath": "/opt/ml/output/tensors",
181+
"HookParameters" : {
182+
"save_interval": "100"
183+
}
184+
}
185+
"""
186+
with SagemakerSimulator(json_file_contents=json_file_contents) as sim:
168187
train_op, X, Y = get_train_op_and_placeholders()
169188
init = tf.compat.v1.global_variables_initializer()
170189
mnist = get_data()
@@ -195,6 +214,9 @@ def test_monitored_session_gradients_zcc():
195214
{
196215
"S3OutputPath": "s3://sagemaker-test",
197216
"LocalPath": "/opt/ml/output/tensors",
217+
"HookParameters" : {
218+
"save_interval": "100"
219+
},
198220
"CollectionConfigurations": [
199221
{
200222
"CollectionName": "gradients"
@@ -227,7 +249,7 @@ def test_monitored_session_gradients_zcc():
227249
assert len(trial.tensor_names(collection="gradients")) > 0
228250

229251

230-
def test_keras_v1(script_mode: bool):
252+
def test_keras_v1(script_mode: bool = False):
231253
""" Works as intended. """
232254
smd.del_hook()
233255
tf.reset_default_graph()
@@ -258,7 +280,7 @@ def test_keras_v1(script_mode: bool):
258280
assert len(trial.tensor_names()) > 0, "Tensors were not saved."
259281

260282

261-
def test_keras_gradients(script_mode: bool, tf_optimizer: bool = False):
283+
def test_keras_gradients(script_mode: bool = False, tf_optimizer: bool = False):
262284
""" Works as intended. """
263285
smd.del_hook()
264286
tf.reset_default_graph()
@@ -320,6 +342,10 @@ def test_keras_gradients(script_mode: bool, tf_optimizer: bool = False):
320342
assert len(trial.tensor_names(collection="optimizer_variables")) > 0
321343

322344

345+
def test_keras_gradients_tf_opt(script_mode: bool = False):
346+
test_keras_gradients(script_mode=script_mode, tf_optimizer=True)
347+
348+
323349
def test_keras_gradients_mirrored(include_workers="one"):
324350
""" Works as intended. """
325351
smd.del_hook()
@@ -366,7 +392,11 @@ def test_keras_gradients_mirrored(include_workers="one"):
366392
test_tf_keras("/opt/ml/output/tensors", zcc=True, include_workers=include_workers)
367393

368394

369-
def test_keras_to_estimator(script_mode: bool):
395+
def test_keras_gradients_mirrored_all_workers():
396+
test_keras_gradients_mirrored(include_workers="all")
397+
398+
399+
def test_keras_to_estimator(script_mode: bool = False):
370400
""" Works as intended. """
371401
import tensorflow.compat.v1.keras as keras
372402

@@ -426,14 +456,14 @@ def input_fn():
426456
test_monitored_session_gradients_zcc()
427457
test_estimator(script_mode=script_mode)
428458
if not script_mode:
429-
test_estimator_gradients_zcc(nested=True)
430-
test_estimator_gradients_zcc(nested=False)
431-
test_estimator_gradients_zcc(nested=False, mirrored=True)
459+
test_estimator_gradients_zcc()
460+
test_estimator_gradients_zcc_nested()
461+
test_estimator_gradients_zcc_mirrored()
432462
test_linear_classifier(script_mode=script_mode)
433463
test_keras_v1(script_mode=script_mode)
434464
test_keras_gradients(script_mode=script_mode)
435-
test_keras_gradients(script_mode=script_mode, tf_optimizer=True)
465+
test_keras_gradients_tf_opt(script_mode=script_mode)
436466
test_keras_to_estimator(script_mode=script_mode)
437467
if not script_mode:
438-
test_keras_gradients_mirrored(include_workers="all")
468+
test_keras_gradients_mirrored_all_workers()
439469
test_keras_gradients_mirrored()

0 commit comments

Comments
 (0)