Skip to content

Commit 0dceec1

Browse files
authored
Add graph export support for MXNet and Pytorch (aws#247)
* Save histograms for weights and gradients * Use standard TF summary function * undo line break changes * fix cases when bool tensor was being passed to add_histogram, and fix tests * Fix region bug and update tb_writer construction * Include summaries if any write_histogram was set to True * Refactor writers in core * set default step to 0 * Use new writer in hook * Cherry picking change of refactor writers * set default step to 0 * remove histogram related stuff * rename IndexUtil * Fix imports * remove import of re * Fix import of summary proto * Fix step usage in writers * Fix step usage by event file writer * Remove direcotry in tensorboard directory, and add collection name as prefix for summaries created * Fix import errors * Fix resnet example which did not have str2bool args * Fix core test * Fix core test * Indentation and move some code to a new function * Merged Vikas' branch on tb data read * Add untested support to read tensorboard data * Write mode and mode_step for summaries, and fix the error of multiple global steps being assigned to same train step * remove unnecessary file * remove test script * Remove changes to imagenet script * working scalars * Change path of tornasole event files * Have new index file per mode for tensorboard events * Move tensor values to different file * move to outside tensors folder * Change frequencies for tf examples * Introduce CollectionKeys * Merging export as json * Make histogram a reduction config property, and add save_raw_tensor field to reduction config. Verified the usage for tensorflow. Also some cleanup with respect to save config in save manager * Fix bug in loading collections * Fix writing tensorboard data in global mode * Add graph support to pytorch models. Copied some new protos, and a couple of files from torch.tensorboard. * Working graph export for mxnet * Save graph correctly for mxnet * undo utils change worker pid * fix import * fix import * do not flush index writer * remove data files * Fix save config issue * make save_histogram a property of collection * Fix save config bugs, and add scalar support to TF * Skip summaries whose tensors are unreachable in graph, and avoid adding histogram when original collection is not included * Move histogram creation to writer instead of event_file_writer, refactor should_save_collection in save manager, add save_scalar methods to MXNet and Pytorch * WIP tensor scalar support * undo add of data * remove test * use correct writer * Make saving scalars work, and added type checks * Writing scalars and tensors supported. tested in tensorboard. need to test through trials * WIP testing steps * remove save scalar and tensor for now because of step number issues. work on trial loading tensorboard data and come back to this * Working reads in non index mode * Tensorboard reads working with indexing * cleanup index file location function * Make pytorch tests working * Reduce length of test_estimator_modes, and add tf tensorboard test * Add basic scalar summary test * Untested completed reads of tensorboard data * Add more tensorboard tests for trial * fix test when reading event files for tensorboard from s3 * Fixed a reduction test * Fix reduction test in TF * Fix merge of a test * fix logger import, and default save/reduction config in save manager * Fix reduction save_raw_tensor in TF * Some cleanup of prepare and collection includes * fix tf tests * Fix all tests * Add tensorboard index test * Fix tensorboard test wrt optimizer_variables * not save histogram for strings * remove when nan support * add hash * Fix collection checks in xgboost * add xgboost tests * Typo * Update hook.py (aws#243) * reduce length of test and add / to prefix * WIP move to tornasole hist summaries for TF * Change collections_to_save_for_step, make TF use custom histograms, refactor to _save_tensor method for all frameworks * rename to save_for_tensor * undo some files * undo some files * Update tests.sh * remove pytorch graph support * remove mxnet graph support * Revert "remove mxnet graph support" This reverts commit 56754da7b44ce7276cf6c9830fd7b0308061ef55. * Revert "remove pytorch graph support" This reverts commit d5c49def8fb369f95282b384dc0bc8a9928ae941. * remove old files * fix export of models * Create __init__.py
1 parent 4a9f80b commit 0dceec1

16 files changed

+1731
-29
lines changed

setup.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,14 @@
88

99

1010
def compile_summary_protobuf():
11-
proto_path = 'tornasole/core/tfevent/proto'
12-
proto_files = os.path.join(proto_path, '*.proto')
13-
cmd = 'protoc ' + proto_files + ' --python_out=.'
14-
print('compiling protobuf files in {}'.format(proto_path))
15-
return os.system('set -ex &&' + cmd)
11+
proto_paths = ['tornasole/core/tfevent/proto', 'tornasole/pytorch/proto']
12+
cmd = 'set -ex && protoc '
13+
for proto_path in proto_paths:
14+
proto_files = os.path.join(proto_path, '*.proto')
15+
cmd += proto_files + ' '
16+
print('compiling protobuf files in {}'.format(proto_path))
17+
cmd += ' --python_out=.'
18+
return os.system(cmd)
1619

1720

1821
def get_framework_packages(f):

tornasole/mxnet/graph.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
from mxnet.ndarray import NDArray
2+
from mxnet.symbol import Symbol
3+
from mxnet.gluon import HybridBlock
4+
import json
5+
6+
from tornasole.core.tfevent.proto.graph_pb2 import GraphDef
7+
from tornasole.core.tfevent.proto.node_def_pb2 import NodeDef
8+
from tornasole.core.tfevent.proto.versions_pb2 import VersionDef
9+
from tornasole.core.tfevent.proto.attr_value_pb2 import AttrValue
10+
11+
12+
def _scoped_name(scope_name, node_name):
13+
return '/'.join([scope_name, node_name])
14+
15+
16+
def _get_nodes_from_symbol(sym):
17+
"""Given a symbol and shapes, return a list of `NodeDef`s for visualizing the
18+
the graph in TensorBoard."""
19+
if not isinstance(sym, Symbol):
20+
raise TypeError('sym must be an `mxnet.symbol.Symbol`,'
21+
' received type {}'.format(str(type(sym))))
22+
conf = json.loads(sym.tojson())
23+
nodes = conf['nodes']
24+
data2op = {} # key: data id, value: list of ops to whom data is an input
25+
for i, node in enumerate(nodes):
26+
if node['op'] != 'null': # node is an operator
27+
input_list = node['inputs']
28+
for idx in input_list:
29+
if idx[0] == 0: # do not include 'data' node in the op scope
30+
continue
31+
if idx[0] in data2op:
32+
# nodes[idx[0]] is a data as an input to op nodes[i]
33+
data2op[idx[0]].append(i)
34+
else:
35+
data2op[idx[0]] = [i]
36+
37+
# In the following, we group data with operators they belong to
38+
# by attaching them with operator names as scope names.
39+
# The parameters with the operator name as the prefix will be
40+
# assigned with the scope name of that operator. For example,
41+
# a convolution op has name 'conv', while its weight and bias
42+
# have name 'conv_weight' and 'conv_bias'. In the end, the operator
43+
# has scope name 'conv' prepended to its name, i.e. 'conv/conv'.
44+
# The parameters are named 'conv/conv_weight' and 'conv/conv_bias'.
45+
node_defs = []
46+
for i, node in enumerate(nodes):
47+
node_name = node['name']
48+
op_name = node['op']
49+
kwargs = {'op': op_name, 'name': node_name}
50+
if op_name != 'null': # node is an operator
51+
inputs = []
52+
input_list = node['inputs']
53+
for idx in input_list:
54+
input_node = nodes[idx[0]]
55+
input_node_name = input_node['name']
56+
if input_node['op'] != 'null':
57+
inputs.append(_scoped_name(input_node_name, input_node_name))
58+
elif idx[0] in data2op and len(data2op[idx[0]]) == 1 and data2op[idx[0]][0] == i:
59+
# the data is only as an input to nodes[i], no else
60+
inputs.append(_scoped_name(node_name, input_node_name))
61+
else: # the data node has no scope name, e.g. 'data' as the input node
62+
inputs.append(input_node_name)
63+
kwargs['input'] = inputs
64+
kwargs['name'] = _scoped_name(node_name, node_name)
65+
elif i in data2op and len(data2op[i]) == 1:
66+
# node is a data node belonging to one op, find out which operator this node belongs to
67+
op_node_name = nodes[data2op[i][0]]['name']
68+
kwargs['name'] = _scoped_name(op_node_name, node_name)
69+
70+
if 'attrs' in node:
71+
# TensorBoard would escape quotation marks, replace it with space
72+
attr = json.dumps(node['attrs'], sort_keys=True).replace("\"", ' ')
73+
attr = {'param': AttrValue(s=attr.encode(encoding='utf-8'))}
74+
kwargs['attr'] = attr
75+
node_def = NodeDef(**kwargs)
76+
node_defs.append(node_def)
77+
return node_defs
78+
79+
80+
def _sym2pb(sym):
81+
"""Converts an MXNet symbol to its graph protobuf definition."""
82+
return GraphDef(node=_get_nodes_from_symbol(sym), versions=VersionDef(producer=100))
83+
84+
85+
def _net2pb(net):
86+
if isinstance(net, HybridBlock):
87+
# TODO(junwu): may need a more approprite way to get symbol from a HybridBlock
88+
if not net._cached_graph:
89+
raise RuntimeError(
90+
"Please first call net.hybridize() and then run forward with "
91+
"this net at least once before calling add_graph().")
92+
net = net._cached_graph[1]
93+
elif not isinstance(net, Symbol):
94+
raise TypeError('only accepts mxnet.gluon.HybridBlock and mxnet.symbol.Symbol '
95+
'as input network, received type {}'.format(str(type(net))))
96+
return _sym2pb(net)

tornasole/mxnet/hook.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from tornasole.mxnet.mxnet_collection import get_collection_manager
77
from tornasole.mxnet.singleton_utils import set_hook
88
from tornasole.mxnet.utils import get_reduction_of_data, make_numpy_array
9-
# from tornasole.mxnet.graph import _net2pb
9+
from tornasole.mxnet.graph import _net2pb
1010

1111
DEFAULT_INCLUDE_COLLECTIONS = [CollectionKeys.LOSSES]
1212

@@ -81,16 +81,15 @@ def log_param(self, param):
8181
tensor_value=param.grad(param.list_ctx()[0]))
8282

