123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129 |
- import itertools
- import logging
- import os
- from torch.hub import _Faketqdm, tqdm
- # logging level for dynamo generated graphs/bytecode/guards
- logging.CODE = 15
- logging.addLevelName(logging.CODE, "CODE")
- # Disable progress bar by default, not in dynamo config because otherwise get a circular import
- disable_progress = True
- # Return all loggers that torchdynamo/torchinductor is responsible for
- def get_loggers():
- return [
- logging.getLogger("torch._dynamo"),
- logging.getLogger("torch._inductor"),
- ]
- # Set the level of all loggers that torchdynamo is responsible for
- def set_loggers_level(level):
- """Write current log level"""
- for logger in get_loggers():
- logger.setLevel(level)
- def get_loggers_level():
- """Read current log level"""
- return get_loggers()[0].level
- LOGGING_CONFIG = {
- "version": 1,
- "formatters": {
- "torchdynamo_format": {
- "format": "[%(asctime)s] %(name)s: [%(levelname)s] %(message)s"
- },
- },
- "handlers": {
- "torchdynamo_console": {
- "class": "logging.StreamHandler",
- "level": "DEBUG",
- "formatter": "torchdynamo_format",
- "stream": "ext://sys.stderr",
- },
- },
- "loggers": {
- "torch._dynamo": {
- "level": "DEBUG",
- "handlers": ["torchdynamo_console"],
- "propagate": False,
- },
- "torch._inductor": {
- "level": "DEBUG",
- "handlers": ["torchdynamo_console"],
- "propagate": False,
- },
- },
- "disable_existing_loggers": False,
- }
- # initialize torchdynamo loggers
- def init_logging(log_level, log_file_name=None):
- if "PYTEST_CURRENT_TEST" not in os.environ:
- logging.config.dictConfig(LOGGING_CONFIG)
- if log_file_name is not None:
- log_file = logging.FileHandler(log_file_name)
- log_file.setLevel(log_level)
- for logger in get_loggers():
- logger.addHandler(log_file)
- if bool(os.environ.get("TORCH_COMPILE_DEBUG", False)):
- from .utils import get_debug_dir
- log_level = logging.DEBUG
- log_path = os.path.join(get_debug_dir(), "torchdynamo")
- if not os.path.exists(log_path):
- os.makedirs(log_path)
- log_file = logging.FileHandler(os.path.join(log_path, "debug.log"))
- log_file.setLevel(logging.DEBUG)
- logger = logging.getLogger("torch._dynamo")
- logger.addHandler(log_file)
- set_loggers_level(log_level)
- # Creates a logging function that logs a message with a step # prepended.
- # get_step_logger should be lazily called (i.e. at runtime, not at module-load time)
- # so that step numbers are initialized properly. e.g.:
- # @functools.lru_cache(None)
- # def _step_logger():
- # return get_step_logger(logging.getLogger(...))
- # def fn():
- # _step_logger()(logging.INFO, "msg")
- _step_counter = itertools.count(1)
- # Update num_steps if more phases are added: Dynamo, AOT, Backend
- # This is very inductor centric
- # _inductor.utils.has_triton() gives a circular import error here
- if not disable_progress:
- try:
- import triton # noqa: F401
- num_steps = 3
- except ImportError:
- num_steps = 2
- pbar = tqdm(total=num_steps, desc="torch.compile()", delay=0)
- def get_step_logger(logger):
- if not disable_progress:
- pbar.update(1)
- if not isinstance(pbar, _Faketqdm):
- pbar.set_postfix_str(f"{logger.name}")
- step = next(_step_counter)
- def log(level, msg):
- logger.log(level, f"Step {step}: {msg}")
- return log
|