Skip to content

Commit ba9f730

Browse files
authored
Merge pull request #905 from NVIDIA/feat_865_2
Feat 865
2 parents 4839b11 + 3e44ee5 commit ba9f730

File tree

2 files changed

+154
-0
lines changed

2 files changed

+154
-0
lines changed

py/torch_tensorrt/logging.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,113 @@ def log(level: Level, msg: str):
9696
msg (str): Actual message text
9797
"""
9898
_log(Level._to_internal_level(level), msg)
99+
100+
InternalError = LogLevel.INTERNAL_ERROR
101+
Error = LogLevel.ERROR
102+
Warning = LogLevel.WARNING
103+
Info = LogLevel.INFO
104+
Debug = LogLevel.DEBUG
105+
Graph = LogLevel.GRAPH
106+
107+
108+
class internal_errors:
109+
"""Context-manager to limit displayed log messages to just internal errors
110+
111+
Example::
112+
113+
with torch_tensorrt.logging.internal_errors():
114+
outputs = model_torchtrt(inputs)
115+
"""
116+
117+
def __enter__(self):
118+
self.external_lvl = get_reportable_log_level()
119+
set_reportable_log_level(Level.InternalError)
120+
121+
def __exit__(self, exc_type, exc_value, exc_tb):
122+
set_reportable_log_level(self.external_lvl)
123+
124+
125+
class errors:
126+
"""Context-manager to limit displayed log messages to just errors and above
127+
128+
Example::
129+
130+
with torch_tensorrt.logging.errors():
131+
outputs = model_torchtrt(inputs)
132+
"""
133+
134+
def __enter__(self):
135+
self.external_lvl = get_reportable_log_level()
136+
set_reportable_log_level(Level.Error)
137+
138+
def __exit__(self, exc_type, exc_value, exc_tb):
139+
set_reportable_log_level(self.external_lvl)
140+
141+
142+
class warnings:
143+
"""Context-manager to limit displayed log messages to just warnings and above
144+
145+
Example::
146+
147+
with torch_tensorrt.logging.warnings():
148+
model_trt = torch_tensorrt.compile(model, **spec)
149+
"""
150+
151+
def __enter__(self):
152+
self.external_lvl = get_reportable_log_level()
153+
set_reportable_log_level(Level.Warning)
154+
155+
def __exit__(self, exc_type, exc_value, exc_tb):
156+
set_reportable_log_level(self.external_lvl)
157+
158+
159+
class info:
160+
"""Context-manager to display all info and greater severity messages
161+
162+
Example::
163+
164+
with torch_tensorrt.logging.info():
165+
model_trt = torch_tensorrt.compile(model, **spec)
166+
"""
167+
168+
def __enter__(self):
169+
self.external_lvl = get_reportable_log_level()
170+
set_reportable_log_level(Level.Info)
171+
172+
def __exit__(self, exc_type, exc_value, exc_tb):
173+
set_reportable_log_level(self.external_lvl)
174+
175+
176+
class debug:
177+
"""Context-manager to display full debug information through the logger
178+
179+
Example::
180+
181+
with torch_tensorrt.logging.debug():
182+
model_trt = torch_tensorrt.compile(model, **spec)
183+
"""
184+
185+
def __enter__(self):
186+
self.external_lvl = get_reportable_log_level()
187+
set_reportable_log_level(Level.Debug)
188+
189+
def __exit__(self, exc_type, exc_value, exc_tb):
190+
set_reportable_log_level(self.external_lvl)
191+
192+
193+
class graphs:
194+
"""Context-manager to display the results of intermediate lowering passes
195+
as well as full debug information through the logger
196+
197+
Example::
198+
199+
with torch_tensorrt.logging.graphs():
200+
model_trt = torch_tensorrt.compile(model, **spec)
201+
"""
202+
203+
def __enter__(self):
204+
self.external_lvl = get_reportable_log_level()
205+
set_reportable_log_level(Level.Graph)
206+
207+
def __exit__(self, exc_type, exc_value, exc_tb):
208+
set_reportable_log_level(self.external_lvl)

tests/py/test_api.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,50 @@ def test_is_colored_output_on(self):
323323
color = torchtrt.logging.get_is_colored_output_on()
324324
self.assertTrue(color)
325325

326+
def test_context_managers(self):
327+
base_lvl = torchtrt.logging.get_reportable_log_level()
328+
with torchtrt.logging.internal_errors():
329+
lvl = torchtrt.logging.get_reportable_log_level()
330+
self.assertEqual(torchtrt.logging.Level.InternalError, lvl)
331+
332+
lvl = torchtrt.logging.get_reportable_log_level()
333+
self.assertEqual(base_lvl, lvl)
334+
335+
with torchtrt.logging.errors():
336+
lvl = torchtrt.logging.get_reportable_log_level()
337+
self.assertEqual(torchtrt.logging.Level.Error, lvl)
338+
339+
lvl = torchtrt.logging.get_reportable_log_level()
340+
self.assertEqual(base_lvl, lvl)
341+
342+
with torchtrt.logging.warnings():
343+
lvl = torchtrt.logging.get_reportable_log_level()
344+
self.assertEqual(torchtrt.logging.Level.Warning, lvl)
345+
346+
lvl = torchtrt.logging.get_reportable_log_level()
347+
self.assertEqual(base_lvl, lvl)
348+
349+
with torchtrt.logging.info():
350+
lvl = torchtrt.logging.get_reportable_log_level()
351+
self.assertEqual(torchtrt.logging.Level.Info, lvl)
352+
353+
lvl = torchtrt.logging.get_reportable_log_level()
354+
self.assertEqual(base_lvl, lvl)
355+
356+
with torchtrt.logging.debug():
357+
lvl = torchtrt.logging.get_reportable_log_level()
358+
self.assertEqual(torchtrt.logging.Level.Debug, lvl)
359+
360+
lvl = torchtrt.logging.get_reportable_log_level()
361+
self.assertEqual(base_lvl, lvl)
362+
363+
with torchtrt.logging.graphs():
364+
lvl = torchtrt.logging.get_reportable_log_level()
365+
self.assertEqual(torchtrt.logging.Level.Graph, lvl)
366+
367+
lvl = torchtrt.logging.get_reportable_log_level()
368+
self.assertEqual(base_lvl, lvl)
369+
326370

327371
class TestDevice(unittest.TestCase):
328372

0 commit comments

Comments
 (0)