Skip to content

Commit 045d89e

Browse files
lakshya97rahul003
authored andcommitted
making a quick variance fix (aws#99)
1 parent 3e28383 commit 045d89e

File tree

2 files changed

+14
-11
lines changed

2 files changed

+14
-11
lines changed

tests/pytorch/test_reduce_config.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def train(model, device, optimizer, num_steps=500, save_steps=[]):
5151

5252
def test_reduce_config():
5353
reset_collections()
54-
global_reduce_config = ReductionConfig(reductions=["max", "mean"])
54+
global_reduce_config = ReductionConfig(reductions=["max", "mean", "variance"])
5555
global_save_config = SaveConfig(save_steps=[0,1,2,3])
5656

5757
ts.get_collection("ReluActivation").include(["relu*"])
@@ -76,12 +76,14 @@ def test_reduce_config():
7676
tname = tr.tensors_matching_regex('Net_conv[0-9]+.weight')[0]
7777
print(tr.tensors())
7878

79-
# Global reduction with max and mean
79+
# Global reduction with max and mean and variance
8080
weight_tensor = tr.tensor(tname)
8181
max_val = weight_tensor.reduction_value(step_num=1, abs=False, reduction_name='max')
8282
assert max_val != None
8383
mean_val = weight_tensor.reduction_value(step_num=1, abs=False, reduction_name='mean')
8484
assert mean_val != None
85+
variance_val = weight_tensor.reduction_value(step_num=1, abs=False, reduction_name='variance')
86+
assert variance_val != None
8587

8688
# custom reduction at step 4 with reduction = 'min and abs reduction = 'max'
8789
tname = tr.tensors_matching_regex('relu0_input_0')[0]

tornasole/pytorch/util.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,30 +4,31 @@
44
from tornasole.core.reductions import get_numpy_reduction
55

66

7-
def get_aggregated_data(aggregation_name, tensor_data, tensor_name, abs=False):
8-
reduction_name = aggregation_name
7+
def get_aggregated_data(reduction_name, tensor_data, tensor_name, abs=False):
98
if isinstance(tensor_data, np.ndarray):
109
return get_numpy_reduction(reduction_name, tensor_data, abs)
1110
if abs:
1211
tensor_data = torch.abs(tensor_data)
1312

1413
if reduction_name in ALLOWED_REDUCTIONS:
15-
assert hasattr(torch.Tensor, aggregation_name)
16-
f = getattr(torch.Tensor, aggregation_name)
14+
if reduction_name == "variance":
15+
reduction_name = "var"
16+
assert hasattr(torch.Tensor, reduction_name)
17+
f = getattr(torch.Tensor, reduction_name)
1718
op = f(tensor_data)
1819
return op
1920
elif reduction_name in ALLOWED_NORMS:
20-
if aggregation_name in ['l1', 'l2']:
21-
ord = int(aggregation_name[1])
21+
if reduction_name in ['l1', 'l2']:
22+
ord = int(reduction_name[1])
2223
else:
2324
raise RuntimeError("Invalid normalization operation {0} for torch.Tensor".format(reduction_name))
2425
op = torch.norm(tensor_data, p=ord)
2526
return op
26-
elif hasattr(torch, aggregation_name):
27-
f = getattr(torch, aggregation_name)
27+
elif hasattr(torch, reduction_name):
28+
f = getattr(torch, reduction_name)
2829
op = f(tensor_data)
2930
return op
30-
raise RuntimeError("Invalid aggregation_name {0}".format(aggregation_name))
31+
raise RuntimeError("Invalid reduction_name {0}".format(reduction_name))
3132

3233

3334
def make_numpy_array(x):

0 commit comments

Comments
 (0)