|
1 | 1 | from enum import Enum
|
2 |
| -from hashlib import new |
3 |
| -from imp import new_module |
4 | 2 | from torch_tensorrt._C import _get_logging_prefix, _set_logging_prefix, \
|
5 | 3 | _get_reportable_log_level, _set_reportable_log_level, \
|
6 | 4 | _get_is_colored_output_on, _set_is_colored_output_on, \
|
@@ -99,55 +97,109 @@ def log(level: Level, msg: str):
|
99 | 97 | """
|
100 | 98 | _log(Level._to_internal_level(level), msg)
|
101 | 99 |
|
102 |
| - |
103 | 100 | InternalError = LogLevel.INTERNAL_ERROR
|
104 | 101 | Error = LogLevel.ERROR
|
105 | 102 | Warning = LogLevel.WARNING
|
106 | 103 | Info = LogLevel.INFO
|
107 | 104 | Debug = LogLevel.DEBUG
|
108 | 105 | Graph = LogLevel.GRAPH
|
109 | 106 |
|
| 107 | + |
110 | 108 | class internal_errors:
|
| 109 | + """Context-manager to limit displayed log messages to just internal errors |
| 110 | +
|
| 111 | + Example:: |
| 112 | +
|
| 113 | + with torch_tensorrt.logging.internal_errors(): |
| 114 | + outputs = model_torchtrt(inputs) |
| 115 | + """ |
| 116 | + |
111 | 117 | def __enter__(self):
|
112 | 118 | self.external_lvl = get_reportable_log_level()
|
113 | 119 | set_reportable_log_level(Level.InternalError)
|
114 | 120 |
|
115 | 121 | def __exit__(self, exc_type, exc_value, exc_tb):
|
116 | 122 | set_reportable_log_level(self.external_lvl)
|
117 | 123 |
|
| 124 | + |
118 | 125 | class errors:
|
| 126 | + """Context-manager to limit displayed log messages to just errors and above |
| 127 | +
|
| 128 | + Example:: |
| 129 | +
|
| 130 | + with torch_tensorrt.logging.errors(): |
| 131 | + outputs = model_torchtrt(inputs) |
| 132 | + """ |
| 133 | + |
119 | 134 | def __enter__(self):
|
120 | 135 | self.external_lvl = get_reportable_log_level()
|
121 | 136 | set_reportable_log_level(Level.Error)
|
122 | 137 |
|
123 | 138 | def __exit__(self, exc_type, exc_value, exc_tb):
|
124 | 139 | set_reportable_log_level(self.external_lvl)
|
125 | 140 |
|
| 141 | + |
126 | 142 | class warnings:
|
| 143 | + """Context-manager to limit displayed log messages to just warnings and above |
| 144 | +
|
| 145 | + Example:: |
| 146 | +
|
| 147 | + with torch_tensorrt.logging.warnings(): |
| 148 | + model_trt = torch_tensorrt.compile(model, **spec) |
| 149 | + """ |
| 150 | + |
127 | 151 | def __enter__(self):
|
128 | 152 | self.external_lvl = get_reportable_log_level()
|
129 | 153 | set_reportable_log_level(Level.Warning)
|
130 | 154 |
|
131 | 155 | def __exit__(self, exc_type, exc_value, exc_tb):
|
132 | 156 | set_reportable_log_level(self.external_lvl)
|
133 | 157 |
|
| 158 | + |
134 | 159 | class info:
|
| 160 | + """Context-manager to display all info and greater severity messages |
| 161 | +
|
| 162 | + Example:: |
| 163 | +
|
| 164 | + with torch_tensorrt.logging.info(): |
| 165 | + model_trt = torch_tensorrt.compile(model, **spec) |
| 166 | + """ |
| 167 | + |
135 | 168 | def __enter__(self):
|
136 | 169 | self.external_lvl = get_reportable_log_level()
|
137 | 170 | set_reportable_log_level(Level.Info)
|
138 | 171 |
|
139 | 172 | def __exit__(self, exc_type, exc_value, exc_tb):
|
140 | 173 | set_reportable_log_level(self.external_lvl)
|
141 | 174 |
|
| 175 | + |
142 | 176 | class debug:
|
| 177 | + """Context-manager to display full debug information through the logger |
| 178 | +
|
| 179 | + Example:: |
| 180 | +
|
| 181 | + with torch_tensorrt.logging.debug(): |
| 182 | + model_trt = torch_tensorrt.compile(model, **spec) |
| 183 | + """ |
| 184 | + |
143 | 185 | def __enter__(self):
|
144 | 186 | self.external_lvl = get_reportable_log_level()
|
145 | 187 | set_reportable_log_level(Level.Debug)
|
146 | 188 |
|
147 | 189 | def __exit__(self, exc_type, exc_value, exc_tb):
|
148 | 190 | set_reportable_log_level(self.external_lvl)
|
149 | 191 |
|
| 192 | + |
150 | 193 | class graphs:
|
| 194 | + """Context-manager to display the results of intermediate lowering passes |
| 195 | + as well as full debug information through the logger |
| 196 | +
|
| 197 | + Example:: |
| 198 | +
|
| 199 | + with torch_tensorrt.logging.graphs(): |
| 200 | + model_trt = torch_tensorrt.compile(model, **spec) |
| 201 | + """ |
| 202 | + |
151 | 203 | def __enter__(self):
|
152 | 204 | self.external_lvl = get_reportable_log_level()
|
153 | 205 | set_reportable_log_level(Level.Graph)
|
|
0 commit comments