sharding_prop.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. from typing import Callable, Dict, Tuple
  2. import torch
  3. import torch.distributed._tensor.api as dtensor
  4. from torch._ops import OpOverload
  5. from torch.distributed._tensor.op_schema import OpSchema, OutputSharding
  6. from torch.utils._pytree import tree_map
  7. """
  8. Print information on ops input shape and sharding for debugging purposes.
  9. """
  10. _DEBUG_VERBOSE = False
  11. def unwrap_schema(e: object) -> object:
  12. return e._spec if isinstance(e, dtensor.DTensor) else e
  13. class ShardingPropagator:
  14. def __init__(self) -> None:
  15. self.op_to_rules: Dict[OpOverload, Callable[[OpSchema], OutputSharding]] = {}
  16. def register_sharding_prop_rule(
  17. self, op_overload: OpOverload, rule_func: Callable[[OpSchema], OutputSharding]
  18. ):
  19. """
  20. Register a sharding propagation rule for an operator.
  21. """
  22. self.op_to_rules[op_overload] = rule_func
  23. def prepare_op_schema(
  24. self,
  25. op_call: OpOverload,
  26. args: Tuple[object, ...],
  27. kwargs: Dict[str, object]
  28. ) -> OpSchema:
  29. """
  30. This unwrap the args/kwargs DTensor to DTensorSpec and pack them
  31. into an OpSchema for sharding propagation usage.
  32. """
  33. args_schema = tree_map(unwrap_schema, args)
  34. kwargs_schema = tree_map(unwrap_schema, kwargs)
  35. op_schema = OpSchema(op_call._schema, args_schema, kwargs_schema)
  36. if _DEBUG_VERBOSE and torch.distributed.get_rank() == 0:
  37. print(f"OpSchema({op_schema})")
  38. local_shapes = tree_map(
  39. lambda t: t.to_local().shape if isinstance(t, dtensor.DTensor) else None,
  40. args,
  41. )
  42. print(f" local shapes: {local_shapes}")
  43. return op_schema
  44. def propagate_op_sharding(
  45. self, op_overload: OpOverload, op_schema: OpSchema
  46. ) -> OutputSharding:
  47. """
  48. Propagate the sharding for an operator given the op_schema.
  49. """
  50. sharding_prop_func = self.op_to_rules.get(op_overload, None)
  51. if sharding_prop_func is None:
  52. # step 1. If there's not even one sharding rule
  53. # implemented for the operator, we error out.
  54. raise NotImplementedError(
  55. f"Operator {op_overload} does not have a DistributedTensor rule registered."
  56. )
  57. # step 2. there's sharding propagation rule, run
  58. # sharding propagation to get the output sharding
  59. try:
  60. output_sharding = sharding_prop_func(op_schema)
  61. except Exception as e:
  62. raise RuntimeError(
  63. f"Sharding propagation failed on op {op_overload}.\n"
  64. f"Input schema: {op_schema}.\n"
  65. f"Error: {e}"
  66. ) from e
  67. # step 3. if can't get output_spec from sharding
  68. # propagation (i.e. no rules apply for input
  69. # placements), we return the output sharding
  70. # with schema suggestions, which can be used to
  71. # decide how to do redistribute on inputs
  72. if output_sharding.output_spec is None:
  73. if output_sharding.schema_suggestions is None:
  74. raise RuntimeError(
  75. f"Sharding propagation failed on op {op_overload}!"
  76. f"Input schema: {op_schema}."
  77. f"Failed reason: {output_sharding.failed_reason}"
  78. )
  79. else:
  80. # we do auto redistribute on inputs if necessary
  81. # to get an eligble input, which we will pick a
  82. # schema suggestion base on the redistribute cost.
  83. # For now we simply pick the first suggestion.
  84. # TODO: implement full auto distribute with a
  85. # simple cost estimation model
  86. suggested_input_schema = output_sharding.schema_suggestions[0]
  87. # run sharding propagation again with suggested schema
  88. propagation_res = sharding_prop_func(suggested_input_schema)
  89. # we set the output sharding with the new propagation result
  90. # so that dispatching know both output_spec and schema_suggestions
  91. # exist, which indicates a reshard is needed
  92. output_sharding.output_spec = propagation_res.output_spec
  93. else:
  94. # if sharding propagation succeed, we set the schema suggestion to
  95. # the default op_schema, which indicates no reshard is needed
  96. output_sharding.schema_suggestions = [op_schema]
  97. return output_sharding
  98. class _CachingPropagator(ShardingPropagator):
  99. """
  100. A sharding propagator that caches the propagation results.
  101. This is currently experimental for Tensor Parallel usage.
  102. """
  103. def __init__(self, op_to_rules=None) -> None:
  104. super().__init__()
  105. if op_to_rules is not None:
  106. self.op_to_rules = op_to_rules
  107. # cache table for sharding propagation results, we might need to
  108. # limit the size of the cache table in the future
  109. self.cached_prop_results: Dict[OpSchema, OutputSharding] = {}
  110. def propagate_op_sharding(
  111. self, op_overload: OpOverload, op_schema: OpSchema
  112. ) -> OutputSharding:
  113. """
  114. Propagate the sharding for an operator given the op_schema.
  115. Cache the propagation results to avoid running propagation again.
  116. """
  117. if op_schema in self.cached_prop_results:
  118. return self.cached_prop_results[op_schema]
  119. else:
  120. # call DTensor's propagate_op_sharding to get the prop result
  121. output_sharding = super().propagate_op_sharding(op_overload, op_schema)
  122. # update cached table
  123. self.cached_prop_results[op_schema] = output_sharding
  124. return output_sharding