Skip to content

Commit 871eb74

Browse files
authored
Add attributes for PT ZCC (aws#289)
1 parent 7d56497 commit 871eb74

File tree

4 files changed

+20
-1
lines changed

4 files changed

+20
-1
lines changed

tests/pytorch/test_distributed_training.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
"""
99
import numpy as nn
1010
import os
11+
import pytest
1112
import torch
1213
import torch.distributed as dist
1314
from torch.multiprocessing import Process
@@ -110,6 +111,7 @@ def init_processes(rank, size, fn, backend="gloo"):
110111
fn(rank, size)
111112

112113

114+
@pytest.mark.slow # 0:05 to run
113115
def test_run_net_single_process():
114116
"""Runs a single linear layer."""
115117
ts.reset_collections()

tests/pytorch/test_loss.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def forward(self, x):
3131
return x
3232

3333

34+
@pytest.mark.slow # 0:05 to run
3435
def test_register_loss():
3536
"""Test that the loss is saved as a tensor."""
3637
ts.reset_collections()

tornasole/pytorch/collection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def _register_default_collections(self):
2424
self.get(CollectionKeys.WEIGHTS).include("^(?!gradient).*weight")
2525
self.get(CollectionKeys.BIASES).include("^(?!gradient).*bias")
2626
self.get(CollectionKeys.GRADIENTS).include("^gradient")
27-
self.get(CollectionKeys.LOSSES).include("Loss")
27+
self.get(CollectionKeys.LOSSES).include("[Ll]oss")
2828

2929
def create_collection(self, name):
3030
super().create_collection(name, cls=Collection)

tornasole/pytorch/hook.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from copy import deepcopy
2+
import types
3+
from typing import Callable, Union
24
import torch
35
import torch.distributed as dist
46
from tornasole.core.json_config import (
@@ -53,6 +55,9 @@ def __init__(
5355
self.model = None
5456
self.exported_model = False
5557

58+
self.has_registered_module = False
59+
self.has_registered_loss_module = False
60+
5661
set_hook(self)
5762

5863
def get_num_workers(self):
@@ -155,6 +160,14 @@ def forward_pre_hook(self, module, inputs):
155160
self.export_collections()
156161
self.exported_collections = True
157162

163+
def record_tensor_value(self, tensor_name: str, tensor_value: torch.Tensor) -> None:
164+
"""Used for registering functional directly, such as F.mse_loss()."""
165+
assert isinstance(
166+
tensor_value, torch.Tensor
167+
), f"tensor_value={tensor_value} must be torch.Tensor"
168+
169+
self._write_outputs(tensor_name, tensor_value)
170+
158171
# This hook is invoked by trainer after running the forward pass.
159172
def forward_hook(self, module, inputs, outputs):
160173
if not self._get_collections_to_save_for_step():
@@ -228,6 +241,8 @@ def register_hook(self, module):
228241
# Capture the gradient for each parameter in the net
229242
self._backward_apply(module)
230243

244+
self.has_registered_module = True
245+
231246
def register_loss(self, loss_module):
232247
"""Register something like `criterion = nn.CrossEntropyLoss()`."""
233248
# Typechecking
@@ -240,6 +255,7 @@ def register_loss(self, loss_module):
240255
self.module_maps[loss_module] = name
241256
# Add a callback to the forward pass
242257
loss_module.register_forward_hook(self.forward_hook)
258+
self.has_registered_loss_module = True
243259

244260
@staticmethod
245261
def _get_reduction_of_data(reduction_name, tensor_value, tensor_name, abs):

0 commit comments

Comments
 (0)