Skip to content

Commit 11dc516

Browse files
committed
refactor(//py): Give better names to the CtxManagers
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 12e470f commit 11dc516

File tree

2 files changed

+49
-7
lines changed

2 files changed

+49
-7
lines changed

py/torch_tensorrt/logging.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -107,51 +107,50 @@ def log(level: Level, msg: str):
107107
Debug = LogLevel.DEBUG
108108
Graph = LogLevel.GRAPH
109109

110-
class InternalErrors:
110+
class internal_errors:
111111
def __enter__(self):
112112
self.external_lvl = get_reportable_log_level()
113113
set_reportable_log_level(Level.InternalError)
114114

115115
def __exit__(self, exc_type, exc_value, exc_tb):
116116
set_reportable_log_level(self.external_lvl)
117117

118-
class Errors:
118+
class errors:
119119
def __enter__(self):
120120
self.external_lvl = get_reportable_log_level()
121121
set_reportable_log_level(Level.Error)
122122

123123
def __exit__(self, exc_type, exc_value, exc_tb):
124124
set_reportable_log_level(self.external_lvl)
125125

126-
class Warnings:
126+
class warnings:
127127
def __enter__(self):
128128
self.external_lvl = get_reportable_log_level()
129129
set_reportable_log_level(Level.Warning)
130130

131131
def __exit__(self, exc_type, exc_value, exc_tb):
132132
set_reportable_log_level(self.external_lvl)
133133

134-
class Info:
134+
class info:
135135
def __enter__(self):
136136
self.external_lvl = get_reportable_log_level()
137137
set_reportable_log_level(Level.Info)
138138

139139
def __exit__(self, exc_type, exc_value, exc_tb):
140140
set_reportable_log_level(self.external_lvl)
141141

142-
class Debug:
142+
class debug:
143143
def __enter__(self):
144144
self.external_lvl = get_reportable_log_level()
145145
set_reportable_log_level(Level.Debug)
146146

147147
def __exit__(self, exc_type, exc_value, exc_tb):
148148
set_reportable_log_level(self.external_lvl)
149149

150-
class Graphs:
150+
class graphs:
151151
def __enter__(self):
152152
self.external_lvl = get_reportable_log_level()
153153
set_reportable_log_level(Level.Graph)
154154

155155
def __exit__(self, exc_type, exc_value, exc_tb):
156156
set_reportable_log_level(self.external_lvl)
157-

tests/py/test_api.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,49 @@ def test_is_colored_output_on(self):
323323
color = torchtrt.logging.get_is_colored_output_on()
324324
self.assertTrue(color)
325325

326+
def test_context_managers(self):
327+
base_lvl = torchtrt.logging.get_reportable_log_level()
328+
with torchtrt.logging.internal_errors():
329+
lvl = torchtrt.logging.get_reportable_log_level()
330+
self.assertEqual(torchtrt.logging.Level.InternalError, lvl)
331+
332+
lvl = torchtrt.logging.get_reportable_log_level()
333+
self.assertEqual(base_lvl, lvl)
334+
335+
with torchtrt.logging.errors():
336+
lvl = torchtrt.logging.get_reportable_log_level()
337+
self.assertEqual(torchtrt.logging.Level.Error, lvl)
338+
339+
lvl = torchtrt.logging.get_reportable_log_level()
340+
self.assertEqual(base_lvl, lvl)
341+
342+
with torchtrt.logging.warnings():
343+
lvl = torchtrt.logging.get_reportable_log_level()
344+
self.assertEqual(torchtrt.logging.Level.Warning, lvl)
345+
346+
lvl = torchtrt.logging.get_reportable_log_level()
347+
self.assertEqual(base_lvl, lvl)
348+
349+
with torchtrt.logging.info():
350+
lvl = torchtrt.logging.get_reportable_log_level()
351+
self.assertEqual(torchtrt.logging.Level.Info, lvl)
352+
353+
lvl = torchtrt.logging.get_reportable_log_level()
354+
self.assertEqual(base_lvl, lvl)
355+
356+
with torchtrt.logging.debug():
357+
lvl = torchtrt.logging.get_reportable_log_level()
358+
self.assertEqual(torchtrt.logging.Level.Debug, lvl)
359+
360+
lvl = torchtrt.logging.get_reportable_log_level()
361+
self.assertEqual(base_lvl, lvl)
362+
363+
with torchtrt.logging.graphs():
364+
lvl = torchtrt.logging.get_reportable_log_level()
365+
self.assertEqual(torchtrt.logging.Level.Graph, lvl)
366+
367+
lvl = torchtrt.logging.get_reportable_log_level()
368+
self.assertEqual(base_lvl, lvl)
326369

327370
class TestDevice(unittest.TestCase):
328371

0 commit comments

Comments
 (0)