checkpoint_utils.py 1.0 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. import shutil
  3. import tempfile
  4. from functools import wraps
  5. from typing import Any, Callable, Dict, Optional, Tuple
  6. import torch.distributed as dist
  7. def with_temp_dir(
  8. func: Optional[Callable] = None,
  9. ) -> Optional[Callable]:
  10. """
  11. Wrapper to initialize temp directory for distributed checkpoint.
  12. """
  13. assert func is not None
  14. @wraps(func)
  15. def wrapper(self, *args: Tuple[object], **kwargs: Dict[str, Any]) -> None:
  16. # Only create temp_dir when rank is 0
  17. if dist.get_rank() == 0:
  18. temp_dir = tempfile.mkdtemp()
  19. print(f"Using temp directory: {temp_dir}")
  20. else:
  21. temp_dir = ""
  22. object_list = [temp_dir]
  23. # Broadcast temp_dir to all the other ranks
  24. dist.broadcast_object_list(object_list)
  25. self.temp_dir = object_list[0]
  26. try:
  27. func(self)
  28. finally:
  29. if dist.get_rank() == 0:
  30. shutil.rmtree(self.temp_dir, ignore_errors=True)
  31. return wrapper