_dedup_tensors.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. import dataclasses
  3. import logging
  4. from typing import Dict, List
  5. from torch.distributed.checkpoint.metadata import MetadataIndex
  6. from torch.distributed.checkpoint.planner import SavePlan
  7. __all__ = ["dedup_tensors"]
  8. def init_logger() -> logging.Logger:
  9. logger = logging.getLogger(__name__)
  10. level = logging.INFO
  11. logger.setLevel(level)
  12. console = logging.StreamHandler()
  13. formatter = logging.Formatter(
  14. "%(asctime)s %(filename)s:%(lineno)s %(levelname)s p:%(processName)s t:%(threadName)s: %(message)s"
  15. )
  16. console.setFormatter(formatter)
  17. console.setLevel(level)
  18. logger.addHandler(console)
  19. logger.propagate = False
  20. return logger
  21. logger = init_logger()
  22. # TODO add docstring for dedup_tensors
  23. def dedup_tensors(all_plans: List[SavePlan]) -> List[SavePlan]:
  24. all_plans = list(all_plans)
  25. key_to_plan: Dict[MetadataIndex, List[int]] = {}
  26. for plan_idx, plan in enumerate(all_plans):
  27. for write_item in plan.items:
  28. key_to_plan.setdefault(write_item.index, []).append(plan_idx)
  29. replicated_items = {k: v for k, v in key_to_plan.items() if len(v) > 1}
  30. # Remove duplicates by always keeping the first entry.
  31. # Compute the per-rank remove set.
  32. plan_to_keys: Dict[int, List[MetadataIndex]] = {}
  33. for key, plans in replicated_items.items():
  34. for plan_idx in plans[1:]:
  35. plan_to_keys.setdefault(plan_idx, []).append(key)
  36. logger.info(f"Duplicate keys to remove: {plan_to_keys}")
  37. for plan_idx, keys in plan_to_keys.items():
  38. key_set = set(keys)
  39. # rewrite items and remove elements
  40. new_items = [
  41. write_item
  42. for write_item in all_plans[plan_idx].items
  43. if write_item.index not in key_set
  44. ]
  45. all_plans[plan_idx] = dataclasses.replace(
  46. all_plans[plan_idx], items=new_items
  47. )
  48. return all_plans