Skip to content

Commit 07b8c07

Browse files
authored
Support Inputs and Labels in the dict format (aws#345)
1 parent 3cfa8d5 commit 07b8c07

File tree

2 files changed

+47
-0
lines changed

2 files changed

+47
-0
lines changed

smdebug/tensorflow/keras.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import tensorflow.compat.v1 as tf
66
from tensorflow.python.distribute import values
77
from tensorflow.python.framework.indexed_slices import IndexedSlices
8+
from tensorflow.python.util import nest
89

910
# First Party
1011
from smdebug.core.modes import ModeKeys
@@ -497,6 +498,10 @@ def save_smdebug_logs(self, logs):
497498
else set()
498499
)
499500
for t_name, t_value in tensors_to_save:
501+
if isinstance(t_value, dict):
502+
# flatten the inputs and labels
503+
# since we cannot convert dicts into numpy
504+
t_value = nest.flatten(t_value)
500505
self._save_tensor_to_file(t_name, t_value, collections_to_write)
501506

502507
def _save_metrics(self, batch, logs, force_save=False):
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Third Party
2+
import numpy as np
3+
import tensorflow as tf
4+
5+
# First Party
6+
import smdebug.tensorflow as smd
7+
from smdebug.core.collection import CollectionKeys
8+
from smdebug.trials import create_trial
9+
10+
11+
def get_data():
12+
images = np.zeros((64, 224))
13+
labels = np.zeros((64, 5))
14+
inputs = {"Image_input": images}
15+
outputs = {"output-softmax": labels}
16+
return inputs, outputs
17+
18+
19+
def create_hook(trial_dir):
20+
hook = smd.KerasHook(trial_dir, save_all=True)
21+
return hook
22+
23+
24+
def create_model():
25+
input_layer = tf.keras.layers.Input(name="Image_input", shape=(224), dtype="float32")
26+
model = tf.keras.layers.Dense(5)(input_layer)
27+
model = tf.keras.layers.Activation("softmax", name="output-softmax")(model)
28+
model = tf.keras.models.Model(inputs=input_layer, outputs=[model])
29+
return model
30+
31+
32+
def test_support_dicts(out_dir):
33+
model = create_model()
34+
optimizer = tf.keras.optimizers.Adadelta(lr=1.0, rho=0.95, epsilon=None, decay=0.0)
35+
model.compile(loss="categorical_crossentropy", optimizer=optimizer)
36+
inputs, labels = get_data()
37+
smdebug_hook = create_hook(out_dir)
38+
model.fit(inputs, labels, batch_size=16, epochs=10, callbacks=[smdebug_hook])
39+
model.save(out_dir, save_format="tf")
40+
trial = create_trial(out_dir)
41+
assert trial.tensor_names(collection=CollectionKeys.INPUTS) == ["model_input"]
42+
assert trial.tensor_names(collection=CollectionKeys.OUTPUTS) == ["labels", "predictions"]

0 commit comments

Comments
 (0)