Skip to content

Commit b0ae48e

Browse files
committed
chore(//py/torch_tensorrt): conforming logging to mypy
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 80818a9 commit b0ae48e

File tree

1 file changed

+27
-24
lines changed

1 file changed

+27
-24
lines changed

py/torch_tensorrt/logging.py

Lines changed: 27 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import Self, Any
12
from enum import Enum
23
from torch_tensorrt._C import (
34
_get_logging_prefix,
@@ -22,19 +23,21 @@ class Level(Enum):
2223
Graph = LogLevel.GRAPH
2324

2425
@staticmethod
25-
def _to_internal_level(external) -> LogLevel:
26+
def _to_internal_level(external: Self) -> LogLevel:
2627
if external == Level.InternalError:
2728
return LogLevel.INTERNAL_ERROR
28-
if external == Level.Error:
29+
elif external == Level.Error:
2930
return LogLevel.ERROR
30-
if external == Level.Warning:
31+
elif external == Level.Warning:
3132
return LogLevel.WARNING
32-
if external == Level.Info:
33+
elif external == Level.Info:
3334
return LogLevel.INFO
34-
if external == Level.Debug:
35+
elif external == Level.Debug:
3536
return LogLevel.DEBUG
36-
if external == Level.Graph:
37+
elif external == Level.Graph:
3738
return LogLevel.GRAPH
39+
else:
40+
raise ValueError("Unknown log severity")
3841

3942

4043
def get_logging_prefix() -> str:
@@ -43,10 +46,10 @@ def get_logging_prefix() -> str:
4346
Returns:
4447
str: Prefix used for logger
4548
"""
46-
return _get_logging_prefix()
49+
return str(_get_logging_prefix())
4750

4851

49-
def set_logging_prefix(prefix: str):
52+
def set_logging_prefix(prefix: str) -> None:
5053
"""Set the prefix used when logging messages
5154
5255
Args:
@@ -64,7 +67,7 @@ def get_reportable_log_level() -> Level:
6467
return Level(_get_reportable_log_level())
6568

6669

67-
def set_reportable_log_level(level: Level):
70+
def set_reportable_log_level(level: Level) -> None:
6871
"""Set the level required for a message to be printed to the log
6972
7073
Args:
@@ -79,10 +82,10 @@ def get_is_colored_output_on() -> bool:
7982
Returns:
8083
bool: If colored output is one
8184
"""
82-
return _get_is_colored_output_on()
85+
return bool(_get_is_colored_output_on())
8386

8487

85-
def set_is_colored_output_on(colored_output_on: bool):
88+
def set_is_colored_output_on(colored_output_on: bool) -> None:
8689
"""Enable or disable color in the log output
8790
8891
Args:
@@ -91,7 +94,7 @@ def set_is_colored_output_on(colored_output_on: bool):
9194
_set_is_colored_output_on(colored_output_on)
9295

9396

94-
def log(level: Level, msg: str):
97+
def log(level: Level, msg: str) -> None:
9598
"""Add a new message to the log
9699
97100
Adds a new message to the log at a specified level. The message
@@ -120,11 +123,11 @@ class internal_errors:
120123
outputs = model_torchtrt(inputs)
121124
"""
122125

123-
def __enter__(self):
126+
def __enter__(self) -> Self:
124127
self.external_lvl = get_reportable_log_level()
125128
set_reportable_log_level(Level.InternalError)
126129

127-
def __exit__(self, exc_type, exc_value, exc_tb):
130+
def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
128131
set_reportable_log_level(self.external_lvl)
129132

130133

@@ -137,11 +140,11 @@ class errors:
137140
outputs = model_torchtrt(inputs)
138141
"""
139142

140-
def __enter__(self):
143+
def __enter__(self) -> Self:
141144
self.external_lvl = get_reportable_log_level()
142145
set_reportable_log_level(Level.Error)
143146

144-
def __exit__(self, exc_type, exc_value, exc_tb):
147+
def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
145148
set_reportable_log_level(self.external_lvl)
146149

147150

@@ -154,11 +157,11 @@ class warnings:
154157
model_trt = torch_tensorrt.compile(model, **spec)
155158
"""
156159

157-
def __enter__(self):
160+
def __enter__(self) -> Self:
158161
self.external_lvl = get_reportable_log_level()
159162
set_reportable_log_level(Level.Warning)
160163

161-
def __exit__(self, exc_type, exc_value, exc_tb):
164+
def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
162165
set_reportable_log_level(self.external_lvl)
163166

164167

@@ -171,11 +174,11 @@ class info:
171174
model_trt = torch_tensorrt.compile(model, **spec)
172175
"""
173176

174-
def __enter__(self):
177+
def __enter__(self) -> Self:
175178
self.external_lvl = get_reportable_log_level()
176179
set_reportable_log_level(Level.Info)
177180

178-
def __exit__(self, exc_type, exc_value, exc_tb):
181+
def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
179182
set_reportable_log_level(self.external_lvl)
180183

181184

@@ -188,11 +191,11 @@ class debug:
188191
model_trt = torch_tensorrt.compile(model, **spec)
189192
"""
190193

191-
def __enter__(self):
194+
def __enter__(self) -> Self:
192195
self.external_lvl = get_reportable_log_level()
193196
set_reportable_log_level(Level.Debug)
194197

195-
def __exit__(self, exc_type, exc_value, exc_tb):
198+
def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
196199
set_reportable_log_level(self.external_lvl)
197200

198201

@@ -206,9 +209,9 @@ class graphs:
206209
model_trt = torch_tensorrt.compile(model, **spec)
207210
"""
208211

209-
def __enter__(self):
212+
def __enter__(self) -> Self:
210213
self.external_lvl = get_reportable_log_level()
211214
set_reportable_log_level(Level.Graph)
212215

213-
def __exit__(self, exc_type, exc_value, exc_tb):
216+
def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
214217
set_reportable_log_level(self.external_lvl)

0 commit comments

Comments
 (0)