from dataclasses import dataclass from typing import Dict, List, Optional, Sequence, Tuple, Union import torch from torch.distributed._tensor.placement_types import DTensorSpec # Common type aliases ArgsType = Tuple[object, ...] KwargsType = Dict[str, object] # ATen op schemas could have Tensor, Tuple[Tensor] and List[Tensor], so output type sould # be the same set of possiblities. OutputSpecType = Optional[Union[DTensorSpec, Sequence[Optional[DTensorSpec]]]] @dataclass class OpSchema: """ OpSchema is a data class that describes an operator input schemas, it includes DTensor DTensorSpecs and non-tensor args/kwargs (positional order preserved). It is mainly used by the dispatching logic below to run things like sharding propagation. Sharding propagation rules registered could utilize this data class and do inplace update some fields (when necessary, i.e shape related ops) to make sure the args/kwargs are legit before passing to the local tensor operator. This is the main reason that we don't freeze this dataclass. NOTE: greater access to the operator inputs comes with greater responsibility. Here are some basic rules about what can be used and what can be changed. Args: func_schema: the function schema of the operator args_schema: contains args except that the DTensor args have been replaced with its DTensorSpec kwargs_schema: contains kwargs except that the DTensor kwargs have been replaced with its DTensorSpec What can be used: - every attribute within this class could be read to conduct sharding propagation. What can be changed: - only the args_schema and kwargs_schema could be changed. - every non-tensor args could be changed to accomodate for local tensor operations (i.e. for ops like view/reshape/...) - every "DTensorSpec" attribute inside `args_schema`, `kwargs_schema` and `args_spec` SHOULD NOT be updated! DTensorSpec are read only and sharding propagation shouldn't inplace update them, otherwise the input DTensor placements will get implicitly changed and it's error-prone. """ func_schema: torch._C.FunctionSchema args_schema: ArgsType kwargs_schema: KwargsType is_inplace: bool = False is_out_variant: bool = False def __post_init__(self) -> None: # simple analysis of function schema to determine # if this is an inplace/out variant, it might not # be entirely correct, but it's good enough for now. self.is_inplace = self.func_schema.name[-1] == "_" self.is_out_variant = "out" in self.func_schema.overload_name @property def args_spec(self) -> Tuple[DTensorSpec, ...]: """ args_spec: Tuple[DTensorSpec, ...]: contains a clean list of args spec list with NO non-DTensor positional arguments (i.e. int/float/tuple, etc) mainly used by sharding propagation to propagate the output spec """ # filter out non-relavant values from args schema to get a clean spec list # this would mainly be used by sharding propagation rules return tuple(item for item in self.args_schema if isinstance(item, DTensorSpec)) def __repr__(self) -> str: return ( f"OpSchema(func_schema={self.func_schema}," f" args_schema={self.args_schema}," f" kwargs_schema={self.kwargs_schema})" ) def __hash__(self) -> int: # NOTE: we turn kwargs_schema into a frozenset to hash as it would not be nested dict frozen_set_kwargs_schema = frozenset(self.kwargs_schema.items()) return hash((self.func_schema, self.args_spec, frozen_set_kwargs_schema)) def __eq__(self, other: object) -> bool: if not isinstance(other, OpSchema): return False return ( self.func_schema == other.func_schema and self.args_schema == other.args_schema and self.kwargs_schema == other.kwargs_schema ) @dataclass class OutputSharding: """ OutputSharding is a data class that is used by the sharding propagation rules, it could set the output_spec upon successful propagation, and if it failed, output_spec would become None and sharding propagation rules could give a list of suggestions for inputs to reshard. NOTE: the schema_suggestion generated by sharding propagation should be exactly the same as the operator OpSchema, except the DTensor DTensorSpecs """ output_spec: OutputSpecType schema_suggestions: Optional[List[OpSchema]] = None failed_reason: Optional[str] = None