make_functional.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the BSD-style license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import copy
  7. from typing import (
  8. Any,
  9. Callable,
  10. Dict,
  11. Iterable,
  12. List,
  13. NoReturn,
  14. Sequence,
  15. Tuple,
  16. Type,
  17. Union,
  18. )
  19. import torch
  20. import torch.nn as nn
  21. from torch import Tensor
  22. from torch.nn.utils._named_member_accessor import NamedMemberAccessor
  23. # Utilities to make nn.Module "functional"
  24. # In particular the goal is to be able to provide a function that takes as input
  25. # the parameters and evaluate the nn.Module using fixed inputs.
  26. def raise_parameter_tying_error() -> NoReturn:
  27. raise RuntimeError(
  28. "make_functional(module): we don't yet support models that "
  29. "do parameter tying (also sometimes known as weight sharing). "
  30. "Please try to rewrite your model by replacing all instances of the "
  31. "tied parameter with another and/or comment your support in "
  32. "https://github.com/pytorch/functorch/issues/446"
  33. )
  34. def create_names_map(
  35. named_params: Union[Dict[str, Tensor], Iterable[Tuple[str, Tensor]]],
  36. tied_named_params: Union[Dict[str, Tensor], Iterable[Tuple[str, Tensor]]],
  37. ) -> Dict[str, List[str]]:
  38. """
  39. named_params is a dictionary of tensors: {'A': A, 'B': B}
  40. tied_named_params is another dictionary of tensors {'A': A, 'B': B, 'B_tied': B}
  41. with potentially tied (or 'duplicated') tensors
  42. This function creates a mapping from the names in named_params to the
  43. names in tied_named_params: {'A': ['A'], 'B': ['B', 'B_tied']}.
  44. """
  45. named_params = dict(named_params)
  46. tied_named_params = dict(tied_named_params)
  47. tensors_dict_keys = set(named_params.keys())
  48. tied_tensors_dict_keys = set(tied_named_params.keys())
  49. assert tensors_dict_keys.issubset(tied_tensors_dict_keys)
  50. tensor_to_mapping: Dict[Tensor, Tuple[str, List[str]]] = {}
  51. for key, tensor in named_params.items():
  52. tensor_to_mapping[tensor] = (key, [])
  53. for key, tensor in tied_named_params.items():
  54. assert tensor in tensor_to_mapping
  55. tensor_to_mapping[tensor][1].append(key)
  56. return dict(tensor_to_mapping.values())
  57. def _extract_members(
  58. mod: nn.Module,
  59. named_members: Callable[..., Iterable[Tuple[str, Tensor]]],
  60. subclass: Callable[[Tensor], Tensor],
  61. ) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[str]]]:
  62. all_named_members = tuple(named_members(remove_duplicate=False))
  63. unique_named_members = tuple(named_members(remove_duplicate=True))
  64. names_map = create_names_map(unique_named_members, all_named_members)
  65. # Remove all the members in the model
  66. memo = {}
  67. accessor = NamedMemberAccessor(mod)
  68. for name, p in all_named_members:
  69. if p not in memo:
  70. memo[p] = subclass(torch.empty_like(p, device="meta"))
  71. replacement = memo[p]
  72. accessor.set_tensor(name, replacement)
  73. if len(unique_named_members) == 0:
  74. names, params = (), ()
  75. else:
  76. names, params = zip(*unique_named_members) # type: ignore[assignment]
  77. return params, names, names_map
  78. def extract_weights(
  79. mod: nn.Module,
  80. ) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[str]]]:
  81. """
  82. This function removes all the Parameters from the model and
  83. return them as a tuple as well as their original attribute names.
  84. The weights must be re-loaded with `load_weights` before the model
  85. can be used again.
  86. Note that this function modifies the model in place and after this
  87. call, mod.parameters() will be empty.
  88. """
  89. return _extract_members(mod, mod.named_parameters, nn.Parameter)
  90. def extract_buffers(
  91. mod: nn.Module,
  92. ) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[str]]]:
  93. return _extract_members(mod, mod.named_buffers, lambda x: x)
  94. def load_weights(
  95. mod: nn.Module,
  96. names: Sequence[str],
  97. params: Sequence[Tensor],
  98. as_params: bool = False,
  99. ) -> None:
  100. """
  101. Reload a set of weights so that `mod` can be used again to perform a forward pass.
  102. Note that the `params` are regular Tensors (that can have history) and so are left
  103. as Tensors. This means that mod.parameters() will still be empty after this call.
  104. """
  105. accessor = NamedMemberAccessor(mod)
  106. if as_params:
  107. params = [nn.Parameter(p) for p in params]
  108. accessor.set_tensors(names, params)
  109. def _swap_state(
  110. mod: nn.Module, names_map: Dict[str, List[str]], elems: Iterable[Tensor]
  111. ) -> List[Tensor]:
  112. result: List[Tensor] = []
  113. accessor = NamedMemberAccessor(mod)
  114. for (_, attr_names), elem in zip(names_map.items(), elems):
  115. for i, attr_name in enumerate(attr_names):
  116. if i == 0:
  117. result.append(accessor.swap_tensor(attr_name, elem))
  118. else:
  119. accessor.set_tensor(attr_name, elem)
  120. return result
  121. def load_buffers(
  122. mod: nn.Module,
  123. names: Sequence[str],
  124. buffers: Sequence[Tensor],
  125. as_params: bool = False,
  126. ) -> None:
  127. accessor = NamedMemberAccessor(mod)
  128. accessor.set_tensors(names, buffers)
  129. def load_state(
  130. model: nn.Module,
  131. weights: Sequence[Tensor],
  132. weight_names: Sequence[str],
  133. buffers: Sequence[Tensor] = (),
  134. buffer_names: Sequence[str] = (),
  135. ) -> nn.Module:
  136. """load_state(model, weights, weight_names, buffers=(), buffer_names=()) -> model
  137. load_state takes `weights` and `buffers` and assigns them to the model.
  138. This is the inverse operation of `make_functional_deprecated_v1`.
  139. """
  140. assert len(weight_names) == len(weights)
  141. load_weights(model, weight_names, weights)
  142. if len(buffers) > 0:
  143. assert len(buffer_names) == len(buffers)
  144. load_buffers(model, buffer_names, buffers)
  145. return model
  146. def make_functional_deprecated_v1(model: nn.Module):
  147. """make_functional_deprecated_v1(model) -> weights, func, weight_names
  148. Given an nn.Module, make_functional_deprecated_v1 extracts the state (weights)
  149. and returns a functional version of the model, `func`. This makes
  150. it so that it is possible use transforms over the parameters of
  151. `model`.
  152. `func` can be invoked as follows:
  153. ```
  154. x = torch.randn(4, 3)
  155. model = nn.Linear(3, 3)
  156. weights, func, _ = make_functional_deprecated_v1(model)
  157. func(weights, (x,))
  158. ```
  159. And here is an example of applying the grad transform:
  160. ```
  161. x = torch.randn(4, 3)
  162. model = nn.Linear(3, 3)
  163. weights, _, func = make_functional_deprecated_v1(model)
  164. grad_weights = grad(func)(weights, (x,))
  165. ```
  166. To put the state back into a model, use `load_state`.
  167. """
  168. buffers = list(model.buffers())
  169. if len(buffers) > 0:
  170. raise RuntimeError(
  171. "make_functional_deprecated_v1(model): `model` has buffers. Please use "
  172. "make_functional_with_buffers_deprecated_v1(model) instead."
  173. )
  174. weights, descriptors, _ = extract_weights(model)
  175. def fun(weights, data):
  176. mutable_model = copy.deepcopy(model)
  177. load_weights(mutable_model, descriptors, weights)
  178. return mutable_model(*data)
  179. return weights, fun, descriptors
  180. def make_functional_with_buffers_deprecated_v1(model: nn.Module):
  181. """make_functional_with_buffers_deprecated_v1(model) -> weights, buffers, func, weight_names, buffer_names
  182. Given an nn.Module, make_functional_with_buffers_deprecated_v1 extracts the state (weights and buffers)
  183. and returns a functional version of the model, `func`.
  184. `func` can be invoked as follows:
  185. ```
  186. x = torch.randn(4, 3)
  187. model = nn.Linear(3, 3)
  188. weights, buffers, func, _, _ = make_functional_with_buffers_deprecated_v1(model)
  189. func(weights, buffers, (x,))
  190. ```
  191. And here is an example of applying the grad transform:
  192. ```
  193. x = torch.randn(4, 3)
  194. model = nn.Linear(3, 3)
  195. weights, buffers, func, _, _ = make_functional_with_buffers_deprecated_v1(model)
  196. func(weights, buffers, (x,))
  197. grad_weights = grad(func)(weights, buffers, (x,))
  198. ```
  199. To put the state back into a model, use `load_state`.
  200. """
  201. weights, weight_descriptors, _ = extract_weights(model)
  202. buffers, buf_descriptors, _ = extract_buffers(model)
  203. def fun(weights, buffers, data):
  204. mutable_model = copy.deepcopy(model)
  205. load_weights(mutable_model, weight_descriptors, weights)
  206. load_buffers(mutable_model, buf_descriptors, buffers)
  207. return mutable_model(*data)
  208. return weights, buffers, fun, weight_descriptors, buf_descriptors
  209. class FunctionalModuleWithBuffers(nn.Module):
  210. """
  211. This is the callable object returned by :func:`make_functional_with_buffers`.
  212. """
  213. def __init__(
  214. self,
  215. stateless_model: nn.Module,
  216. param_names: Tuple[str, ...],
  217. buffer_names: Tuple[str, ...],
  218. param_names_map: Dict[str, List[str]],
  219. buffer_names_map: Dict[str, List[str]],
  220. ) -> None:
  221. super().__init__()
  222. self.stateless_model = stateless_model
  223. self.param_names = param_names
  224. self.buffer_names = buffer_names
  225. self.all_names_map = dict(param_names_map)
  226. self.all_names_map.update(buffer_names_map)
  227. @staticmethod
  228. def _create_from(
  229. model: nn.Module, disable_autograd_tracking: bool = False
  230. ) -> Tuple["FunctionalModuleWithBuffers", Tuple[Tensor, ...], Tuple[Tensor, ...]]:
  231. # TODO: We don't need to copy the model to create a stateless copy
  232. model_copy = copy.deepcopy(model)
  233. params, param_names, param_names_map = extract_weights(model_copy)
  234. buffers, buffer_names, buffer_names_map = extract_buffers(model_copy)
  235. if disable_autograd_tracking:
  236. for param in params:
  237. param.requires_grad_(False)
  238. return (
  239. FunctionalModuleWithBuffers(
  240. model_copy, param_names, buffer_names, param_names_map, buffer_names_map
  241. ),
  242. params,
  243. buffers,
  244. )
  245. def forward(
  246. self, params: Iterable[Tensor], buffers: Iterable[Tensor], *args, **kwargs
  247. ) -> Any:
  248. # Temporarily load the state back onto self.stateless_model
  249. old_state = _swap_state(
  250. self.stateless_model,
  251. self.all_names_map,
  252. tuple(params) + tuple(buffers),
  253. )
  254. try:
  255. return self.stateless_model(*args, **kwargs)
  256. finally:
  257. # Remove the loaded state on self.stateless_model
  258. _swap_state(self.stateless_model, self.all_names_map, old_state)
  259. class FunctionalModule(nn.Module):
  260. """
  261. This is the callable object returned by :func:`make_functional`.
  262. """
  263. def __init__(
  264. self,
  265. stateless_model: nn.Module,
  266. param_names: Tuple[str, ...],
  267. names_map: Dict[str, List[str]],
  268. ) -> None:
  269. super().__init__()
  270. self.stateless_model = stateless_model
  271. self.param_names = param_names
  272. self.names_map = names_map
  273. @staticmethod
  274. def _create_from(
  275. model: nn.Module, disable_autograd_tracking: bool = False
  276. ) -> Tuple["FunctionalModule", Tuple[Tensor, ...]]:
  277. # TODO: We don't need to copy the model to create a stateless copy
  278. model_copy = copy.deepcopy(model)
  279. params, param_names, names_map = extract_weights(model_copy)
  280. if disable_autograd_tracking:
  281. for param in params:
  282. param.requires_grad_(False)
  283. return FunctionalModule(model_copy, param_names, names_map), params
  284. def forward(self, params: Iterable[Tensor], *args, **kwargs) -> Any:
  285. # Temporarily load the state back onto self.stateless_model
  286. old_state = _swap_state(self.stateless_model, self.names_map, params)
  287. try:
  288. return self.stateless_model(*args, **kwargs)
  289. finally:
  290. # Remove the loaded state on self.stateless_model
  291. _swap_state(self.stateless_model, self.names_map, old_state)
  292. def make_functional(
  293. model: nn.Module, disable_autograd_tracking: bool = False
  294. ) -> Tuple[FunctionalModule, Tuple[Tensor, ...]]:
  295. """make_functional(model, disable_autograd_tracking=False) -> func, params
  296. Given a ``torch.nn.Module``, :func:`make_functional` extracts the state
  297. (params) and returns a functional version of the model, ``func``. This
  298. makes it so that it is possible use transforms over the parameters of
  299. ``model``.
  300. ``func`` can be invoked as follows:
  301. .. code-block:: python
  302. import torch
  303. import torch.nn as nn
  304. from functorch import make_functional
  305. x = torch.randn(4, 3)
  306. model = nn.Linear(3, 3)
  307. func, params = make_functional(model)
  308. func(params, x)
  309. And here is an example of applying the grad transform over the parameters
  310. of a model.
  311. .. code-block:: python
  312. import torch
  313. import torch.nn as nn
  314. from functorch import make_functional, grad
  315. x = torch.randn(4, 3)
  316. t = torch.randn(4, 3)
  317. model = nn.Linear(3, 3)
  318. func, params = make_functional(model)
  319. def compute_loss(params, x, t):
  320. y = func(params, x)
  321. return nn.functional.mse_loss(y, t)
  322. grad_weights = grad(compute_loss)(params, x, t)
  323. If the model has any buffers, please use :func:`make_functional_with_buffers` instead.
  324. Args:
  325. model (torch.nn.Module): Input model.
  326. disable_autograd_tracking (bool): Flag to disable gradients tracking for output parameters.
  327. The returned params are unrelated to the set of params from the original model. If False (default),
  328. the params will have ``requires_grad=True`` on them (aka they will be trackable with regular
  329. PyTorch autograd), matching the requires_grad-ness of the params from the original model.
  330. Otherwise, the returned params will have ``requires_grad=False``. Default, False.
  331. If you plan on using regular PyTorch autograd (e.g., if you want to call ``.backward()`` or
  332. ``torch.autograd.grad()``, then set ``disable_autograd_tracking=False``.
  333. Otherwise, if you're only planning on using functorch's gradient transforms,
  334. then please set ``disable_autograd_tracking=True`` to avoid unnecessarily tracking
  335. history with PyTorch autograd.
  336. """
  337. buffers = list(model.buffers())
  338. if len(buffers) > 0:
  339. raise RuntimeError(
  340. "make_functional(model): `model` has buffers. Please use "
  341. "make_functional_with_buffers(model) instead."
  342. )
  343. return FunctionalModule._create_from(
  344. model, disable_autograd_tracking=disable_autograd_tracking
  345. )
  346. def make_functional_with_buffers(
  347. model: nn.Module, disable_autograd_tracking: bool = False
  348. ) -> Tuple[FunctionalModuleWithBuffers, Tuple[Tensor, ...], Tuple[Tensor, ...]]:
  349. """make_functional_with_buffers(model, disable_autograd_tracking=False) -> func, params, buffers
  350. Given a ``torch.nn.Module``, make_functional_with_buffers extracts the
  351. state (params and buffers) and returns a functional version of the model
  352. ``func`` that can be invoked like a function.
  353. ``func`` can be invoked as follows:
  354. .. code-block:: python
  355. import torch
  356. import torch.nn as nn
  357. from functorch import make_functional_with_buffers
  358. x = torch.randn(4, 3)
  359. model = nn.Linear(3, 3)
  360. func, params, buffers = make_functional_with_buffers(model)
  361. func(params, buffers, x)
  362. And here is an example of applying the grad transform over the parameters
  363. of a model:
  364. .. code-block:: python
  365. import torch
  366. import torch.nn as nn
  367. from functorch import make_functional_with_buffers, grad
  368. x = torch.randn(4, 3)
  369. t = torch.randn(4, 3)
  370. model = nn.Linear(3, 3)
  371. func, params, buffers = make_functional_with_buffers(model)
  372. def compute_loss(params, buffers, x, t):
  373. y = func(params, buffers, x)
  374. return nn.functional.mse_loss(y, t)
  375. grad_weights = grad(compute_loss)(params, buffers, x, t)
  376. Args:
  377. model (torch.nn.Module): Input model.
  378. disable_autograd_tracking (bool): Flag to disable gradients tracking for output parameters.
  379. The returned params are unrelated to the set of params from the original model. If False (default),
  380. the params will have ``requires_grad=True`` on them (aka they will be trackable with regular
  381. PyTorch autograd), matching the requires_grad-ness of the params from the original model.
  382. Otherwise, the returned params will have ``requires_grad=False``. Default, False.
  383. If you plan on using regular PyTorch autograd (e.g., if you want to call ``.backward()`` or
  384. ``torch.autograd.grad()``, then set ``disable_autograd_tracking=False``.
  385. Otherwise, if you're only planning on using functorch's gradient transforms,
  386. then please set ``disable_autograd_tracking=True`` to avoid unnecessarily tracking
  387. history with PyTorch autograd.
  388. """
  389. return FunctionalModuleWithBuffers._create_from(
  390. model, disable_autograd_tracking=disable_autograd_tracking
  391. )
  392. def transpose_stack(
  393. tuple_of_tuple_of_tensors: Tuple[Tuple[Tensor, ...], ...]
  394. ) -> Tuple[Tensor, ...]:
  395. tuple_of_tuple_of_tensors = tuple(zip(*tuple_of_tuple_of_tensors))
  396. results = tuple(
  397. torch.stack(shards).detach() for shards in tuple_of_tuple_of_tensors
  398. )
  399. return results
  400. def combine_state_for_ensemble(
  401. models: Sequence[nn.Module],
  402. ) -> Tuple[FunctionalModuleWithBuffers, Tuple[Tensor, ...], Tuple[Tensor, ...]]:
  403. """combine_state_for_ensemble(models) -> func, params, buffers
  404. Prepares a list of torch.nn.Modules for ensembling with :func:`vmap`.
  405. Given a list of ``M`` ``nn.Modules`` of the same class, stacks all of their
  406. parameters and buffers together to make ``params`` and ``buffers``.
  407. Each parameter and buffer in the result will have an additional dimension
  408. of size ``M``.
  409. :func:`combine_state_for_ensemble` also returns ``func``, a functional
  410. version of one of the models in :attr:`models`. One cannot directly run
  411. ``func(params, buffers, *args, **kwargs)`` directly, you probably want to
  412. use ``vmap(func, ...)(params, buffers, *args, **kwargs)``
  413. Here's an example of how to ensemble over a very simple model:
  414. .. code-block:: python
  415. num_models = 5
  416. batch_size = 64
  417. in_features, out_features = 3, 3
  418. models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
  419. data = torch.randn(batch_size, 3)
  420. fmodel, params, buffers = combine_state_for_ensemble(models)
  421. output = vmap(fmodel, (0, 0, None))(params, buffers, data)
  422. assert output.shape == (num_models, batch_size, out_features)
  423. .. warning::
  424. All of the modules being stacked together must be the same (except for
  425. the values of their parameters/buffers). For example, they should be in the
  426. same mode (training vs eval).
  427. This API is subject to change -- we're investigating better ways to
  428. create ensembles and would love your feedback how to improve this.
  429. """
  430. if len(models) == 0:
  431. raise RuntimeError(
  432. "combine_state_for_ensemble: Expected at least one model, got 0."
  433. )
  434. if not (all(m.training for m in models) or all(not m.training for m in models)):
  435. raise RuntimeError(
  436. "combine_state_for_ensemble: Expected all models to "
  437. "have the same training/eval mode."
  438. )
  439. model0_typ = type(models[0])
  440. if not all(type(m) == model0_typ for m in models):
  441. raise RuntimeError(
  442. "combine_state_for_ensemble: Expected all models to be of the same class."
  443. )
  444. funcs, params, buffers = zip(
  445. *[make_functional_with_buffers(model) for model in models]
  446. )
  447. params = transpose_stack(params)
  448. buffers = transpose_stack(buffers)
  449. return funcs[0], params, buffers
  450. def functional_init(
  451. model_class: Type[nn.Module],
  452. ensemble_shape: Union[Tuple[()], Tuple[int]] = (),
  453. device: torch.types.Device = "cpu",
  454. ):
  455. def wrapped(*args, **kwargs):
  456. if len(ensemble_shape) >= 2:
  457. raise ValueError("NYI: ensemble_shape with more than 1 element")
  458. if len(ensemble_shape) == 0:
  459. model = model_class(*args, **kwargs).to(device)
  460. return make_functional_deprecated_v1(model)
  461. num_models = ensemble_shape[0] # type: ignore[misc]
  462. if num_models <= 0:
  463. raise ValueError(f"num_models {num_models} should be > 0")
  464. # NB: Not very efficient, more of a POC
  465. models = tuple(
  466. model_class(*args, **kwargs).to(device) for _ in range(num_models)
  467. )
  468. _, fn, names = make_functional_deprecated_v1(model_class(*args, **kwargs))
  469. weights = tuple(make_functional_deprecated_v1(model)[0] for model in models)
  470. weights = tuple(zip(*weights))
  471. weights = tuple(torch.stack(shards).detach() for shards in weights)
  472. return weights, fn, names
  473. return wrapped
  474. def functional_init_with_buffers(
  475. model_class: Type[nn.Module],
  476. ensemble_shape: Union[Tuple[()], Tuple[int]] = (),
  477. device: torch.types.Device = "cpu",
  478. ):
  479. def wrapped(*args, **kwargs):
  480. if len(ensemble_shape) >= 2:
  481. raise ValueError("NYI: ensemble_shape with more than 1 element")
  482. if len(ensemble_shape) == 0:
  483. model = model_class(*args, **kwargs).to(device)
  484. return make_functional_deprecated_v1(model)
  485. num_models = ensemble_shape[0] # type: ignore[misc]
  486. if num_models <= 0:
  487. raise ValueError(f"num_models {num_models} should be > 0")
  488. # NB: Not very efficient, more of a POC
  489. models = tuple(
  490. model_class(*args, **kwargs).to(device) for _ in range(num_models)
  491. )
  492. (
  493. _,
  494. _,
  495. fn,
  496. weight_names,
  497. buffer_names,
  498. ) = make_functional_with_buffers_deprecated_v1(model_class(*args, **kwargs))
  499. weights, buffers = zip(
  500. *tuple(
  501. make_functional_with_buffers_deprecated_v1(model)[:2]
  502. for model in models
  503. )
  504. )
  505. weights = tuple(zip(*weights))
  506. weights = tuple(torch.stack(shards).detach() for shards in weights)
  507. buffers = tuple(zip(*buffers))
  508. buffers = tuple(torch.stack(shards).detach() for shards in buffers)
  509. return weights, buffers, fn, weight_names, buffer_names
  510. return wrapped