__init__.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. #!/usr/bin/env/python3
  2. # Copyright (c) Facebook, Inc. and its affiliates.
  3. # All rights reserved.
  4. #
  5. # This source code is licensed under the BSD-style license found in the
  6. # LICENSE file in the root directory of this source tree.
  7. """
  8. Module contains events processing mechanisms that are integrated with the standard python logging.
  9. Example of usage:
  10. ::
  11. from torch.distributed.elastic import events
  12. event = events.Event(name="test_event", source=events.EventSource.WORKER, metadata={...})
  13. events.get_logging_handler(destination="console").info(event)
  14. """
  15. import inspect
  16. import logging
  17. import os
  18. import socket
  19. import traceback
  20. from enum import Enum
  21. from typing import Dict, Optional
  22. from torch.distributed.elastic.events.handlers import get_logging_handler
  23. from .api import ( # noqa: F401
  24. Event,
  25. EventMetadataValue,
  26. EventSource,
  27. NodeState,
  28. RdzvEvent,
  29. )
  30. _events_loggers: Dict[str, logging.Logger] = {}
  31. def _get_or_create_logger(destination: str = "null") -> logging.Logger:
  32. """
  33. Constructs python logger based on the destination type or extends if provided.
  34. Available destination could be found in ``handlers.py`` file.
  35. The constructed logger does not propagate messages to the upper level loggers,
  36. e.g. root logger. This makes sure that a single event can be processed once.
  37. Args:
  38. destination: The string representation of the event handler.
  39. Available handlers found in ``handlers`` module
  40. """
  41. global _events_loggers
  42. if destination not in _events_loggers:
  43. _events_logger = logging.getLogger(f"torchelastic-events-{destination}")
  44. _events_logger.setLevel(os.environ.get("LOGLEVEL", "INFO"))
  45. # Do not propagate message to the root logger
  46. _events_logger.propagate = False
  47. logging_handler = get_logging_handler(destination)
  48. _events_logger.addHandler(logging_handler)
  49. # Add the logger to the global dictionary
  50. _events_loggers[destination] = _events_logger
  51. return _events_loggers[destination]
  52. def record(event: Event, destination: str = "null") -> None:
  53. _get_or_create_logger(destination).info(event.serialize())
  54. def record_rdzv_event(event: RdzvEvent) -> None:
  55. _get_or_create_logger("dynamic_rendezvous").info(event.serialize())
  56. def construct_and_record_rdzv_event(
  57. run_id: str,
  58. message: str,
  59. node_state: NodeState,
  60. name: str = "",
  61. hostname: str = "",
  62. pid: Optional[int] = None,
  63. master_endpoint: str = "",
  64. local_id: Optional[int] = None,
  65. rank: Optional[int] = None,
  66. ) -> None:
  67. # We don't want to perform an extra computation if not needed.
  68. if isinstance(get_logging_handler("dynamic_rendezvous"), logging.NullHandler):
  69. return
  70. # Set up parameters.
  71. if not hostname:
  72. hostname = socket.getfqdn()
  73. if not pid:
  74. pid = os.getpid()
  75. # Determines which file called this function.
  76. callstack = inspect.stack()
  77. filename = "no_file"
  78. if len(callstack) > 1:
  79. stack_depth_1 = callstack[1]
  80. filename = os.path.basename(stack_depth_1.filename)
  81. if not name:
  82. name = stack_depth_1.function
  83. # Delete the callstack variable. If kept, this can mess with python's
  84. # garbage collector as we are holding on to stack frame information in
  85. # the inspect module.
  86. del callstack
  87. # Set up error trace if this is an exception
  88. if node_state == NodeState.FAILED:
  89. error_trace = traceback.format_exc()
  90. else:
  91. error_trace = ""
  92. # Initialize event object
  93. event = RdzvEvent(
  94. name=f"{filename}:{name}",
  95. run_id=run_id,
  96. message=message,
  97. hostname=hostname,
  98. pid=pid,
  99. node_state=node_state,
  100. master_endpoint=master_endpoint,
  101. rank=rank,
  102. local_id=local_id,
  103. error_trace=error_trace,
  104. )
  105. # Finally, record the event.
  106. record_rdzv_event(event)