api.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. from typing import Dict, Tuple, Any
  2. import traceback as tb
  3. WRAPPED_EXCEPTION = Tuple[BaseException, tb.StackSummary]
  4. __all__ = ["CheckpointException"]
  5. def _wrap_exception(exc: BaseException) -> WRAPPED_EXCEPTION:
  6. return (exc, tb.extract_tb(exc.__traceback__))
  7. def _is_wrapped_exception(obj: Any) -> bool:
  8. if not isinstance(obj, tuple):
  9. return False
  10. if len(obj) != 2:
  11. return False
  12. return isinstance(obj[0], BaseException) and isinstance(
  13. obj[1], tb.StackSummary
  14. )
  15. class CheckpointException(BaseException):
  16. """
  17. Exception raised if failure was detected as part of a checkpoint load or save.
  18. """
  19. def __init__(self, msg: str, failures: Dict[int, WRAPPED_EXCEPTION]):
  20. super().__init__(msg, failures)
  21. self._failures = failures
  22. @property
  23. def failures(self) -> Dict[int, WRAPPED_EXCEPTION]:
  24. """
  25. Returns:
  26. Dict of failed nodes and their associated exception.
  27. Keys are node ranks and values are exceptions
  28. """
  29. return self._failures
  30. def __str__(self):
  31. str = f"CheckpointException ranks:{self._failures.keys()}\n"
  32. for rank, exc_pair in self._failures.items():
  33. exc, trace = exc_pair
  34. str += f"Traceback (most recent call last): (RANK {rank})\n"
  35. if trace is not None:
  36. str += "".join(tb.format_list(trace))
  37. str += "".join(tb.format_exception_only(type(exc), value=exc))
  38. return str