logging.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. import itertools
  2. import logging
  3. import os
  4. from torch.hub import _Faketqdm, tqdm
  5. # logging level for dynamo generated graphs/bytecode/guards
  6. logging.CODE = 15
  7. logging.addLevelName(logging.CODE, "CODE")
  8. # Disable progress bar by default, not in dynamo config because otherwise get a circular import
  9. disable_progress = True
  10. # Return all loggers that torchdynamo/torchinductor is responsible for
  11. def get_loggers():
  12. return [
  13. logging.getLogger("torch._dynamo"),
  14. logging.getLogger("torch._inductor"),
  15. ]
  16. # Set the level of all loggers that torchdynamo is responsible for
  17. def set_loggers_level(level):
  18. """Write current log level"""
  19. for logger in get_loggers():
  20. logger.setLevel(level)
  21. def get_loggers_level():
  22. """Read current log level"""
  23. return get_loggers()[0].level
  24. LOGGING_CONFIG = {
  25. "version": 1,
  26. "formatters": {
  27. "torchdynamo_format": {
  28. "format": "[%(asctime)s] %(name)s: [%(levelname)s] %(message)s"
  29. },
  30. },
  31. "handlers": {
  32. "torchdynamo_console": {
  33. "class": "logging.StreamHandler",
  34. "level": "DEBUG",
  35. "formatter": "torchdynamo_format",
  36. "stream": "ext://sys.stderr",
  37. },
  38. },
  39. "loggers": {
  40. "torch._dynamo": {
  41. "level": "DEBUG",
  42. "handlers": ["torchdynamo_console"],
  43. "propagate": False,
  44. },
  45. "torch._inductor": {
  46. "level": "DEBUG",
  47. "handlers": ["torchdynamo_console"],
  48. "propagate": False,
  49. },
  50. },
  51. "disable_existing_loggers": False,
  52. }
  53. # initialize torchdynamo loggers
  54. def init_logging(log_level, log_file_name=None):
  55. if "PYTEST_CURRENT_TEST" not in os.environ:
  56. logging.config.dictConfig(LOGGING_CONFIG)
  57. if log_file_name is not None:
  58. log_file = logging.FileHandler(log_file_name)
  59. log_file.setLevel(log_level)
  60. for logger in get_loggers():
  61. logger.addHandler(log_file)
  62. if bool(os.environ.get("TORCH_COMPILE_DEBUG", False)):
  63. from .utils import get_debug_dir
  64. log_level = logging.DEBUG
  65. log_path = os.path.join(get_debug_dir(), "torchdynamo")
  66. if not os.path.exists(log_path):
  67. os.makedirs(log_path)
  68. log_file = logging.FileHandler(os.path.join(log_path, "debug.log"))
  69. log_file.setLevel(logging.DEBUG)
  70. logger = logging.getLogger("torch._dynamo")
  71. logger.addHandler(log_file)
  72. set_loggers_level(log_level)
  73. # Creates a logging function that logs a message with a step # prepended.
  74. # get_step_logger should be lazily called (i.e. at runtime, not at module-load time)
  75. # so that step numbers are initialized properly. e.g.:
  76. # @functools.lru_cache(None)
  77. # def _step_logger():
  78. # return get_step_logger(logging.getLogger(...))
  79. # def fn():
  80. # _step_logger()(logging.INFO, "msg")
  81. _step_counter = itertools.count(1)
  82. # Update num_steps if more phases are added: Dynamo, AOT, Backend
  83. # This is very inductor centric
  84. # _inductor.utils.has_triton() gives a circular import error here
  85. if not disable_progress:
  86. try:
  87. import triton # noqa: F401
  88. num_steps = 3
  89. except ImportError:
  90. num_steps = 2
  91. pbar = tqdm(total=num_steps, desc="torch.compile()", delay=0)
  92. def get_step_logger(logger):
  93. if not disable_progress:
  94. pbar.update(1)
  95. if not isinstance(pbar, _Faketqdm):
  96. pbar.set_postfix_str(f"{logger.name}")
  97. step = next(_step_counter)
  98. def log(level, msg):
  99. logger.log(level, f"Step {step}: {msg}")
  100. return log