123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899 |
- import logging
- from typing import Callable, Generic, List
- from typing_extensions import ParamSpec # Python 3.10+
- logger = logging.getLogger(__name__)
- P = ParamSpec("P")
- class CallbackRegistry(Generic[P]):
- def __init__(self, name: str):
- self.name = name
- self.callback_list: List[Callable[P, None]] = []
- def add_callback(self, cb: Callable[P, None]) -> None:
- self.callback_list.append(cb)
- def fire_callbacks(self, *args: P.args, **kwargs: P.kwargs) -> None:
- for cb in self.callback_list:
- try:
- cb(*args, **kwargs)
- except Exception as e:
- logger.exception(
- f"Exception in callback for {self.name} registered with CUDA trace"
- )
- CUDAEventCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
- "CUDA event creation"
- )
- CUDAEventDeletionCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
- "CUDA event deletion"
- )
- CUDAEventRecordCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry(
- "CUDA event record"
- )
- CUDAEventWaitCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry(
- "CUDA event wait"
- )
- CUDAMemoryAllocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
- "CUDA memory allocation"
- )
- CUDAMemoryDeallocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
- "CUDA memory deallocation"
- )
- CUDAStreamCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
- "CUDA stream creation"
- )
- CUDADeviceSynchronizationCallbacks: "CallbackRegistry[[]]" = CallbackRegistry(
- "CUDA device synchronization"
- )
- CUDAStreamSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
- "CUDA stream synchronization"
- )
- CUDAEventSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
- "CUDA event synchronization"
- )
- def register_callback_for_cuda_event_creation(cb: Callable[[int], None]) -> None:
- CUDAEventCreationCallbacks.add_callback(cb)
- def register_callback_for_cuda_event_deletion(cb: Callable[[int], None]) -> None:
- CUDAEventDeletionCallbacks.add_callback(cb)
- def register_callback_for_cuda_event_record(cb: Callable[[int, int], None]) -> None:
- CUDAEventRecordCallbacks.add_callback(cb)
- def register_callback_for_cuda_event_wait(cb: Callable[[int, int], None]) -> None:
- CUDAEventWaitCallbacks.add_callback(cb)
- def register_callback_for_cuda_memory_allocation(cb: Callable[[int], None]) -> None:
- CUDAMemoryAllocationCallbacks.add_callback(cb)
- def register_callback_for_cuda_memory_deallocation(cb: Callable[[int], None]) -> None:
- CUDAMemoryDeallocationCallbacks.add_callback(cb)
- def register_callback_for_cuda_stream_creation(cb: Callable[[int], None]) -> None:
- CUDAStreamCreationCallbacks.add_callback(cb)
- def register_callback_for_cuda_device_synchronization(cb: Callable[[], None]) -> None:
- CUDADeviceSynchronizationCallbacks.add_callback(cb)
- def register_callback_for_cuda_stream_synchronization(
- cb: Callable[[int], None]
- ) -> None:
- CUDAStreamSynchronizationCallbacks.add_callback(cb)
- def register_callback_for_cuda_event_synchronization(cb: Callable[[int], None]) -> None:
- CUDAEventSynchronizationCallbacks.add_callback(cb)
|