api.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. import copy
  3. import warnings
  4. from typing import Callable, cast, Dict, Optional, Sequence
  5. import torch
  6. import torch.nn as nn
  7. import torch.distributed._tensor.dispatch as op_dispatch
  8. from torch.distributed._tensor.device_mesh import DeviceMesh, get_global_device_mesh
  9. from torch.distributed._tensor.placement_types import (
  10. _Partial,
  11. DTensorSpec,
  12. Placement,
  13. Replicate,
  14. Shard,
  15. )
  16. from torch.distributed._tensor.sharding_prop import ShardingPropagator
  17. from torch.distributed._tensor.redistribute import Redistribute
  18. from torch.utils._pytree import tree_flatten
  19. __all__ = ["DTensor", "distribute_tensor", "distribute_module"]
  20. # NOTE [Autograd interaction between torch.Tensor]
  21. #
  22. # The autograd functions defined below are being used by the public
  23. # facing APIs (i.e. from_local, to_local) to ensure our DTensor
  24. # works together with torch.Tensor within autograd engine. This
  25. # allows DistributedTensor to exist on part of the module hierarchy
  26. # and still able to calculate gradients across the torch.Tensor and
  27. # DistributedTensor boundary.
  28. # As an example, we have the a module that consists of submodules
  29. # A, B, and C, the execution flow would be like:
  30. # input(torch.Tensor) -> Module A -> Module B -> Module C -> output (torch.Tensor)
  31. #
  32. # Suppose I only want to make Module B be a sharded module with
  33. # DistributedTensor params, we would need to make the folloing
  34. # flow to work:
  35. #
  36. # input(torch.Tensor) -> Module A
  37. # -> DTensor input -> Sharded Module B -> DTensor output
  38. # -> output (torch.Tensor) -> Module C -> output (torch.Tensor)
  39. #
  40. # We need the conversion from Module A to DTensor input, which is
  41. # `from_local`, and conversion from DTensor output to output, which
  42. # is `to_local`, thus these two functions must be Autograd functions.
  43. #
  44. class _ToTorchTensor(torch.autograd.Function):
  45. @staticmethod
  46. def forward(ctx, input: "DTensor"): # type: ignore[override]
  47. ctx.dtensor_device_mesh = input.device_mesh
  48. ctx.dtensor_placements = input.placements
  49. ctx.dtensor_shape = input.shape
  50. ctx.dtensor_requires_grad = input.requires_grad
  51. return input._local_tensor.detach()
  52. @staticmethod
  53. def backward(ctx, grad_output: torch.Tensor): # type: ignore[override]
  54. device_mesh = ctx.dtensor_device_mesh
  55. placements = ctx.dtensor_placements
  56. return DTensor(
  57. grad_output,
  58. device_mesh,
  59. placements,
  60. size=ctx.dtensor_shape,
  61. requires_grad=grad_output.requires_grad,
  62. )
  63. class _FromTorchTensor(torch.autograd.Function):
  64. @staticmethod
  65. def forward( # type: ignore[override]
  66. ctx, # pyre-ignore[2]: Parameter must be annotated.
  67. input: torch.Tensor,
  68. device_mesh: DeviceMesh,
  69. placements: Sequence[Placement],
  70. run_check: bool,
  71. ) -> "DTensor":
  72. ctx.previous_placement = placements
  73. ctx.previous_device_mesh = device_mesh
  74. if run_check:
  75. # TODO: by default check tensor metas across rank
  76. # TODO: See if we need to make this run_check logic
  77. # have a corresponding backward.
  78. for idx, placement in enumerate(placements):
  79. if placement.is_replicate():
  80. # broadcast rank 0 tensor to all ranks
  81. # only broadcast if run_check is True
  82. input = input.contiguous()
  83. device_mesh.broadcast(input, mesh_dim=idx)
  84. # if it's not by default run_check, we assume user is certain that each
  85. # rank has the same tensor shape, and we just use that to calculate the
  86. # global shape
  87. tensor_shape = list(input.size())
  88. for idx, placement in enumerate(placements):
  89. if placement.is_shard():
  90. shard_dim = cast(Shard, placement).dim
  91. local_dim_size = tensor_shape[shard_dim]
  92. tensor_shape[shard_dim] = local_dim_size * device_mesh.size(idx)
  93. dist_tensor = DTensor(
  94. input,
  95. device_mesh,
  96. placements,
  97. size=torch.Size(tensor_shape),
  98. # requires_grad of the dist tensor depends on if input
  99. # requires_grad or not
  100. requires_grad=input.requires_grad,
  101. )
  102. return dist_tensor
  103. @staticmethod
  104. def backward(ctx, grad_output: "DTensor"): # type: ignore[override]
  105. previous_placement = ctx.previous_placement
  106. previous_device_mesh = ctx.previous_device_mesh
  107. # reshard to the placement when creating DistributedTensor
  108. # so that the gradient layout matches, and we could return
  109. # local gradients directly
  110. if grad_output.placements != previous_placement:
  111. # pyre-fixme[16]: `Redistribute` has no attribute `apply`.
  112. grad_output = Redistribute.apply(
  113. grad_output, previous_device_mesh, previous_placement
  114. )
  115. # TODO: backward is also differentiable now, add a test
  116. # to test higher level gradients.
  117. return grad_output.to_local(), None, None, None
  118. class DTensor(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new__
  119. _local_tensor: torch.Tensor
  120. _spec: DTensorSpec
  121. __slots__ = ["_local_tensor", "_spec"]
  122. # class attribute that handles operator placements propagation
  123. # rules, keyed by aten op name, value is propagation func
  124. _propagator: ShardingPropagator = ShardingPropagator()
  125. # class attribute that handles custom registered ops, all handled
  126. # custom ops should appear in this table, and overriding the default
  127. # operators that's been covered by _op_to_rules or fallbacks.
  128. # (custom operator is the highest priority when dispatching).
  129. # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
  130. _custom_dispatch_ops: Dict[str, Callable] = {}
  131. @staticmethod
  132. def __new__(
  133. cls,
  134. local_tensor: torch.Tensor,
  135. device_mesh: DeviceMesh,
  136. placements: Sequence[Placement],
  137. *,
  138. size: torch.Size,
  139. requires_grad: bool = False,
  140. ) -> "DTensor":
  141. """
  142. Construct a DTensor from a local tensor, device mesh, and placement and
  143. other tensor properties (i.e. shape, requires_grad, strides, etc).
  144. Note: This is not a public API and it's only supposed to be used by the
  145. operator implementations and internals. If you want to construct a
  146. DTensor from a local tensor, consider using `DTensor.from_local`, if
  147. you want to construct a DTensor from a "global" tensor (where you
  148. already have tensor initialized and want to shard this tensor),
  149. consider using `distribute_tensor`.
  150. """
  151. # recover tensor strides from local tensor strides and global size info
  152. # in the case of sharding
  153. # TODO: we should try to use meta tensor for shape and stride calculation
  154. tensor_stride = list(local_tensor.stride())
  155. local_size = list(local_tensor.size())
  156. for placement in placements:
  157. if isinstance(placement, Shard):
  158. shard_dim = placement.dim
  159. # recover tensor stride by modifying the stride that larger than
  160. # the current stride on the shard_dim
  161. for i in range(len(tensor_stride)):
  162. if i != shard_dim and tensor_stride[i] >= tensor_stride[shard_dim]:
  163. # rescale the stride by the shard size
  164. tensor_stride[i] = (
  165. tensor_stride[i] // local_size[shard_dim]
  166. ) * size[shard_dim]
  167. elif not isinstance(placement, (Replicate, _Partial)):
  168. raise RuntimeError(f"placement type {type(placement)} not supported!")
  169. if requires_grad != local_tensor.requires_grad:
  170. warnings.warn(
  171. "To construct DTensor from torch.Tensor, it's recommended to "
  172. "use local_tensor.detach() and make requires_grad consistent."
  173. )
  174. # new method instruct wrapper tensor from local_tensor and add
  175. # placement spec, it does not do actual distribution
  176. r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
  177. cls,
  178. size,
  179. strides=tensor_stride,
  180. dtype=local_tensor.dtype,
  181. device=local_tensor.device,
  182. layout=local_tensor.layout,
  183. requires_grad=requires_grad,
  184. )
  185. # deepcopy and set spec
  186. r._spec = DTensorSpec(device_mesh, copy.deepcopy(placements), shape=r.size())
  187. # detach local tensor from autograd graph as we initialize the
  188. # distributed tensor and autograd will be working on top of
  189. # the wrapper tensor directly instead of local torch.Tensor
  190. r._local_tensor = local_tensor.detach()
  191. return r
  192. # pyre-fixme[14]: `__repr__` overrides method defined in `DTensor` inconsistently.
  193. # pyre-fixme[3]: Return type must be annotated.
  194. def __repr__(self):
  195. # TODO: consider all_gather the local tensors for better debugging
  196. return f"DTensor(local_tensor={self._local_tensor}, device_mesh={self._spec.mesh}, placements={self._spec.placements})"
  197. @classmethod
  198. # pyre-fixme[3]: Return type must be annotated.
  199. # pyre-fixme[2]: Parameter must be annotated.
  200. def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
  201. # check that we are not getting mixed vanilla and Distributed tensors
  202. arg_list, _ = tree_flatten(args)
  203. for arg in arg_list:
  204. if isinstance(arg, torch.Tensor) and not isinstance(arg, DTensor):
  205. raise RuntimeError(
  206. f"{func}: got mixed distributed and non-distributed tensors."
  207. )
  208. if kwargs is None:
  209. kwargs = {}
  210. return op_dispatch.operator_dispatch(
  211. func,
  212. args,
  213. kwargs,
  214. DTensor._propagator,
  215. DTensor._custom_dispatch_ops,
  216. )
  217. @classmethod
  218. def from_local(
  219. cls,
  220. local_tensor: torch.Tensor,
  221. device_mesh: Optional[DeviceMesh] = None,
  222. placements: Optional[Sequence[Placement]] = None,
  223. run_check: bool = True,
  224. ) -> "DTensor":
  225. """
  226. Create a :class:`DTensor` from a local torch.Tensor on each rank
  227. according to the `device_mesh` and `placements` specified.
  228. Args:
  229. local_tensor (torch.Tensor): local torch.Tensor on each rank.
  230. device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to place the
  231. tensor, if not specified, must be called under a DeviceMesh
  232. context manager, default: None
  233. placements (List[:class:`Placement`], optional): the placements that
  234. describes how to place the local torch.Tensor on DeviceMesh, must
  235. have the same number of elements as `device_mesh.ndim`. If not
  236. specified, we will by default replicate the tensor across the
  237. `device_mesh` from the first rank of each dimension of the `device_mesh`.
  238. run_check (bool, optional): indicate whether to run check across ranks
  239. to check meta information and data. if have :class:`Replicate` in
  240. `placements`, the data on first rank of the device mesh dimension
  241. will be broadcasted to other ranks.
  242. Returns:
  243. A :class:`DTensor` object
  244. .. note:: `from_local` is differentiable, the `requires_grad` of the created
  245. `DTensor` object will depend on if `local_tensor` requires_grad or not.
  246. """
  247. # if same shape/dtype, no need to run_check, if not, must allgather
  248. # the metadatas to check the size/dtype across ranks
  249. # There should be no data communication unless there's replication
  250. # strategy, where we broadcast the replication from the first rank
  251. # in the mesh dimension
  252. device_mesh = get_global_device_mesh() if device_mesh is None else device_mesh
  253. # convert the local tensor to desired device base on device mesh's device_type
  254. if not local_tensor.is_meta:
  255. local_tensor = local_tensor.to(device_mesh.device_type)
  256. # set default placements to replicated if not specified
  257. if placements is None:
  258. placements = [Replicate() for _ in range(device_mesh.ndim)]
  259. # `from_local` is differentiable, and the gradient of the dist tensor this function
  260. # created should flow back the gradients to the local_tensor, so we call an autograd
  261. # function to construct the dist tensor instead.
  262. return _FromTorchTensor.apply( # pyre-ignore[16]: autograd func
  263. local_tensor, device_mesh, placements, run_check
  264. )
  265. def to_local(self) -> torch.Tensor:
  266. """
  267. Get the local tensor of this DTensor on its current rank. For sharding it returns
  268. a local shard of the logical tensor view, for replication it returns the replica on
  269. its current rank.
  270. Returns:
  271. A :class:`torch.Tensor` object that represents the local tensor of its current rank.
  272. .. note:: `to_local` is differentiable, the `requires_grad` of the local tensor returned
  273. will depend on if the `DTensor` requires_grad or not.
  274. """
  275. return _ToTorchTensor.apply(self) # pyre-ignore[16]: autograd func
  276. def redistribute(
  277. self,
  278. device_mesh: Optional[DeviceMesh] = None,
  279. placements: Optional[Sequence[Placement]] = None,
  280. ) -> "DTensor":
  281. """
  282. `redistribute` performs necessary collective operations that redistribute the current
  283. DTensor from its current placements to a new placements, or from is current DeviceMesh
  284. to a new DeviceMesh. i.e. we can turn a Sharded DTensor to a Replicated DTensor by
  285. specifying a Replicate placement for each dimension of the DeviceMesh.
  286. Args:
  287. device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to place the
  288. DTensor, if not specified, must be called under a DeviceMesh
  289. context manager, default: None
  290. placements (List[:class:`Placement`], optional): the new placements that
  291. describes how to place the DTensor into the DeviceMesh, must
  292. have the same number of elements as `device_mesh.ndim`.
  293. Returns:
  294. A :class:`DTensor` object
  295. .. note:: `redistribute` is differentiable.
  296. """
  297. # This API perform necessary transformations and get
  298. # a new DTensor with the new spec. i.e. for
  299. # sharding it's a reshard behavior.
  300. # Note that redistribute currently only supports out
  301. # of place redistribution, i.e. it always create a new
  302. # DTensor object and leave the original one unchanged.
  303. device_mesh = get_global_device_mesh() if device_mesh is None else device_mesh
  304. # raise error if new placements not specified
  305. if placements is None:
  306. raise RuntimeError("placements is needed for redistribute!")
  307. for placement in placements:
  308. if placement.is_partial():
  309. raise RuntimeError(
  310. "Can not redistribute to _Partial, _Partial is for internal use only!"
  311. )
  312. # pyre-fixme[16]: `Redistribute` has no attribute `apply`.
  313. return Redistribute.apply(self, device_mesh, placements)
  314. @property
  315. def device_mesh(self) -> DeviceMesh:
  316. """
  317. The :class:`DeviceMesh` attribute that associates with this DTensor object.
  318. .. note:: device_mesh is a read-only property, it can not be set.
  319. """
  320. return self._spec.mesh
  321. @property
  322. def placements(self) -> Sequence[Placement]:
  323. """
  324. The placements attribute of this DTensor that describes the layout of this
  325. DTensor on the its DeviceMesh.
  326. .. note:: placements is a read-only property, it can not be set.
  327. """
  328. return self._spec.placements
  329. def distribute_tensor(
  330. tensor: torch.Tensor,
  331. device_mesh: Optional[DeviceMesh] = None,
  332. placements: Optional[Sequence[Placement]] = None,
  333. ) -> DTensor:
  334. """
  335. Distribute a torch.Tensor to the `device_mesh` according to the `placements`
  336. specified. The rank of `device_mesh` and `placements` must be the same.
  337. Args:
  338. tensor (torch.Tensor): torch.Tensor to be distributed. Note that if you
  339. want to shard a tensor on a dimension that is not evenly divisible by
  340. the number of devices in that mesh dimension, we use `torch.tensor_split`
  341. semantic to shard the tensor and scatter the shards.
  342. device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to distribute the
  343. tensor, if not specified, must be called under a DeviceMesh context
  344. manager, default: None
  345. placements (List[:class:`Placement`], optional): the placements that
  346. describes how to place the tensor on DeviceMesh, must have the same
  347. number of elements as `device_mesh.ndim`. If not specified, we will
  348. by default replicate the tensor across the `device_mesh` from the
  349. first rank of each dimension of the `device_mesh`.
  350. Returns:
  351. A :class:`DTensor` object
  352. """
  353. # get default device mesh if there's nothing specified
  354. device_mesh = get_global_device_mesh() if device_mesh is None else device_mesh
  355. # convert tensor to the correponding device type if it's not in that device type
  356. if not tensor.is_meta:
  357. tensor = tensor.to(device_mesh.device_type)
  358. # set default placements to replicated if not specified
  359. if placements is None:
  360. placements = [Replicate() for _ in range(device_mesh.ndim)]
  361. if len(placements) != device_mesh.ndim:
  362. raise ValueError(
  363. f"`placements` must have the same length as `device_mesh.ndim`! "
  364. f"Found placements length: {len(placements)}, and device_mesh.ndim: {device_mesh.ndim}."
  365. )
  366. if isinstance(tensor, DTensor):
  367. # if the tensor is already a DTensor, we just need to check if the
  368. # device mesh and placements are the same
  369. if tensor.device_mesh != device_mesh:
  370. raise ValueError(
  371. f"Cannot distribute a DTensor with device mesh {tensor.device_mesh} "
  372. f"to a different device mesh {device_mesh}."
  373. )
  374. if tensor.placements != placements:
  375. raise ValueError(
  376. f"Cannot distribute a DTensor with placements {tensor.placements} "
  377. f"to a different placements {placements}. do you want to call "
  378. f"`redistribute` instead?"
  379. )
  380. return tensor
  381. local_tensor = tensor
  382. # distribute the tensor according to the placements.
  383. for idx, placement in enumerate(placements):
  384. if placement.is_shard():
  385. placement = cast(Shard, placement)
  386. output = placement._shard_tensor(local_tensor, device_mesh, idx)
  387. # scatter call could not return a tensor with correct requires_grad
  388. # field, as ProcessGroupNCCL refuse to take a tensor with requires_grad
  389. # to do inplace update! So we manually set it here
  390. output.requires_grad_(tensor.requires_grad)
  391. local_tensor = output
  392. elif placement.is_replicate():
  393. local_tensor = local_tensor.contiguous()
  394. device_mesh.broadcast(local_tensor, mesh_dim=idx)
  395. else:
  396. raise RuntimeError(
  397. f"Trying to distribute tensor with unsupported placements {placement} on device mesh dimension {idx}!"
  398. )
  399. assert local_tensor is not None, "distributing a tensor should not be None"
  400. return DTensor(
  401. local_tensor,
  402. device_mesh,
  403. placements,
  404. size=tensor.size(),
  405. requires_grad=tensor.requires_grad,
  406. )
  407. def distribute_module(
  408. module: nn.Module,
  409. device_mesh: Optional[DeviceMesh] = None,
  410. partition_fn: Optional[Callable[[str, nn.Module, DeviceMesh], None]] = None,
  411. input_fn: Optional[Callable[..., None]] = None,
  412. output_fn: Optional[Callable[..., None]] = None,
  413. ) -> nn.Module:
  414. """
  415. This function converts all module parameters to :class:`DTensor` parameters
  416. according to the `partition_fn` specified. It could also control the input or
  417. output of the module by specifying the `input_fn` and `output_fn`. (i.e. convert
  418. the input to :class:`DTensor`, convert the output back to torch.Tensor)
  419. Args:
  420. module (:class:`nn.Module`): user module to be partitioned.
  421. device_mesh (:class:`DeviceMesh`): the device mesh to place the module.
  422. partition_fn (Callable): the function to partition parameters (i.e. shard certain
  423. parameters across the `device_mesh`). If `partition_fn` is not specified,
  424. by default we replicate all module parameters of `module` across the mesh.
  425. input_fn (Callable): specify the input distribution, i.e. could control how the
  426. input of the module is sharded. `input_fn` will be installed as a module
  427. `forward_pre_hook` (pre forward hook).
  428. output_fn (Callable): specify the output distribution, i.e. could control how the
  429. output is sharded, or convert it back to torch.Tensor. output_fn will be
  430. installed as a module `forward_hook` (post forward hook).
  431. Returns:
  432. A module that contains parameters/buffers that are all `DTensor`s.
  433. """
  434. if device_mesh is None:
  435. device_mesh = get_global_device_mesh()
  436. def replicate_module_params_buffers(m: nn.Module, mesh: DeviceMesh) -> None:
  437. # This function loop over the immediate module parameters and
  438. # buffers, replicate all non DTensor params/buffers to DTensor
  439. # parameters/buffers, if they have not been partitioned in the
  440. # partition_fn, we can't easily use `module._apply` here
  441. # because we don't know what happened inside partition_fn as
  442. # user could do anything, i.e. install hooks, and we want to
  443. # preserve those.
  444. full_replicate = [Replicate()] * mesh.ndim
  445. for key, param in m._parameters.items():
  446. if param is not None and not isinstance(param, DTensor):
  447. m.register_parameter(
  448. key,
  449. nn.Parameter(distribute_tensor(param.data, mesh, full_replicate)),
  450. )
  451. for key, buffer in m._buffers.items():
  452. if buffer is not None and not isinstance(buffer, DTensor):
  453. m._buffers[key] = distribute_tensor(buffer, mesh, full_replicate)
  454. if partition_fn is None:
  455. # if partition_fn not specified, we by default replicate
  456. # all module params/buffers
  457. for name, submod in module.named_modules():
  458. replicate_module_params_buffers(submod, device_mesh)
  459. else:
  460. # apply partition_fun to submodules
  461. for name, submod in module.named_modules():
  462. partition_fn(name, submod, device_mesh)
  463. replicate_module_params_buffers(submod, device_mesh)
  464. # register input_fn as module forward pre hook
  465. if input_fn is not None:
  466. module.register_forward_pre_hook(lambda _, inputs: input_fn(inputs, device_mesh)) # type: ignore[misc]
  467. # register input_fn as module forward hook
  468. if output_fn is not None:
  469. module.register_forward_hook(
  470. lambda mod, inputs, outputs: output_fn(outputs, device_mesh) # type: ignore[misc]
  471. )
  472. return module