Skip to content

Commit 491b933

Browse files
committed
chore(//py/torch_tensorrt): ptq mypy compliance
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent bd19b41 commit 491b933

File tree

1 file changed

+16
-13
lines changed

1 file changed

+16
-13
lines changed

py/torch_tensorrt/ptq.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Dict, Any
1+
from typing import List, Dict, Any, Self, Optional
22
import torch
33
import os
44

@@ -16,15 +16,15 @@ class CalibrationAlgo(Enum):
1616
MINMAX_CALIBRATION = _C.CalibrationAlgo.MINMAX_CALIBRATION
1717

1818

19-
def get_cache_mode_batch(self):
19+
def get_cache_mode_batch(self: object) -> None:
2020
return None
2121

2222

23-
def get_batch_size(self):
23+
def get_batch_size(self: object) -> int:
2424
return 1
2525

2626

27-
def get_batch(self, names):
27+
def get_batch(self: object, _: Any) -> Optional[List[int]]:
2828
if self.current_batch_idx + self.batch_size > len(self.data_loader.dataset):
2929
return None
3030

@@ -39,27 +39,30 @@ def get_batch(self, names):
3939
return inputs_gpu
4040

4141

42-
def read_calibration_cache(self):
42+
def read_calibration_cache(self: object) -> bytes:
4343
if self.cache_file and self.use_cache:
4444
if os.path.exists(self.cache_file):
4545
with open(self.cache_file, "rb") as f:
46-
return f.read()
46+
b: bytes = f.read()
47+
return b
48+
else:
49+
raise FileNotFoundError(self.cache_file)
4750
else:
4851
return b""
4952

5053

51-
def write_calibration_cache(self, cache):
54+
def write_calibration_cache(self: object, cache: bytes) -> None:
5255
if self.cache_file:
5356
with open(self.cache_file, "wb") as f:
5457
f.write(cache)
5558
else:
56-
return b""
59+
return
5760

5861

5962
# deepcopy (which involves pickling) is performed on the compile_spec internally during compilation.
6063
# We register this __reduce__ function for pickler to identity the calibrator object returned by DataLoaderCalibrator during deepcopy.
6164
# This should be the object's local name relative to the module https://docs.python.org/3/library/pickle.html#object.__reduce__
62-
def __reduce__(self):
65+
def __reduce__(self: object) -> str:
6366
return self.__class__.__name__
6467

6568

@@ -75,10 +78,10 @@ class DataLoaderCalibrator(object):
7578
device: device on which calibration data is copied to.
7679
"""
7780

78-
def __init__(self, **kwargs):
81+
def __init__(self, **kwargs: Any):
7982
pass
8083

81-
def __new__(cls, *args, **kwargs):
84+
def __new__(cls: Self, *args: Any, **kwargs: Any) -> Self:
8285
dataloader = args[0]
8386
algo_type = kwargs.get("algo_type", CalibrationAlgo.ENTROPY_CALIBRATION_2)
8487
cache_file = kwargs.get("cache_file", None)
@@ -158,10 +161,10 @@ class CacheCalibrator(object):
158161
algo_type: choice of calibration algorithm.
159162
"""
160163

161-
def __init__(self, **kwargs):
164+
def __init__(self, **kwargs: Any):
162165
pass
163166

164-
def __new__(cls, *args, **kwargs):
167+
def __new__(cls: Self, *args: Any, **kwargs: Any) -> Self:
165168
cache_file = args[0]
166169
algo_type = kwargs.get("algo_type", CalibrationAlgo.ENTROPY_CALIBRATION_2)
167170

0 commit comments

Comments
 (0)