|
16 | 16 | MISSING_EVENT_FILE_RETRY_LIMIT,
|
17 | 17 | MISSING_EVENT_FILE_RETRY_LIMIT_KEY,
|
18 | 18 | )
|
19 |
| -from smdebug.core.locations import IndexFileLocationUtils, TensorLocation |
| 19 | +from smdebug.core.locations import IndexFileLocationUtils, TensorLocation, TensorShape |
20 | 20 | from smdebug.core.logger import get_logger
|
21 | 21 | from smdebug.core.modes import ModeKeys
|
22 | 22 | from smdebug.core.s3_utils import list_s3_objects
|
@@ -120,12 +120,22 @@ def fetch_tensor_value(self, tensor_location: TensorLocation):
|
120 | 120 | def list_event_files(self, start_after_prefix):
|
121 | 121 | pass
|
122 | 122 |
|
123 |
| - @abstractmethod |
124 | 123 | def load_tensor_data_from_index_files(
|
125 | 124 | self, start_after_key=None, range_steps=None
|
126 | 125 | ) -> Tuple[Dict[str, Dict[int, Dict[str, TensorLocation]]], str]:
|
127 | 126 | """Return a triply nested dict referring to tensor data."""
|
128 | 127 |
|
| 128 | + responses, steps, last_index_token, workers = self.read_index_files( |
| 129 | + start_after_key, range_steps |
| 130 | + ) |
| 131 | + |
| 132 | + tensor_data = {} |
| 133 | + for step, response, worker in zip(steps, responses, workers): |
| 134 | + tensor_data = self._update_tensors_from_json( |
| 135 | + tensor_data, step, response, self.path, worker |
| 136 | + ) |
| 137 | + return tensor_data, last_index_token |
| 138 | + |
129 | 139 | @abstractmethod
|
130 | 140 | def _is_event_file_present(self, file_name) -> bool:
|
131 | 141 | pass
|
@@ -203,8 +213,10 @@ def _validate(index_dict):
|
203 | 213 | raise IndexReaderException("meta section is not present")
|
204 | 214 | if len(index_dict["meta"]) == 0:
|
205 | 215 | raise IndexReaderException("meta section is empty")
|
206 |
| - if "tensor_payload" not in index_dict: |
207 |
| - raise IndexReaderException("tensor_payload section is not present") |
| 216 | + if "tensor_payload" not in index_dict and "shape_payload" not in index_dict: |
| 217 | + raise IndexReaderException( |
| 218 | + "neither tensor_payload nor shape_payload sections are present" |
| 219 | + ) |
208 | 220 |
|
209 | 221 | def _update_tensors_from_json(
|
210 | 222 | self, index_tensors_dict, step, response: bytes, path, worker
|
@@ -233,28 +245,41 @@ def _update_tensors_from_json(
|
233 | 245 | mode = index_meta["mode"]
|
234 | 246 | mode = ModeKeys[mode.strip()]
|
235 | 247 | mode_step = index_meta["mode_step"]
|
236 |
| - event_file_name = os.path.join(path, index_meta["event_file_name"]) |
237 |
| - tensors = index_dict["tensor_payload"] |
238 |
| - for tensor in tensors: |
239 |
| - tensor_name = tensor["tensorname"] |
240 |
| - start_idx = tensor["start_idx"] |
241 |
| - length = tensor["length"] |
242 |
| - tensor_location = TensorLocation( |
243 |
| - tensor_name, mode, mode_step, event_file_name, start_idx, length, worker |
244 |
| - ) |
| 248 | + |
| 249 | + to_update_index_dict = [] |
| 250 | + |
| 251 | + if "tensor_payload" in index_dict and len(index_dict["tensor_payload"]): |
| 252 | + event_file_name = os.path.join(path, index_meta["event_file_name"]) |
| 253 | + for tensor in index_dict["tensor_payload"]: |
| 254 | + tensor_name = tensor["tensorname"] |
| 255 | + start_idx = tensor["start_idx"] |
| 256 | + length = tensor["length"] |
| 257 | + tensor_location = TensorLocation( |
| 258 | + tensor_name, mode, mode_step, event_file_name, start_idx, length, worker |
| 259 | + ) |
| 260 | + to_update_index_dict.append((tensor_name, step, tensor_location)) |
| 261 | + |
| 262 | + if "shape_payload" in index_dict and len(index_dict["shape_payload"]): |
| 263 | + for tensor in index_dict["shape_payload"]: |
| 264 | + tensor_name = tensor["tensorname"] |
| 265 | + original_name = tensor["originalname"] |
| 266 | + shape = tensor["shape"] |
| 267 | + ts = TensorShape(tensor_name, mode, mode_step, shape, original_name) |
| 268 | + to_update_index_dict.append((tensor_name, step, ts)) |
| 269 | + |
| 270 | + for tu in to_update_index_dict: |
| 271 | + tensor_name, step, obj = tu |
| 272 | + if isinstance(obj, TensorLocation): |
| 273 | + obj_dict = {"tensor_location": obj} |
| 274 | + elif isinstance(obj, TensorShape): |
| 275 | + obj_dict = {"tensor_shape": obj} |
245 | 276 | if tensor_name in index_tensors_dict:
|
246 | 277 | if step in index_tensors_dict[tensor_name]:
|
247 |
| - index_tensors_dict[tensor_name][step].update( |
248 |
| - {worker: {"tensor_location": tensor_location}} |
249 |
| - ) |
| 278 | + index_tensors_dict[tensor_name][step].update({worker: obj_dict}) |
250 | 279 | else:
|
251 |
| - index_tensors_dict[tensor_name].update( |
252 |
| - {step: {worker: {"tensor_location": tensor_location}}} |
253 |
| - ) |
| 280 | + index_tensors_dict[tensor_name].update({step: {worker: obj_dict}}) |
254 | 281 | else:
|
255 |
| - index_tensors_dict[tensor_name] = { |
256 |
| - step: {worker: {"tensor_location": tensor_location}} |
257 |
| - } |
| 282 | + index_tensors_dict[tensor_name] = {step: {worker: obj_dict}} |
258 | 283 | return index_tensors_dict
|
259 | 284 |
|
260 | 285 |
|
@@ -285,22 +310,6 @@ def fetch_tensor_value(self, tensor_location: TensorLocation) -> np.ndarray:
|
285 | 310 | tensor_name, step, tensor_data, mode, mode_step = tensor_tuple
|
286 | 311 | return tensor_data
|
287 | 312 |
|
288 |
| - def load_tensor_data_from_index_files( |
289 |
| - self, start_after_key=None, range_steps=None |
290 |
| - ) -> Tuple[Dict[str, Dict[int, Dict[str, TensorLocation]]], str]: |
291 |
| - """Return a triply nested dict referring to tensor data.""" |
292 |
| - |
293 |
| - responses, steps, last_index_token, workers = self.read_index_files( |
294 |
| - start_after_key, range_steps |
295 |
| - ) |
296 |
| - |
297 |
| - tensor_data = {} |
298 |
| - for step, response, worker in zip(steps, responses, workers): |
299 |
| - tensor_data = self._update_tensors_from_json( |
300 |
| - tensor_data, step, response, self.path, worker |
301 |
| - ) |
302 |
| - return tensor_data, last_index_token |
303 |
| - |
304 | 313 | def read_index_files(
|
305 | 314 | self, start_after_key: str, range_steps=None
|
306 | 315 | ) -> Tuple[List[bytes], list, str, List[str]]:
|
@@ -398,21 +407,6 @@ def fetch_tensor_value(self, tensor_location: TensorLocation) -> np.ndarray:
|
398 | 407 | tensor_name, step, tensor_data, mode, mode_step = tensor_tuple
|
399 | 408 | return tensor_data
|
400 | 409 |
|
401 |
| - def load_tensor_data_from_index_files( |
402 |
| - self, start_after_key=None, range_steps=None |
403 |
| - ) -> Tuple[Dict[str, Dict[int, Dict[str, TensorLocation]]], str]: |
404 |
| - """Return a triply nested dict referring to tensor data.""" |
405 |
| - |
406 |
| - responses, steps, last_index_token, workers = self.read_index_files( |
407 |
| - start_after_key, range_steps |
408 |
| - ) |
409 |
| - tensor_data = {} |
410 |
| - for step, response, worker in zip(steps, responses, workers): |
411 |
| - tensor_data = self._update_tensors_from_json( |
412 |
| - tensor_data, step, response, self.path, worker |
413 |
| - ) |
414 |
| - return tensor_data, last_index_token |
415 |
| - |
416 | 410 | def read_index_files(
|
417 | 411 | self, start_after_key: str, range_steps=None
|
418 | 412 | ) -> Tuple[List[bytes], list, str, List[str]]:
|
|
0 commit comments