8383
def _export_model(self):
84-
pass
85-
# if self.model is not None:
86-
# try:
87-
# self._get_tb_writer().write_graph(_net2pb(self.model))
88-
# except (RuntimeError, TypeError) as e:
89-
# self.logger.warning(
90-
# f'Could not export model graph for tensorboard '
91-
# f'due to the mxnet exception: {e}')
92-
# else:
93-
# self.logger.warning('Tornasole does not know the model')
84+
if self.model is not None:
85+
try:
86+
self._get_tb_writer().write_graph(_net2pb(self.model))
87+
except (RuntimeError, TypeError) as e:
88+
self.logger.warning(
89+
f'Could not export model graph for tensorboard '
90+
f'due to the mxnet exception: {e}')
91+
else:
92+
self.logger.warning('Tornasole does not know the model')
9493

9594
# This hook is invoked by trainer prior to running the forward pass.
9695
def forward_pre_hook(self, block, inputs):
@@ -119,7 +118,6 @@ def forward_pre_hook(self, block, inputs):
119118

120119
if self.last_saved_step is not None and not self.exported_collections:
121120
self.export_collections()
122-
self._export_model()
123121
self.exported_collections = True
124122

125123
self.last_block = block

tornasole/pytorch/_proto_graph.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#Taken from https://github.com/pytorch/pytorch/blob/c749be9e9f8dd3db8b3582e93f917bd47e8e9e20/torch/utils/tensorboard/_proto_graph.py
2+
3+
from tornasole.core.tfevent.proto.node_def_pb2 import NodeDef
4+
from tornasole.core.tfevent.proto.attr_value_pb2 import AttrValue
5+
from tornasole.core.tfevent.proto.tensor_shape_pb2 import TensorShapeProto
6+
7+
8+
def attr_value_proto(dtype, shape, s):
9+
"""Creates a dict of objects matching
10+
https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/attr_value.proto
11+
specifically designed for a NodeDef. The values have been
12+
reverse engineered from standard TensorBoard logged data.
13+
"""
14+
attr = {}
15+
if s is not None:
16+
attr['attr'] = AttrValue(s=s.encode(encoding='utf_8'))
17+
if shape is not None:
18+
shapeproto = tensor_shape_proto(shape)
19+
attr['_output_shapes'] = AttrValue(list=AttrValue.ListValue(shape=[shapeproto]))
20+
return attr
21+
22+
23+
def tensor_shape_proto(outputsize):
24+
"""Creates an object matching
25+
https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/tensor_shape.proto
26+
"""
27+
return TensorShapeProto(dim=[TensorShapeProto.Dim(size=d) for d in outputsize])
28+
29+
30+
def node_proto(name,
31+
op='UnSpecified',
32+
input=None,
33+
dtype=None,
34+
shape=None, # type: tuple
35+
outputsize=None,
36+
attributes=''
37+
):
38+
"""Creates an object matching
39+
https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/node_def.proto
40+
"""
41+
if input is None:
42+
input = []
43+
if not isinstance(input, list):
44+
input = [input]
45+
return NodeDef(
46+
name=name.encode(encoding='utf_8'),
47+
op=op,
48+
input=input,
49+
attr=attr_value_proto(dtype, outputsize, attributes)
50+
)

0 commit comments

Comments
 (0)