12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758 |
- # Copyright (c) Meta Platforms, Inc. and affiliates
- import dataclasses
- import logging
- from typing import Dict, List
- from torch.distributed.checkpoint.metadata import MetadataIndex
- from torch.distributed.checkpoint.planner import SavePlan
- __all__ = ["dedup_tensors"]
- def init_logger() -> logging.Logger:
- logger = logging.getLogger(__name__)
- level = logging.INFO
- logger.setLevel(level)
- console = logging.StreamHandler()
- formatter = logging.Formatter(
- "%(asctime)s %(filename)s:%(lineno)s %(levelname)s p:%(processName)s t:%(threadName)s: %(message)s"
- )
- console.setFormatter(formatter)
- console.setLevel(level)
- logger.addHandler(console)
- logger.propagate = False
- return logger
- logger = init_logger()
- # TODO add docstring for dedup_tensors
- def dedup_tensors(all_plans: List[SavePlan]) -> List[SavePlan]:
- all_plans = list(all_plans)
- key_to_plan: Dict[MetadataIndex, List[int]] = {}
- for plan_idx, plan in enumerate(all_plans):
- for write_item in plan.items:
- key_to_plan.setdefault(write_item.index, []).append(plan_idx)
- replicated_items = {k: v for k, v in key_to_plan.items() if len(v) > 1}
- # Remove duplicates by always keeping the first entry.
- # Compute the per-rank remove set.
- plan_to_keys: Dict[int, List[MetadataIndex]] = {}
- for key, plans in replicated_items.items():
- for plan_idx in plans[1:]:
- plan_to_keys.setdefault(plan_idx, []).append(key)
- logger.info(f"Duplicate keys to remove: {plan_to_keys}")
- for plan_idx, keys in plan_to_keys.items():
- key_set = set(keys)
- # rewrite items and remove elements
- new_items = [
- write_item
- for write_item in all_plans[plan_idx].items
- if write_item.index not in key_set
- ]
- all_plans[plan_idx] = dataclasses.replace(
- all_plans[plan_idx], items=new_items
- )
- return all_plans
|