utils.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. import functools
  3. import operator
  4. from typing import Iterable, List, Sequence, Union
  5. import torch
  6. from torch.distributed._tensor.api import DTensor
  7. # pyre-fixme[3]: Return type must be annotated.
  8. # pyre-fixme[2]: Parameter must be annotated.
  9. def unwrap_single_placement(e):
  10. if not isinstance(e, DTensor):
  11. return None
  12. assert len(e.placements) == 1, "more than one placement!"
  13. return e.placements[0]
  14. # convenient wrapper to register custom operator impls
  15. # pyre-fixme[3]: Return type must be annotated.
  16. # pyre-fixme[2]: Parameter must be annotated.
  17. def register_impl(func):
  18. # pyre-fixme[53]: Captured variable `func` is not annotated.
  19. # pyre-fixme[3]: Return type must be annotated.
  20. # pyre-fixme[2]: Parameter must be annotated.
  21. def wrapper(impl):
  22. DTensor._custom_dispatch_ops[func] = impl
  23. return impl
  24. return wrapper
  25. # convenient wrapper to register sharding propagation rules
  26. # pyre-fixme[3]: Return type must be annotated.
  27. # pyre-fixme[2]: Parameter must be annotated.
  28. def register_prop_rule(op):
  29. # pyre-fixme[53]: Captured variable `func` is not annotated.
  30. # pyre-fixme[3]: Return type must be annotated.
  31. # pyre-fixme[2]: Parameter must be annotated.
  32. def wrapper(impl):
  33. overloads = op if isinstance(op, list) else [op]
  34. for overload in overloads:
  35. DTensor._propagator.register_sharding_prop_rule(overload, impl)
  36. return impl
  37. return wrapper
  38. def as_list(
  39. x: Union[List[object], object]
  40. # pyre-fixme[11]: Annotation `immutable_list` is not defined as a type.
  41. ) -> Union[List[object], torch.fx.immutable_collections.immutable_list]:
  42. # During tracing, `aten.sum.dim_IntList` uses `immutable_list` for its args,
  43. # which is an object but treated as a list by the tracer. Therefore, keep
  44. # `immutable_list` intact here as well.
  45. if type(x) is list or isinstance(x, torch.fx.immutable_collections.immutable_list):
  46. return x
  47. else:
  48. return [x]
  49. def normalize_dim(dim: int, ndim: int) -> int:
  50. return dim if dim >= 0 else dim + ndim
  51. def normalize_dims(dims: Union[int, Sequence[int]], ndim: int) -> Sequence[int]:
  52. """
  53. normalize a dim or a sequence of dims, so that they
  54. are all positive.
  55. """
  56. if isinstance(dims, int):
  57. dims = (normalize_dim(dims, ndim),)
  58. elif isinstance(dims, list):
  59. dims = [normalize_dim(dim, ndim) for dim in dims]
  60. elif isinstance(dims, tuple):
  61. dims = tuple([normalize_dim(dim, ndim) for dim in dims])
  62. return dims
  63. def prod(xs: Iterable[int]) -> int:
  64. return functools.reduce(operator.mul, xs, 1)