@@ -372,13 +372,15 @@ def plot_metric(result: List[float], metric_name: str):
372
372
373
373
def calculate_mse (ref_values : ProgramOutput , values : ProgramOutput ):
374
374
def mean_squared_error (a : torch .Tensor , b : torch .Tensor ):
375
- return round ((torch .pow ((a - b ). to ( torch . float32 ) , 2 )).mean ().item (), 2 )
375
+ return round ((torch .pow ((a - b ), 2 )).mean ().item (), 2 )
376
376
377
377
results = []
378
378
for ref_value , value in zip (ref_values , values ):
379
379
# TODO T171811011: extend the implementation of each metrics function to support value types other than tensor type
380
380
if isinstance (ref_value , torch .Tensor ) and isinstance (value , torch .Tensor ):
381
- results .append (mean_squared_error (ref_value , value ))
381
+ results .append (
382
+ mean_squared_error (ref_value .to (torch .float32 ), value .to (torch .float32 ))
383
+ )
382
384
else :
383
385
results .append (None )
384
386
@@ -387,8 +389,6 @@ def mean_squared_error(a: torch.Tensor, b: torch.Tensor):
387
389
388
390
def calculate_snr (ref_values : ProgramOutput , values : ProgramOutput ):
389
391
def signal_to_noise (signal : torch .Tensor , noise : torch .Tensor ):
390
- signal = signal .type (torch .float32 )
391
- noise = noise .type (torch .float32 )
392
392
signal_power = torch .mean (torch .pow (signal , 2 ))
393
393
noise_power = torch .mean (torch .pow (noise , 2 ))
394
394
snr = 10 * torch .log10 (signal_power / noise_power )
@@ -398,8 +398,10 @@ def signal_to_noise(signal: torch.Tensor, noise: torch.Tensor):
398
398
for ref_value , value in zip (ref_values , values ):
399
399
# TODO T171811011: extend the implementation of each metrics function to support value types other than tensor type
400
400
if isinstance (ref_value , torch .Tensor ) and isinstance (value , torch .Tensor ):
401
- diff = ref_value - value
402
- snr = signal_to_noise (ref_value , diff )
401
+ ref_value_fp = ref_value .to (torch .float32 )
402
+ value_fp = value .to (torch .float32 )
403
+ diff = ref_value_fp - value_fp
404
+ snr = signal_to_noise (ref_value_fp , diff )
403
405
results .append (snr )
404
406
else :
405
407
results .append (None )
@@ -429,7 +431,9 @@ def cosine_similarity(tensor1: torch.Tensor, tensor2: torch.Tensor):
429
431
for ref_value , value in zip (ref_values , values ):
430
432
# TODO T171811011: extend the implementation of each metrics function to support value types other than tensor type
431
433
if isinstance (ref_value , torch .Tensor ) and isinstance (value , torch .Tensor ):
432
- results .append (cosine_similarity (ref_value , value ))
434
+ results .append (
435
+ cosine_similarity (ref_value .to (torch .float32 ), value .to (torch .float32 ))
436
+ )
433
437
else :
434
438
results .append (None )
435
439
0 commit comments