Skip to content

Commit dc8647e

Browse files
authored
split profilers (#6261)
1 parent efda48f commit dc8647e

File tree

4 files changed

+295
-273
lines changed

4 files changed

+295
-273
lines changed

pytorch_lightning/profiler/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,9 +198,9 @@ def custom_processing_step(self, data):
198198
AdvancedProfiler,
199199
BaseProfiler,
200200
PassThroughProfiler,
201-
PyTorchProfiler,
202201
SimpleProfiler,
203202
)
203+
from pytorch_lightning.profiler.pytorch import PyTorchProfiler
204204

205205
__all__ = [
206206
'BaseProfiler',

pytorch_lightning/profiler/profilers.py

Lines changed: 1 addition & 271 deletions
Original file line numberDiff line numberDiff line change
@@ -14,24 +14,19 @@
1414
"""Profiler to check if there are any bottlenecks in your code."""
1515

1616
import cProfile
17-
import inspect
1817
import io
1918
import os
2019
import pstats
2120
import time
2221
from abc import ABC, abstractmethod
2322
from collections import defaultdict
2423
from contextlib import contextmanager
25-
from typing import List, Optional, Union
24+
from typing import Optional, Union
2625

2726
import numpy as np
28-
import torch
2927

3028
from pytorch_lightning import _logger as log
31-
from pytorch_lightning.utilities import rank_zero_only
3229
from pytorch_lightning.utilities.cloud_io import get_filesystem
33-
from pytorch_lightning.utilities.distributed import rank_zero_warn
34-
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3530

3631

3732
class BaseProfiler(ABC):
@@ -294,268 +289,3 @@ def __del__(self):
294289
"""Close profiler's stream."""
295290
if self.output_file:
296291
self.output_file.close()
297-
298-
299-
class PyTorchProfiler(BaseProfiler):
300-
301-
PROFILED_FUNCTIONS = ("training_step_and_backward", "validation_step", "test_step")
302-
AVAILABLE_SORT_KEYS = (
303-
"cpu_time",
304-
"cuda_time",
305-
"cpu_time_total",
306-
"cuda_time_total",
307-
"cpu_memory_usage",
308-
"cuda_memory_usage",
309-
"self_cpu_memory_usage",
310-
"self_cuda_memory_usage",
311-
"count",
312-
)
313-
314-
def __init__(
315-
self,
316-
output_filename: Optional[str] = None,
317-
enabled: bool = True,
318-
use_cuda: bool = False,
319-
record_shapes: bool = False,
320-
profile_memory: bool = False,
321-
group_by_input_shapes: bool = False,
322-
with_stack: bool = False,
323-
use_kineto: bool = False,
324-
use_cpu: bool = True,
325-
emit_nvtx: bool = False,
326-
export_to_chrome: bool = False,
327-
path_to_export_trace: str = None,
328-
row_limit: int = 20,
329-
sort_by_key: Optional[str] = None,
330-
profiled_functions: Optional[List] = None,
331-
local_rank: Optional[int] = None,
332-
):
333-
"""
334-
This profiler uses PyTorch's Autograd Profiler and lets you inspect the cost of
335-
different operators inside your model - both on the CPU and GPU
336-
337-
Args:
338-
339-
output_filename: optionally save profile results to file instead of printing
340-
to std out when training is finished. When using ``ddp``,
341-
each rank will stream the profiled operation to their own file
342-
with the extension ``_{rank}.txt``
343-
344-
enabled: Setting this to False makes this context manager a no-op.
345-
346-
use_cuda: Enables timing of CUDA events as well using the cudaEvent API.
347-
Adds approximately 4us of overhead to each tensor operation.
348-
349-
record_shapes: If shapes recording is set, information about input dimensions will be collected.
350-
351-
profile_memory: Whether to report memory usage, default: True (Introduced in PyTorch 1.6.0)
352-
353-
group_by_input_shapes: Include operator input shapes and group calls by shape.
354-
355-
with_stack: record source information (file and line number) for the ops (Introduced in PyTorch 1.7.0)
356-
357-
use_kineto: experimental support for Kineto profiler (Introduced in PyTorch 1.8.0)
358-
359-
use_cpu: use_kineto=True and can be used to lower the overhead
360-
for GPU-only profiling (Introduced in PyTorch 1.8.0)
361-
362-
emit_nvtx: Context manager that makes every autograd operation emit an NVTX range
363-
Run::
364-
365-
nvprof --profile-from-start off -o trace_name.prof -- <regular command here>
366-
367-
To visualize, you can either use::
368-
369-
nvvp trace_name.prof
370-
torch.autograd.profiler.load_nvprof(path)
371-
372-
export_to_chrome: Wether to export the sequence of profiled operators for Chrome.
373-
It will generate a ``.json`` file which can be read by Chrome.
374-
375-
path_to_export_trace: Directory path to export ``.json`` traces when using ``export_to_chrome=True``.
376-
By default, it will be save where the file being is being run.
377-
378-
row_limit: Limit the number of rows in a table, `0` is a special value that
379-
removes the limit completely.
380-
381-
sort_by_key: Keys to sort out profiled table
382-
383-
profiled_functions: list of profiled functions which will create a context manager on.
384-
Any other will be pass through.
385-
386-
local_rank: When running in distributed setting, local_rank is used for each process
387-
to write to their own file if `output_fname` is provided.
388-
389-
Raises:
390-
MisconfigurationException:
391-
If arg ``sort_by_key`` is not present in ``AVAILABLE_SORT_KEYS``, or
392-
if log file is not a ``.txt`` file.
393-
ValueError:
394-
If you attempt to stop recording an action which was never started.
395-
"""
396-
397-
self.profiled_actions = {}
398-
self.enabled = enabled
399-
self.profiled_functions = profiled_functions or self.PROFILED_FUNCTIONS
400-
self.use_cuda = use_cuda
401-
self.record_shapes = record_shapes
402-
self.profile_memory = profile_memory
403-
self.sort_by_key = sort_by_key or ("cuda_time_total" if self.use_cuda else "cpu_time_total")
404-
self.with_stack = with_stack
405-
self.group_by_input_shapes = group_by_input_shapes and record_shapes
406-
self.use_kineto = use_kineto
407-
self.use_cpu = use_cpu
408-
self.row_limit = row_limit
409-
self.emit_nvtx = emit_nvtx
410-
self.export_to_chrome = export_to_chrome
411-
self.path_to_export_trace = path_to_export_trace
412-
413-
if export_to_chrome and path_to_export_trace is None:
414-
rank_zero_warn(
415-
"The exported trace would be save locally as `path_to_export_trace` is empty."
416-
" Note: Each functions will generate its own traced file."
417-
)
418-
419-
if self.sort_by_key not in self.AVAILABLE_SORT_KEYS:
420-
raise MisconfigurationException(
421-
f"Found sort_by_key: {sort_by_key}. Should be within {self.AVAILABLE_SORT_KEYS}. "
422-
)
423-
424-
self.profiled_actions = {}
425-
self.context_names = {}
426-
self.running_stack = []
427-
self.profiler = None
428-
429-
self.output_fname = output_filename
430-
self.output_file = None
431-
if local_rank is not None:
432-
self.on_train_start(local_rank=local_rank)
433-
self.on_train_start = super().on_train_start
434-
435-
def on_train_start(self, local_rank: Optional[str] = None):
436-
self.local_rank = local_rank
437-
438-
# when logging to `log.info`, only perform profiling on rank 0
439-
if local_rank != 0 and self.output_fname is None:
440-
self.wrap_functions_into_rank_zero_only()
441-
442-
if self.output_fname:
443-
if local_rank is not None:
444-
if '.txt' not in self.output_fname:
445-
raise MisconfigurationException("Log file should be .txt file.")
446-
447-
self.output_fname = self.output_fname.replace(".txt", f"_{self.local_rank}.txt")
448-
449-
fs = get_filesystem(self.output_fname)
450-
self.output_file = fs.open(self.output_fname, "w")
451-
452-
streaming_out = [self.output_file.write] if self.output_file else [log.info]
453-
super().__init__(output_streams=streaming_out)
454-
455-
def wrap_functions_into_rank_zero_only(self):
456-
self.start = rank_zero_only(self.start)
457-
self.stop = rank_zero_only(self.stop)
458-
self.summary = rank_zero_only(self.summary)
459-
self.describe = rank_zero_only(self.describe)
460-
461-
def start(self, action_name: str) -> None:
462-
if action_name not in self.profiled_functions:
463-
return
464-
465-
if len(self.running_stack) > 0:
466-
self._stop(self.running_stack[-1])
467-
self.running_stack.append(action_name)
468-
469-
self.context_names[action_name] = "/".join(self.running_stack)
470-
471-
self._start(action_name)
472-
473-
def _start(self, action_name: str) -> None:
474-
if self.emit_nvtx:
475-
self._create_profiler(action_name, torch.cuda.profiler.profile, enter=False)
476-
self._create_profiler(action_name, torch.autograd.profiler.emit_nvtx)
477-
else:
478-
self._create_profiler(action_name, torch.autograd.profiler.profile)
479-
480-
def _create_profiler(self, action_name, profiler, enter=True):
481-
init_args = inspect.signature(profiler.__init__).parameters
482-
profiler_args = {k: v for k, v in vars(self).items() if k in init_args}
483-
pr = profiler(**profiler_args)
484-
if enter:
485-
pr = pr.__enter__()
486-
self.profiler = pr
487-
488-
def _stop(self, action_name: str) -> None:
489-
if self.profiler is None:
490-
return
491-
492-
self.profiler.__exit__(exc_type=None, exc_val=None, exc_tb=None)
493-
494-
function_events = self.profiler.function_events
495-
self.profiler = None
496-
for name in self.running_stack:
497-
if name not in self.profiled_actions:
498-
self.profiled_actions[name] = function_events
499-
else:
500-
self.profiled_actions[name] += function_events
501-
502-
def stop(self, action_name: str) -> None:
503-
if action_name not in self.profiled_functions:
504-
return
505-
506-
if len(self.running_stack) == 0 or self.running_stack[-1] != action_name:
507-
raise ValueError( # pragma: no-cover
508-
f"Attempting to stop recording an action ({action_name}) which was never started."
509-
)
510-
self._stop(action_name)
511-
self.running_stack.pop()
512-
# restore running profiler
513-
if len(self.running_stack) > 0:
514-
self._start(self.running_stack[-1])
515-
516-
def summary(self) -> str:
517-
recorded_stats = {}
518-
output_string = ''
519-
local_rank = '0' if self.local_rank is None else self.local_rank
520-
521-
if not self.enabled:
522-
return output_string
523-
524-
for action_name, function_events in self.profiled_actions.items():
525-
526-
# next line is a workaround for a pytorch issue (fixed on master, still present
527-
# on 1.7). Without it the code fails with `AssertionError: There is already a CPU
528-
# parent event for detach`
529-
function_events.populate_cpu_children = lambda: None
530-
531-
if self.export_to_chrome:
532-
filename = f"{action_name}_{local_rank}_trace.json"
533-
path_to_trace = filename if self.path_to_export_trace is None \
534-
else os.path.join(self.path_to_export_trace, filename)
535-
function_events.export_chrome_trace(path_to_trace)
536-
537-
if self.emit_nvtx:
538-
return output_string
539-
540-
else:
541-
data = function_events.key_averages(group_by_input_shapes=self.group_by_input_shapes)
542-
table = data.table(sort_by=self.sort_by_key, row_limit=self.row_limit)
543-
recorded_stats[action_name] = table
544-
545-
# log to standard out
546-
output_string = f"{os.linesep}Profiler Report{os.linesep}"
547-
for action, stats in recorded_stats.items():
548-
output_string += (f"{os.linesep}Profile stats for: {action} rank: {local_rank} {os.linesep}{stats}")
549-
550-
return output_string
551-
552-
def describe(self):
553-
"""Logs a profile report after the conclusion of the training run."""
554-
super().describe()
555-
if self.output_file:
556-
self.output_file.flush()
557-
558-
def __del__(self):
559-
"""Close profiler's stream."""
560-
if self.output_file:
561-
self.output_file.close()

0 commit comments

Comments
 (0)