1
- from typing import List , Dict , Any
1
+ from typing import List , Dict , Any , Self , Optional
2
2
import torch
3
3
import os
4
4
@@ -16,15 +16,15 @@ class CalibrationAlgo(Enum):
16
16
MINMAX_CALIBRATION = _C .CalibrationAlgo .MINMAX_CALIBRATION
17
17
18
18
19
- def get_cache_mode_batch (self ) :
19
+ def get_cache_mode_batch (self : object ) -> None :
20
20
return None
21
21
22
22
23
- def get_batch_size (self ) :
23
+ def get_batch_size (self : object ) -> int :
24
24
return 1
25
25
26
26
27
- def get_batch (self , names ) :
27
+ def get_batch (self : object , _ : Any ) -> Optional [ List [ int ]] :
28
28
if self .current_batch_idx + self .batch_size > len (self .data_loader .dataset ):
29
29
return None
30
30
@@ -39,27 +39,30 @@ def get_batch(self, names):
39
39
return inputs_gpu
40
40
41
41
42
- def read_calibration_cache (self ) :
42
+ def read_calibration_cache (self : object ) -> bytes :
43
43
if self .cache_file and self .use_cache :
44
44
if os .path .exists (self .cache_file ):
45
45
with open (self .cache_file , "rb" ) as f :
46
- return f .read ()
46
+ b : bytes = f .read ()
47
+ return b
48
+ else :
49
+ raise FileNotFoundError (self .cache_file )
47
50
else :
48
51
return b""
49
52
50
53
51
- def write_calibration_cache (self , cache ) :
54
+ def write_calibration_cache (self : object , cache : bytes ) -> None :
52
55
if self .cache_file :
53
56
with open (self .cache_file , "wb" ) as f :
54
57
f .write (cache )
55
58
else :
56
- return b""
59
+ return
57
60
58
61
59
62
# deepcopy (which involves pickling) is performed on the compile_spec internally during compilation.
60
63
# We register this __reduce__ function for pickler to identity the calibrator object returned by DataLoaderCalibrator during deepcopy.
61
64
# This should be the object's local name relative to the module https://docs.python.org/3/library/pickle.html#object.__reduce__
62
- def __reduce__ (self ) :
65
+ def __reduce__ (self : object ) -> str :
63
66
return self .__class__ .__name__
64
67
65
68
@@ -75,10 +78,10 @@ class DataLoaderCalibrator(object):
75
78
device: device on which calibration data is copied to.
76
79
"""
77
80
78
- def __init__ (self , ** kwargs ):
81
+ def __init__ (self , ** kwargs : Any ):
79
82
pass
80
83
81
- def __new__ (cls , * args , ** kwargs ) :
84
+ def __new__ (cls : Self , * args : Any , ** kwargs : Any ) -> Self :
82
85
dataloader = args [0 ]
83
86
algo_type = kwargs .get ("algo_type" , CalibrationAlgo .ENTROPY_CALIBRATION_2 )
84
87
cache_file = kwargs .get ("cache_file" , None )
@@ -158,10 +161,10 @@ class CacheCalibrator(object):
158
161
algo_type: choice of calibration algorithm.
159
162
"""
160
163
161
- def __init__ (self , ** kwargs ):
164
+ def __init__ (self , ** kwargs : Any ):
162
165
pass
163
166
164
- def __new__ (cls , * args , ** kwargs ) :
167
+ def __new__ (cls : Self , * args : Any , ** kwargs : Any ) -> Self :
165
168
cache_file = args [0 ]
166
169
algo_type = kwargs .get ("algo_type" , CalibrationAlgo .ENTROPY_CALIBRATION_2 )
167
170
0 commit comments