#!/usr/bin/python3 import itertools import torch from torch.autograd.profiler_legacy import profile from typing import List from . import ( _disable_server_process_global_profiler, _enable_server_process_global_profiler, ) __all__: List[str] = [] class _server_process_global_profile(profile): """ It has the same API as ``torch.autograd.profiler.profile`` class, except that it enables profiling on all threads running RPC server request callbacks. Context manager that manages autograd profiler state and holds a summary of results. Under the hood it just records events of functions being executed in C++ and exposes those events to Python. You can wrap any code into it and it will only report runtime of PyTorch functions. Note: profiler is thread local and is automatically propagated into the async tasks Args: enabled (bool, optional): Setting this to False makes this context manager a no-op. Default: ``True``. use_cuda (bool, optional): Enables timing of CUDA events as well using the cudaEvent API. Adds approximately 4us of overhead to each tensor operation. Default: ``False`` record_shapes (bool, optional): If shapes recording is set, information about input dimensions will be collected. This allows one to see which dimensions have been used under the hood and further group by them using prof.key_averages(group_by_input_shape=True). Please note that shape recording might skew your profiling data. It is recommended to use separate runs with and without shape recording to validate the timing. Most likely the skew will be negligible for bottom most events (in a case of nested function calls). But for higher level functions the total self cpu time might be artificially increased because of the shape collection. profile_memory (bool, optional): Whether to report memory usage, default: ``False`` .. warning: Enabling memory profiling incurs additional profiler overhead .. warning: Due to some CUDA multiprocessing limitations (multiprocessing-cuda-note_), one cannot use the profiler with ``use_cuda = True`` to benchmark DataLoaders with ``num_workers > 0``. If you wish to benchmark data loading, please use ``use_cuda = False`` or ``num_workers = 0``. Example: >>> # xdoctest: +SKIP >>> # On worker 0: >>> import torch >>> import torch.distributed.rpc as rpc >>> rpc.init_rpc("worker0", rank=0, world_size=2) >>> x, y = torch.tensor(1), torch.tensor(2) >>> outer_profile_rref = rpc.remote(dst_worker_name, rpc._server_process_global_profile) >>> outer_profile_rref.rpc_sync().__enter__() >>> rpc.rpc_sync(dst_worker_name, torch.add, (x, y)) >>> inner_profile_rref = rpc.remote(dst_worker_name, rpc._server_process_global_profile) >>> inner_profile_rref.rpc_sync().__enter__() >>> rpc.rpc_sync(dst_worker_name, torch.sub, (x, y)) >>> inner_profile_rref.rpc_sync().__exit__(None, None, None) >>> outer_profile_rref.rpc_sync().__exit__(None, None, None) >>> print(inner_profile_rref.rpc_sync().key_averages()) --------- --------------- --------------- --------------- --------------- --------------- --------------- Name Self CPU total % Self CPU total CPU total % CPU total CPU time avg Number of Calls --------- --------------- --------------- --------------- --------------- --------------- --------------- sub 85.06% 76.275us 100.00% 89.667us 89.667us 1 empty 14.94% 13.392us 14.94% 13.392us 13.392us 1 --------- --------------- --------------- --------------- --------------- --------------- --------------- Self CPU time total: 89.667us >>> print(outer_profile_rref.rpc_sync().key_averages()) --------- --------------- --------------- --------------- --------------- --------------- --------------- Name Self CPU total % Self CPU total CPU total % CPU total CPU time avg Number of Calls --------- --------------- --------------- --------------- --------------- --------------- --------------- sub 35.65% 76.275us 41.91% 89.667us 89.667us 1 empty 12.67% 27.101us 12.67% 27.101us 13.551us 2 add 51.68% 110.550us 58.09% 124.259us 124.259us 1 --------- --------------- --------------- --------------- --------------- --------------- --------------- Self CPU time total: 213.926us >>> rpc.shutdown() >>> # On worker 1: >>> import torch.distributed.rpc as rpc >>> rpc.init_rpc("worker1", rank=1, world_size=2) >>> # wait for worker 0 to finish work, and then shutdown. >>> rpc.shutdown() """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def __enter__(self): """ Turn on server-side process-global profiling. This enables thread-local profiler on all RPC threads running server-side request callbacks. """ if not self.enabled: return if self.entered: # type: ignore[has-type] raise RuntimeError("autograd profiler traces are not reentrant") self.entered = True profiler_kind = ( torch.autograd.ProfilerState.CUDA if self.use_cuda else torch.autograd.ProfilerState.CPU ) profiler_config = torch.autograd.ProfilerConfig( profiler_kind, self.record_shapes, self.profile_memory, False, False, False, torch.profiler._ExperimentalConfig()) _enable_server_process_global_profiler(profiler_config) return self def __exit__(self, exc_type, exc_val, exc_tb): """ Turn off server-side process-global profiling. Aggregate all profiling events recorded by RPC threads. These attributes are assigned on exiting context. Attributes: function_events (torch.autograd.profiler.EventList). It's a list that has helper methods, like 1) show record items in a pretty-print table. 2) do averaging by grouping on keys. 3) and more. process_global_function_events (List[torch.autograd.profiler.FunctionEvent]). It's a list of ``FunctionEvent`` elements. Every element is a profiling result of an RPC request handling within the profiling range. """ if not self.enabled: return process_global_events = _disable_server_process_global_profiler() # Every element in this list is a thread profiling result from an RPC request handling. process_global_function_events = [] for thread_local_events in process_global_events: # Parse from ``Event``s to ``FunctionEvent``s. thread_local_function_events = torch.autograd.profiler_legacy._parse_legacy_records( thread_local_events ) thread_local_function_events.sort( key=lambda function_event: [ function_event.time_range.start, -(function_event.time_range.end), ] ) process_global_function_events.append(thread_local_function_events) flattened_function_events = list( itertools.chain(*process_global_function_events) ) self.function_events = torch.autograd.profiler_util.EventList( flattened_function_events, use_cuda=self.use_cuda, profile_memory=self.profile_memory, ) self.function_events._build_tree() self.process_global_function_events = process_global_function_events return False