12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849 |
- from typing import Dict, Tuple, Any
- import traceback as tb
- WRAPPED_EXCEPTION = Tuple[BaseException, tb.StackSummary]
- __all__ = ["CheckpointException"]
- def _wrap_exception(exc: BaseException) -> WRAPPED_EXCEPTION:
- return (exc, tb.extract_tb(exc.__traceback__))
- def _is_wrapped_exception(obj: Any) -> bool:
- if not isinstance(obj, tuple):
- return False
- if len(obj) != 2:
- return False
- return isinstance(obj[0], BaseException) and isinstance(
- obj[1], tb.StackSummary
- )
- class CheckpointException(BaseException):
- """
- Exception raised if failure was detected as part of a checkpoint load or save.
- """
- def __init__(self, msg: str, failures: Dict[int, WRAPPED_EXCEPTION]):
- super().__init__(msg, failures)
- self._failures = failures
- @property
- def failures(self) -> Dict[int, WRAPPED_EXCEPTION]:
- """
- Returns:
- Dict of failed nodes and their associated exception.
- Keys are node ranks and values are exceptions
- """
- return self._failures
- def __str__(self):
- str = f"CheckpointException ranks:{self._failures.keys()}\n"
- for rank, exc_pair in self._failures.items():
- exc, trace = exc_pair
- str += f"Traceback (most recent call last): (RANK {rank})\n"
- if trace is not None:
- str += "".join(tb.format_list(trace))
- str += "".join(tb.format_exception_only(type(exc), value=exc))
- return str
|