Skip to content

Commit 301cbdd

Browse files
authored
Modify Asserts to Work with TF 2.1.0 and TF 2.0.0 (aws#380)
1 parent b1f8756 commit 301cbdd

File tree

4 files changed

+28
-18
lines changed

4 files changed

+28
-18
lines changed

tests/tensorflow2/test_keras.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,7 @@ def test_include_regex(out_dir, tf_eager_mode):
558558

559559
tr = create_trial_fast_refresh(out_dir)
560560
tnames = tr.tensor_names(collection="custom_coll")
561-
assert len(tnames) == 12
561+
assert len(tnames) == (12 if is_tf_2_2() else 4)
562562
for tname in tnames:
563563
assert tr.tensor(tname).value(0) is not None
564564

@@ -729,10 +729,7 @@ def test_keras_fit_pure_eager(out_dir, tf_eager_mode):
729729
helper_keras_fit(trial_dir=out_dir, hook=hook, eager=tf_eager_mode, run_eagerly=True)
730730

731731
trial = smd.create_trial(path=out_dir)
732-
if is_tf_2_2():
733-
assert len(trial.tensor_names()) == 27
734-
else:
735-
assert len(trial.tensor_names()) == (20 if is_tf_2_3() else 21)
732+
assert len(trial.tensor_names()) == (27 if is_tf_2_2() else 13)
736733
assert len(trial.tensor_names(collection=CollectionKeys.BIASES)) == 2
737734
assert len(trial.tensor_names(collection=CollectionKeys.WEIGHTS)) == 2
738735
assert len(trial.tensor_names(collection=CollectionKeys.OPTIMIZER_VARIABLES)) == 5

tests/tensorflow2/test_model_subclassing.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import tensorflow as tf
33
from tensorflow.keras.layers import BatchNormalization, Conv2D, Dense, Flatten
44
from tensorflow.keras.models import Model
5+
from tests.tensorflow2.utils import is_tf_2_2
56

67
# First Party
78
import smdebug.tensorflow as smd
@@ -78,7 +79,12 @@ def test_subclassed_model(out_dir):
7879
trial = smd.create_trial(out_dir)
7980
assert len(trial.tensor_names(collection=smd.CollectionKeys.LAYERS)) == 8
8081

81-
assert trial.tensor_names(collection=smd.CollectionKeys.INPUTS) == ["model_input"]
82-
assert trial.tensor_names(collection=smd.CollectionKeys.OUTPUTS) == ["labels", "predictions"]
8382
assert trial.tensor_names(collection=smd.CollectionKeys.LOSSES) == ["loss"]
84-
assert len(trial.tensor_names(collection=smd.CollectionKeys.GRADIENTS)) == 6
83+
if is_tf_2_2():
84+
# Feature to save model inputs and outputs was first added for TF 2.2.0
85+
assert trial.tensor_names(collection=smd.CollectionKeys.INPUTS) == ["model_input"]
86+
assert trial.tensor_names(collection=smd.CollectionKeys.OUTPUTS) == [
87+
"labels",
88+
"predictions",
89+
]
90+
assert len(trial.tensor_names(collection=smd.CollectionKeys.GRADIENTS)) == 6

tests/tensorflow2/test_support_dicts.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Third Party
22
import numpy as np
3+
import pytest
34
import tensorflow as tf
5+
from tests.tensorflow2.utils import is_tf_2_2
46

57
# First Party
68
import smdebug.tensorflow as smd
@@ -29,6 +31,10 @@ def create_model():
2931
return model
3032

3133

34+
@pytest.mark.skipif(
35+
is_tf_2_2() is False,
36+
reason="Feature to save model inputs and outputs was first added for TF 2.2.0",
37+
)
3238
def test_support_dicts(out_dir):
3339
model = create_model()
3440
optimizer = tf.keras.optimizers.Adadelta(lr=1.0, rho=0.95, epsilon=None, decay=0.0)
Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,32 @@
11
# Third Party
22
import numpy as np
3-
from packaging import version
3+
import pytest
44
from tensorflow.python.framework.dtypes import _NP_TO_TF
5+
from tests.tensorflow2.utils import is_tf_2_2
56

67
# First Party
78
from smdebug.core.tfevent.util import _get_proto_dtype
89

910

11+
@pytest.mark.skipif(
12+
is_tf_2_2() is False, reason="Brain Float Is Unavailable in lower versions of TF"
13+
)
1014
def test_tensorflow2_datatypes():
1115
# _NP_TO_TF contains all the mappings
1216
# of numpy to tf types
1317
try:
14-
from tensorflow import __version__ as tf_version
18+
from tensorflow.python import _pywrap_bfloat16
1519

16-
if version.parse(tf_version) >= version.parse("2.0.0"):
17-
from tensorflow.python import _pywrap_bfloat16
18-
19-
# TF 2.x.x Implements a Custom Numpy Datatype for Brain Floating Type
20-
# Which is currently only supported on TPUs
21-
_np_bfloat16 = _pywrap_bfloat16.TF_bfloat16_type()
22-
_NP_TO_TF.pop(_np_bfloat16)
20+
# TF 2.x.x Implements a Custom Numpy Datatype for Brain Floating Type
21+
# Which is currently only supported on TPUs
22+
_np_bfloat16 = _pywrap_bfloat16.TF_bfloat16_type()
23+
_NP_TO_TF.pop(_np_bfloat16)
2324
except (ModuleNotFoundError, ValueError, ImportError):
2425
pass
2526

2627
for _type in _NP_TO_TF:
2728
try:
2829
_get_proto_dtype(np.dtype(_type))
2930
except Exception:
30-
assert False
31+
assert False, f"{_type} not supported"
3132
assert True

0 commit comments

Comments
 (0)