123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145 |
- from typing import Callable, Dict, Tuple
- import torch
- import torch.distributed._tensor.api as dtensor
- from torch._ops import OpOverload
- from torch.distributed._tensor.op_schema import OpSchema, OutputSharding
- from torch.utils._pytree import tree_map
- """
- Print information on ops input shape and sharding for debugging purposes.
- """
- _DEBUG_VERBOSE = False
- def unwrap_schema(e: object) -> object:
- return e._spec if isinstance(e, dtensor.DTensor) else e
- class ShardingPropagator:
- def __init__(self) -> None:
- self.op_to_rules: Dict[OpOverload, Callable[[OpSchema], OutputSharding]] = {}
- def register_sharding_prop_rule(
- self, op_overload: OpOverload, rule_func: Callable[[OpSchema], OutputSharding]
- ):
- """
- Register a sharding propagation rule for an operator.
- """
- self.op_to_rules[op_overload] = rule_func
- def prepare_op_schema(
- self,
- op_call: OpOverload,
- args: Tuple[object, ...],
- kwargs: Dict[str, object]
- ) -> OpSchema:
- """
- This unwrap the args/kwargs DTensor to DTensorSpec and pack them
- into an OpSchema for sharding propagation usage.
- """
- args_schema = tree_map(unwrap_schema, args)
- kwargs_schema = tree_map(unwrap_schema, kwargs)
- op_schema = OpSchema(op_call._schema, args_schema, kwargs_schema)
- if _DEBUG_VERBOSE and torch.distributed.get_rank() == 0:
- print(f"OpSchema({op_schema})")
- local_shapes = tree_map(
- lambda t: t.to_local().shape if isinstance(t, dtensor.DTensor) else None,
- args,
- )
- print(f" local shapes: {local_shapes}")
- return op_schema
- def propagate_op_sharding(
- self, op_overload: OpOverload, op_schema: OpSchema
- ) -> OutputSharding:
- """
- Propagate the sharding for an operator given the op_schema.
- """
- sharding_prop_func = self.op_to_rules.get(op_overload, None)
- if sharding_prop_func is None:
- # step 1. If there's not even one sharding rule
- # implemented for the operator, we error out.
- raise NotImplementedError(
- f"Operator {op_overload} does not have a DistributedTensor rule registered."
- )
- # step 2. there's sharding propagation rule, run
- # sharding propagation to get the output sharding
- try:
- output_sharding = sharding_prop_func(op_schema)
- except Exception as e:
- raise RuntimeError(
- f"Sharding propagation failed on op {op_overload}.\n"
- f"Input schema: {op_schema}.\n"
- f"Error: {e}"
- ) from e
- # step 3. if can't get output_spec from sharding
- # propagation (i.e. no rules apply for input
- # placements), we return the output sharding
- # with schema suggestions, which can be used to
- # decide how to do redistribute on inputs
- if output_sharding.output_spec is None:
- if output_sharding.schema_suggestions is None:
- raise RuntimeError(
- f"Sharding propagation failed on op {op_overload}!"
- f"Input schema: {op_schema}."
- f"Failed reason: {output_sharding.failed_reason}"
- )
- else:
- # we do auto redistribute on inputs if necessary
- # to get an eligble input, which we will pick a
- # schema suggestion base on the redistribute cost.
- # For now we simply pick the first suggestion.
- # TODO: implement full auto distribute with a
- # simple cost estimation model
- suggested_input_schema = output_sharding.schema_suggestions[0]
- # run sharding propagation again with suggested schema
- propagation_res = sharding_prop_func(suggested_input_schema)
- # we set the output sharding with the new propagation result
- # so that dispatching know both output_spec and schema_suggestions
- # exist, which indicates a reshard is needed
- output_sharding.output_spec = propagation_res.output_spec
- else:
- # if sharding propagation succeed, we set the schema suggestion to
- # the default op_schema, which indicates no reshard is needed
- output_sharding.schema_suggestions = [op_schema]
- return output_sharding
- class _CachingPropagator(ShardingPropagator):
- """
- A sharding propagator that caches the propagation results.
- This is currently experimental for Tensor Parallel usage.
- """
- def __init__(self, op_to_rules=None) -> None:
- super().__init__()
- if op_to_rules is not None:
- self.op_to_rules = op_to_rules
- # cache table for sharding propagation results, we might need to
- # limit the size of the cache table in the future
- self.cached_prop_results: Dict[OpSchema, OutputSharding] = {}
- def propagate_op_sharding(
- self, op_overload: OpOverload, op_schema: OpSchema
- ) -> OutputSharding:
- """
- Propagate the sharding for an operator given the op_schema.
- Cache the propagation results to avoid running propagation again.
- """
- if op_schema in self.cached_prop_results:
- return self.cached_prop_results[op_schema]
- else:
- # call DTensor's propagate_op_sharding to get the prop result
- output_sharding = super().propagate_op_sharding(op_overload, op_schema)
- # update cached table
- self.cached_prop_results[op_schema] = output_sharding
- return output_sharding
|