_cuda_trace.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import logging
  2. from typing import Callable, Generic, List
  3. from typing_extensions import ParamSpec # Python 3.10+
  4. logger = logging.getLogger(__name__)
  5. P = ParamSpec("P")
  6. class CallbackRegistry(Generic[P]):
  7. def __init__(self, name: str):
  8. self.name = name
  9. self.callback_list: List[Callable[P, None]] = []
  10. def add_callback(self, cb: Callable[P, None]) -> None:
  11. self.callback_list.append(cb)
  12. def fire_callbacks(self, *args: P.args, **kwargs: P.kwargs) -> None:
  13. for cb in self.callback_list:
  14. try:
  15. cb(*args, **kwargs)
  16. except Exception as e:
  17. logger.exception(
  18. f"Exception in callback for {self.name} registered with CUDA trace"
  19. )
  20. CUDAEventCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
  21. "CUDA event creation"
  22. )
  23. CUDAEventDeletionCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
  24. "CUDA event deletion"
  25. )
  26. CUDAEventRecordCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry(
  27. "CUDA event record"
  28. )
  29. CUDAEventWaitCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry(
  30. "CUDA event wait"
  31. )
  32. CUDAMemoryAllocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
  33. "CUDA memory allocation"
  34. )
  35. CUDAMemoryDeallocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
  36. "CUDA memory deallocation"
  37. )
  38. CUDAStreamCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
  39. "CUDA stream creation"
  40. )
  41. CUDADeviceSynchronizationCallbacks: "CallbackRegistry[[]]" = CallbackRegistry(
  42. "CUDA device synchronization"
  43. )
  44. CUDAStreamSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
  45. "CUDA stream synchronization"
  46. )
  47. CUDAEventSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
  48. "CUDA event synchronization"
  49. )
  50. def register_callback_for_cuda_event_creation(cb: Callable[[int], None]) -> None:
  51. CUDAEventCreationCallbacks.add_callback(cb)
  52. def register_callback_for_cuda_event_deletion(cb: Callable[[int], None]) -> None:
  53. CUDAEventDeletionCallbacks.add_callback(cb)
  54. def register_callback_for_cuda_event_record(cb: Callable[[int, int], None]) -> None:
  55. CUDAEventRecordCallbacks.add_callback(cb)
  56. def register_callback_for_cuda_event_wait(cb: Callable[[int, int], None]) -> None:
  57. CUDAEventWaitCallbacks.add_callback(cb)
  58. def register_callback_for_cuda_memory_allocation(cb: Callable[[int], None]) -> None:
  59. CUDAMemoryAllocationCallbacks.add_callback(cb)
  60. def register_callback_for_cuda_memory_deallocation(cb: Callable[[int], None]) -> None:
  61. CUDAMemoryDeallocationCallbacks.add_callback(cb)
  62. def register_callback_for_cuda_stream_creation(cb: Callable[[int], None]) -> None:
  63. CUDAStreamCreationCallbacks.add_callback(cb)
  64. def register_callback_for_cuda_device_synchronization(cb: Callable[[], None]) -> None:
  65. CUDADeviceSynchronizationCallbacks.add_callback(cb)
  66. def register_callback_for_cuda_stream_synchronization(
  67. cb: Callable[[int], None]
  68. ) -> None:
  69. CUDAStreamSynchronizationCallbacks.add_callback(cb)
  70. def register_callback_for_cuda_event_synchronization(cb: Callable[[int], None]) -> None:
  71. CUDAEventSynchronizationCallbacks.add_callback(cb)