op_schema.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. from dataclasses import dataclass
  2. from typing import Dict, List, Optional, Sequence, Tuple, Union
  3. import torch
  4. from torch.distributed._tensor.placement_types import DTensorSpec
  5. # Common type aliases
  6. ArgsType = Tuple[object, ...]
  7. KwargsType = Dict[str, object]
  8. # ATen op schemas could have Tensor, Tuple[Tensor] and List[Tensor], so output type sould
  9. # be the same set of possiblities.
  10. OutputSpecType = Optional[Union[DTensorSpec, Sequence[Optional[DTensorSpec]]]]
  11. @dataclass
  12. class OpSchema:
  13. """
  14. OpSchema is a data class that describes an operator input schemas, it
  15. includes DTensor DTensorSpecs and non-tensor args/kwargs (positional order
  16. preserved). It is mainly used by the dispatching logic below to run things like
  17. sharding propagation.
  18. Sharding propagation rules registered could utilize this data class and
  19. do inplace update some fields (when necessary, i.e shape related ops) to make
  20. sure the args/kwargs are legit before passing to the local tensor operator.
  21. This is the main reason that we don't freeze this dataclass.
  22. NOTE: greater access to the operator inputs comes with greater responsibility.
  23. Here are some basic rules about what can be used and what can be changed.
  24. Args:
  25. func_schema: the function schema of the operator
  26. args_schema: contains args except that the DTensor args have been replaced
  27. with its DTensorSpec
  28. kwargs_schema: contains kwargs except that the DTensor kwargs have been replaced
  29. with its DTensorSpec
  30. What can be used:
  31. - every attribute within this class could be read to conduct
  32. sharding propagation.
  33. What can be changed:
  34. - only the args_schema and kwargs_schema could be changed.
  35. - every non-tensor args could be changed to accomodate for local tensor
  36. operations (i.e. for ops like view/reshape/...)
  37. - every "DTensorSpec" attribute inside `args_schema`, `kwargs_schema` and
  38. `args_spec` SHOULD NOT be updated! DTensorSpec are read only and sharding
  39. propagation shouldn't inplace update them, otherwise the input DTensor
  40. placements will get implicitly changed and it's error-prone.
  41. """
  42. func_schema: torch._C.FunctionSchema
  43. args_schema: ArgsType
  44. kwargs_schema: KwargsType
  45. is_inplace: bool = False
  46. is_out_variant: bool = False
  47. def __post_init__(self) -> None:
  48. # simple analysis of function schema to determine
  49. # if this is an inplace/out variant, it might not
  50. # be entirely correct, but it's good enough for now.
  51. self.is_inplace = self.func_schema.name[-1] == "_"
  52. self.is_out_variant = "out" in self.func_schema.overload_name
  53. @property
  54. def args_spec(self) -> Tuple[DTensorSpec, ...]:
  55. """
  56. args_spec: Tuple[DTensorSpec, ...]: contains a clean list of args spec list
  57. with NO non-DTensor positional arguments (i.e. int/float/tuple, etc)
  58. mainly used by sharding propagation to propagate the output spec
  59. """
  60. # filter out non-relavant values from args schema to get a clean spec list
  61. # this would mainly be used by sharding propagation rules
  62. return tuple(item for item in self.args_schema if isinstance(item, DTensorSpec))
  63. def __repr__(self) -> str:
  64. return (
  65. f"OpSchema(func_schema={self.func_schema},"
  66. f" args_schema={self.args_schema},"
  67. f" kwargs_schema={self.kwargs_schema})"
  68. )
  69. def __hash__(self) -> int:
  70. # NOTE: we turn kwargs_schema into a frozenset to hash as it would not be nested dict
  71. frozen_set_kwargs_schema = frozenset(self.kwargs_schema.items())
  72. return hash((self.func_schema, self.args_spec, frozen_set_kwargs_schema))
  73. def __eq__(self, other: object) -> bool:
  74. if not isinstance(other, OpSchema):
  75. return False
  76. return (
  77. self.func_schema == other.func_schema
  78. and self.args_schema == other.args_schema
  79. and self.kwargs_schema == other.kwargs_schema
  80. )
  81. @dataclass
  82. class OutputSharding:
  83. """
  84. OutputSharding is a data class that is used by the sharding propagation
  85. rules, it could set the output_spec upon successful propagation, and if
  86. it failed, output_spec would become None and sharding propagation rules
  87. could give a list of suggestions for inputs to reshard.
  88. NOTE: the schema_suggestion generated by sharding propagation should be
  89. exactly the same as the operator OpSchema, except the DTensor DTensorSpecs
  90. """
  91. output_spec: OutputSpecType
  92. schema_suggestions: Optional[List[OpSchema]] = None
  93. failed_reason: Optional[str] = None