1
+ from typing import Self , Any
1
2
from enum import Enum
2
3
from torch_tensorrt ._C import (
3
4
_get_logging_prefix ,
@@ -22,19 +23,21 @@ class Level(Enum):
22
23
Graph = LogLevel .GRAPH
23
24
24
25
@staticmethod
25
- def _to_internal_level (external ) -> LogLevel :
26
+ def _to_internal_level (external : Self ) -> LogLevel :
26
27
if external == Level .InternalError :
27
28
return LogLevel .INTERNAL_ERROR
28
- if external == Level .Error :
29
+ elif external == Level .Error :
29
30
return LogLevel .ERROR
30
- if external == Level .Warning :
31
+ elif external == Level .Warning :
31
32
return LogLevel .WARNING
32
- if external == Level .Info :
33
+ elif external == Level .Info :
33
34
return LogLevel .INFO
34
- if external == Level .Debug :
35
+ elif external == Level .Debug :
35
36
return LogLevel .DEBUG
36
- if external == Level .Graph :
37
+ elif external == Level .Graph :
37
38
return LogLevel .GRAPH
39
+ else :
40
+ raise ValueError ("Unknown log severity" )
38
41
39
42
40
43
def get_logging_prefix () -> str :
@@ -43,10 +46,10 @@ def get_logging_prefix() -> str:
43
46
Returns:
44
47
str: Prefix used for logger
45
48
"""
46
- return _get_logging_prefix ()
49
+ return str ( _get_logging_prefix () )
47
50
48
51
49
- def set_logging_prefix (prefix : str ):
52
+ def set_logging_prefix (prefix : str ) -> None :
50
53
"""Set the prefix used when logging messages
51
54
52
55
Args:
@@ -64,7 +67,7 @@ def get_reportable_log_level() -> Level:
64
67
return Level (_get_reportable_log_level ())
65
68
66
69
67
- def set_reportable_log_level (level : Level ):
70
+ def set_reportable_log_level (level : Level ) -> None :
68
71
"""Set the level required for a message to be printed to the log
69
72
70
73
Args:
@@ -79,10 +82,10 @@ def get_is_colored_output_on() -> bool:
79
82
Returns:
80
83
bool: If colored output is one
81
84
"""
82
- return _get_is_colored_output_on ()
85
+ return bool ( _get_is_colored_output_on () )
83
86
84
87
85
- def set_is_colored_output_on (colored_output_on : bool ):
88
+ def set_is_colored_output_on (colored_output_on : bool ) -> None :
86
89
"""Enable or disable color in the log output
87
90
88
91
Args:
@@ -91,7 +94,7 @@ def set_is_colored_output_on(colored_output_on: bool):
91
94
_set_is_colored_output_on (colored_output_on )
92
95
93
96
94
- def log (level : Level , msg : str ):
97
+ def log (level : Level , msg : str ) -> None :
95
98
"""Add a new message to the log
96
99
97
100
Adds a new message to the log at a specified level. The message
@@ -120,11 +123,11 @@ class internal_errors:
120
123
outputs = model_torchtrt(inputs)
121
124
"""
122
125
123
- def __enter__ (self ):
126
+ def __enter__ (self ) -> Self :
124
127
self .external_lvl = get_reportable_log_level ()
125
128
set_reportable_log_level (Level .InternalError )
126
129
127
- def __exit__ (self , exc_type , exc_value , exc_tb ) :
130
+ def __exit__ (self , exc_type : Any , exc_value : Any , exc_tb : Any ) -> None :
128
131
set_reportable_log_level (self .external_lvl )
129
132
130
133
@@ -137,11 +140,11 @@ class errors:
137
140
outputs = model_torchtrt(inputs)
138
141
"""
139
142
140
- def __enter__ (self ):
143
+ def __enter__ (self ) -> Self :
141
144
self .external_lvl = get_reportable_log_level ()
142
145
set_reportable_log_level (Level .Error )
143
146
144
- def __exit__ (self , exc_type , exc_value , exc_tb ) :
147
+ def __exit__ (self , exc_type : Any , exc_value : Any , exc_tb : Any ) -> None :
145
148
set_reportable_log_level (self .external_lvl )
146
149
147
150
@@ -154,11 +157,11 @@ class warnings:
154
157
model_trt = torch_tensorrt.compile(model, **spec)
155
158
"""
156
159
157
- def __enter__ (self ):
160
+ def __enter__ (self ) -> Self :
158
161
self .external_lvl = get_reportable_log_level ()
159
162
set_reportable_log_level (Level .Warning )
160
163
161
- def __exit__ (self , exc_type , exc_value , exc_tb ) :
164
+ def __exit__ (self , exc_type : Any , exc_value : Any , exc_tb : Any ) -> None :
162
165
set_reportable_log_level (self .external_lvl )
163
166
164
167
@@ -171,11 +174,11 @@ class info:
171
174
model_trt = torch_tensorrt.compile(model, **spec)
172
175
"""
173
176
174
- def __enter__ (self ):
177
+ def __enter__ (self ) -> Self :
175
178
self .external_lvl = get_reportable_log_level ()
176
179
set_reportable_log_level (Level .Info )
177
180
178
- def __exit__ (self , exc_type , exc_value , exc_tb ) :
181
+ def __exit__ (self , exc_type : Any , exc_value : Any , exc_tb : Any ) -> None :
179
182
set_reportable_log_level (self .external_lvl )
180
183
181
184
@@ -188,11 +191,11 @@ class debug:
188
191
model_trt = torch_tensorrt.compile(model, **spec)
189
192
"""
190
193
191
- def __enter__ (self ):
194
+ def __enter__ (self ) -> Self :
192
195
self .external_lvl = get_reportable_log_level ()
193
196
set_reportable_log_level (Level .Debug )
194
197
195
- def __exit__ (self , exc_type , exc_value , exc_tb ) :
198
+ def __exit__ (self , exc_type : Any , exc_value : Any , exc_tb : Any ) -> None :
196
199
set_reportable_log_level (self .external_lvl )
197
200
198
201
@@ -206,9 +209,9 @@ class graphs:
206
209
model_trt = torch_tensorrt.compile(model, **spec)
207
210
"""
208
211
209
- def __enter__ (self ):
212
+ def __enter__ (self ) -> Self :
210
213
self .external_lvl = get_reportable_log_level ()
211
214
set_reportable_log_level (Level .Graph )
212
215
213
- def __exit__ (self , exc_type , exc_value , exc_tb ) :
216
+ def __exit__ (self , exc_type : Any , exc_value : Any , exc_tb : Any ) -> None :
214
217
set_reportable_log_level (self .external_lvl )
0 commit comments