|
4 | 4 | from tornasole.core.reductions import get_numpy_reduction
|
5 | 5 |
|
6 | 6 |
|
7 |
| -def get_aggregated_data(aggregation_name, tensor_data, tensor_name, abs=False): |
8 |
| - reduction_name = aggregation_name |
| 7 | +def get_aggregated_data(reduction_name, tensor_data, tensor_name, abs=False): |
9 | 8 | if isinstance(tensor_data, np.ndarray):
|
10 | 9 | return get_numpy_reduction(reduction_name, tensor_data, abs)
|
11 | 10 | if abs:
|
12 | 11 | tensor_data = torch.abs(tensor_data)
|
13 | 12 |
|
14 | 13 | if reduction_name in ALLOWED_REDUCTIONS:
|
15 |
| - assert hasattr(torch.Tensor, aggregation_name) |
16 |
| - f = getattr(torch.Tensor, aggregation_name) |
| 14 | + if reduction_name == "variance": |
| 15 | + reduction_name = "var" |
| 16 | + assert hasattr(torch.Tensor, reduction_name) |
| 17 | + f = getattr(torch.Tensor, reduction_name) |
17 | 18 | op = f(tensor_data)
|
18 | 19 | return op
|
19 | 20 | elif reduction_name in ALLOWED_NORMS:
|
20 |
| - if aggregation_name in ['l1', 'l2']: |
21 |
| - ord = int(aggregation_name[1]) |
| 21 | + if reduction_name in ['l1', 'l2']: |
| 22 | + ord = int(reduction_name[1]) |
22 | 23 | else:
|
23 | 24 | raise RuntimeError("Invalid normalization operation {0} for torch.Tensor".format(reduction_name))
|
24 | 25 | op = torch.norm(tensor_data, p=ord)
|
25 | 26 | return op
|
26 |
| - elif hasattr(torch, aggregation_name): |
27 |
| - f = getattr(torch, aggregation_name) |
| 27 | + elif hasattr(torch, reduction_name): |
| 28 | + f = getattr(torch, reduction_name) |
28 | 29 | op = f(tensor_data)
|
29 | 30 | return op
|
30 |
| - raise RuntimeError("Invalid aggregation_name {0}".format(aggregation_name)) |
| 31 | + raise RuntimeError("Invalid reduction_name {0}".format(reduction_name)) |
31 | 32 |
|
32 | 33 |
|
33 | 34 | def make_numpy_array(x):
|
|
0 commit comments