34
34
from torch .fx .node import Node
35
35
36
36
from torch .overrides import TorchFunctionMode
37
- from torch .testing ._internal .common_utils import torch_to_numpy_dtype_dict
38
37
from tosa import TosaGraph
39
38
40
39
logger = logging .getLogger (__name__ )
41
40
logger .setLevel (logging .CRITICAL )
42
41
42
+ # Copied from PyTorch.
43
+ # From torch/testing/_internal/common_utils.py:torch_to_numpy_dtype_dict
44
+ # To avoid a dependency on _internal stuff.
45
+ _torch_to_numpy_dtype_dict = {
46
+ torch .bool : np .bool_ ,
47
+ torch .uint8 : np .uint8 ,
48
+ torch .uint16 : np .uint16 ,
49
+ torch .uint32 : np .uint32 ,
50
+ torch .uint64 : np .uint64 ,
51
+ torch .int8 : np .int8 ,
52
+ torch .int16 : np .int16 ,
53
+ torch .int32 : np .int32 ,
54
+ torch .int64 : np .int64 ,
55
+ torch .float16 : np .float16 ,
56
+ torch .float32 : np .float32 ,
57
+ torch .float64 : np .float64 ,
58
+ torch .bfloat16 : np .float32 ,
59
+ torch .complex32 : np .complex64 ,
60
+ torch .complex64 : np .complex64 ,
61
+ torch .complex128 : np .complex128 ,
62
+ }
63
+
43
64
44
65
class QuantizationParams :
45
66
__slots__ = ["node_name" , "zp" , "scale" , "qmin" , "qmax" , "dtype" ]
@@ -336,7 +357,7 @@ def run_corstone(
336
357
output_dtype = node .meta ["val" ].dtype
337
358
tosa_ref_output = np .fromfile (
338
359
os .path .join (intermediate_path , f"out-{ i } .bin" ),
339
- torch_to_numpy_dtype_dict [output_dtype ],
360
+ _torch_to_numpy_dtype_dict [output_dtype ],
340
361
)
341
362
342
363
output_np .append (torch .from_numpy (tosa_ref_output ).reshape (output_shape ))
@@ -350,7 +371,7 @@ def prep_data_for_save(
350
371
):
351
372
if isinstance (data , torch .Tensor ):
352
373
data_np = np .array (data .detach (), order = "C" ).astype (
353
- torch_to_numpy_dtype_dict [data .dtype ]
374
+ _torch_to_numpy_dtype_dict [data .dtype ]
354
375
)
355
376
else :
356
377
data_np = np .array (data )
0 commit comments