123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112 |
- from dataclasses import dataclass
- from typing import Dict, List, Optional, Sequence, Tuple, Union
- import torch
- from torch.distributed._tensor.placement_types import DTensorSpec
- # Common type aliases
- ArgsType = Tuple[object, ...]
- KwargsType = Dict[str, object]
- # ATen op schemas could have Tensor, Tuple[Tensor] and List[Tensor], so output type sould
- # be the same set of possiblities.
- OutputSpecType = Optional[Union[DTensorSpec, Sequence[Optional[DTensorSpec]]]]
- @dataclass
- class OpSchema:
- """
- OpSchema is a data class that describes an operator input schemas, it
- includes DTensor DTensorSpecs and non-tensor args/kwargs (positional order
- preserved). It is mainly used by the dispatching logic below to run things like
- sharding propagation.
- Sharding propagation rules registered could utilize this data class and
- do inplace update some fields (when necessary, i.e shape related ops) to make
- sure the args/kwargs are legit before passing to the local tensor operator.
- This is the main reason that we don't freeze this dataclass.
- NOTE: greater access to the operator inputs comes with greater responsibility.
- Here are some basic rules about what can be used and what can be changed.
- Args:
- func_schema: the function schema of the operator
- args_schema: contains args except that the DTensor args have been replaced
- with its DTensorSpec
- kwargs_schema: contains kwargs except that the DTensor kwargs have been replaced
- with its DTensorSpec
- What can be used:
- - every attribute within this class could be read to conduct
- sharding propagation.
- What can be changed:
- - only the args_schema and kwargs_schema could be changed.
- - every non-tensor args could be changed to accomodate for local tensor
- operations (i.e. for ops like view/reshape/...)
- - every "DTensorSpec" attribute inside `args_schema`, `kwargs_schema` and
- `args_spec` SHOULD NOT be updated! DTensorSpec are read only and sharding
- propagation shouldn't inplace update them, otherwise the input DTensor
- placements will get implicitly changed and it's error-prone.
- """
- func_schema: torch._C.FunctionSchema
- args_schema: ArgsType
- kwargs_schema: KwargsType
- is_inplace: bool = False
- is_out_variant: bool = False
- def __post_init__(self) -> None:
- # simple analysis of function schema to determine
- # if this is an inplace/out variant, it might not
- # be entirely correct, but it's good enough for now.
- self.is_inplace = self.func_schema.name[-1] == "_"
- self.is_out_variant = "out" in self.func_schema.overload_name
- @property
- def args_spec(self) -> Tuple[DTensorSpec, ...]:
- """
- args_spec: Tuple[DTensorSpec, ...]: contains a clean list of args spec list
- with NO non-DTensor positional arguments (i.e. int/float/tuple, etc)
- mainly used by sharding propagation to propagate the output spec
- """
- # filter out non-relavant values from args schema to get a clean spec list
- # this would mainly be used by sharding propagation rules
- return tuple(item for item in self.args_schema if isinstance(item, DTensorSpec))
- def __repr__(self) -> str:
- return (
- f"OpSchema(func_schema={self.func_schema},"
- f" args_schema={self.args_schema},"
- f" kwargs_schema={self.kwargs_schema})"
- )
- def __hash__(self) -> int:
- # NOTE: we turn kwargs_schema into a frozenset to hash as it would not be nested dict
- frozen_set_kwargs_schema = frozenset(self.kwargs_schema.items())
- return hash((self.func_schema, self.args_spec, frozen_set_kwargs_schema))
- def __eq__(self, other: object) -> bool:
- if not isinstance(other, OpSchema):
- return False
- return (
- self.func_schema == other.func_schema
- and self.args_schema == other.args_schema
- and self.kwargs_schema == other.kwargs_schema
- )
- @dataclass
- class OutputSharding:
- """
- OutputSharding is a data class that is used by the sharding propagation
- rules, it could set the output_spec upon successful propagation, and if
- it failed, output_spec would become None and sharding propagation rules
- could give a list of suggestions for inputs to reshard.
- NOTE: the schema_suggestion generated by sharding propagation should be
- exactly the same as the operator OpSchema, except the DTensor DTensorSpecs
- """
- output_spec: OutputSpecType
- schema_suggestions: Optional[List[OpSchema]] = None
- failed_reason: Optional[str] = None
|