dispatch.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. from typing import Callable, cast, Dict, Tuple, Union, Optional
  3. import torch
  4. import torch.distributed._tensor.api as dtensor
  5. from torch.distributed._tensor.op_schema import (
  6. ArgsType,
  7. KwargsType,
  8. OutputSpecType,
  9. )
  10. from torch.distributed._tensor.placement_types import DTensorSpec
  11. from torch.distributed._tensor.sharding_prop import ShardingPropagator
  12. from torch.distributed._tensor.redistribute import redistribute_dtensor
  13. from torch.utils._pytree import tree_flatten, tree_unflatten
  14. """
  15. If _ENABLE_FALLBACK set to False, dispatch will fail when an op doesn't
  16. have a sharding rule registered.
  17. """
  18. _ENABLE_FALLBACK = False
  19. def wrap(res: object, spec: OutputSpecType) -> object:
  20. if isinstance(res, torch.Tensor):
  21. assert spec is not None and isinstance(
  22. spec, DTensorSpec
  23. ), f"output spec does not match with output! Expected DTensorSpec, got {spec}."
  24. return dtensor.DTensor(
  25. res,
  26. spec.mesh,
  27. spec.placements,
  28. size=spec.shape,
  29. requires_grad=res.requires_grad,
  30. )
  31. elif isinstance(res, list):
  32. assert spec is not None and isinstance(
  33. spec, list
  34. ), f"output spec does not match with output! Expected list, got {spec}."
  35. return [
  36. dtensor.DTensor(e, s.mesh, s.placements, size=s.shape)
  37. for e, s in zip(res, spec)
  38. ]
  39. elif isinstance(res, tuple):
  40. assert spec is not None and isinstance(
  41. spec, tuple
  42. ), f"output spec does not match with output! Expected tuple, got {spec}"
  43. # NOTE: local results might return Optional Tensor from ATen op, so we need to
  44. # handle that case and make sure we don't wrap None with DTensor.
  45. # (i.e. native_layer_norm.backward)
  46. return tuple(
  47. dtensor.DTensor(e, s.mesh, s.placements, size=s.shape)
  48. if e is not None and s is not None
  49. else None
  50. for e, s in zip(res, spec)
  51. )
  52. else:
  53. # if the res contains only non tensor values, we simply return it without rewrapping
  54. return res
  55. def pack_args_kwargs_with_local_tensor(
  56. args: Union[ArgsType, KwargsType],
  57. args_schema: Union[ArgsType, KwargsType],
  58. redistribute_with_schema: bool = False,
  59. ) -> Union[ArgsType, KwargsType]:
  60. flatten_args, args_tree_spec = tree_flatten(args)
  61. flatten_args_schema, _ = tree_flatten(args_schema)
  62. for i, arg in enumerate(flatten_args):
  63. if isinstance(arg, dtensor.DTensor):
  64. if redistribute_with_schema:
  65. target_spec = flatten_args_schema[i]
  66. arg = redistribute_dtensor(
  67. arg, target_spec.mesh, target_spec.placements
  68. )
  69. # reuse the schema list and update it with local tensor
  70. flatten_args_schema[i] = arg._local_tensor
  71. return tree_unflatten(flatten_args_schema, args_tree_spec)
  72. def _reshape_alias(
  73. x: torch.Tensor, shape: Tuple[int, ...], strides: Tuple[int, ...]
  74. ) -> torch.Tensor:
  75. return torch.ops.aten.view(x, shape)
  76. _CURRENT_DECOMPOSITION_TABLE: Dict[Callable[..., object], Callable[..., object]] = {
  77. torch.ops.aten._reshape_alias.default: _reshape_alias,
  78. }
  79. def operator_dispatch(
  80. op_call: torch._ops.OpOverload,
  81. args: Tuple[object, ...],
  82. kwargs: Dict[str, object],
  83. sharding_propagator: ShardingPropagator,
  84. custom_dispatch_ops: Optional[Dict[str, Callable[..., object]]] = None,
  85. ) -> object:
  86. # first we need to lift some private aten aliases to public calls
  87. if op_call in _CURRENT_DECOMPOSITION_TABLE:
  88. return _CURRENT_DECOMPOSITION_TABLE[op_call](*args, **kwargs)
  89. # STEP 0. See if there's a user defined custom aten operator
  90. # implementations. Custom operators take the highest priority
  91. if custom_dispatch_ops is not None and str(op_call) in custom_dispatch_ops:
  92. # dispatch to user defined custom distributed tensor ops
  93. return custom_dispatch_ops[str(op_call)](*args, **kwargs)
  94. # unwrap the args/kwargs schema
  95. op_schema = sharding_propagator.prepare_op_schema(op_call, args, kwargs)
  96. output_sharding = sharding_propagator.propagate_op_sharding(op_call, op_schema)
  97. # if the schema suggestion from sharding prop is not the same instance as the
  98. # input op_schema, it indicates a reshard, we need to redistribute the input
  99. # tensors before calling the local op
  100. assert output_sharding.schema_suggestions is not None
  101. needs_redistribute = output_sharding.schema_suggestions[0] is not op_schema
  102. suggested_input_schema = output_sharding.schema_suggestions[0]
  103. local_tensor_args = pack_args_kwargs_with_local_tensor(
  104. args,
  105. suggested_input_schema.args_schema,
  106. redistribute_with_schema=needs_redistribute,
  107. )
  108. local_tensor_kwargs = pack_args_kwargs_with_local_tensor(
  109. kwargs,
  110. suggested_input_schema.kwargs_schema,
  111. redistribute_with_schema=needs_redistribute,
  112. )
  113. # run local op computation with potentially modified args/kwargs
  114. local_tensor_args = cast(Tuple[object, ...], local_tensor_args)
  115. local_tensor_kwargs = cast(Dict[str, object], local_tensor_kwargs)
  116. local_results = op_call(*local_tensor_args, **local_tensor_kwargs)
  117. if suggested_input_schema.is_inplace:
  118. # inplace op should return self instead of re-wrapping
  119. self = cast(dtensor.DTensor, args[0])
  120. self._spec = cast(DTensorSpec, output_sharding.output_spec)
  121. return self
  122. elif suggested_input_schema.is_out_variant:
  123. # out variant could possibly have multiple out args (i.e. lu_unpack.out)
  124. output_specs = (
  125. (output_sharding.output_spec,)
  126. if not isinstance(output_sharding.output_spec, tuple)
  127. else output_sharding.output_spec
  128. )
  129. out_dts = []
  130. spec_idx = 0
  131. for arg in suggested_input_schema.func_schema.arguments:
  132. if arg.is_out:
  133. out_dt = cast(dtensor.DTensor, kwargs[arg.name])
  134. out_dt._spec = cast(DTensorSpec, output_specs[spec_idx])
  135. out_dts.append(out_dt)
  136. spec_idx += 1
  137. assert len(out_dts) >= 1, "out variant should have at least one out arg"
  138. return tuple(out_dts) if len(out_dts) > 1 else out_dts[0]
  139. else:
  140. return wrap(local_results, output_sharding.output_spec)