container.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918
  1. import warnings
  2. from collections import OrderedDict, abc as container_abcs
  3. from itertools import chain, islice
  4. import operator
  5. import torch
  6. from .module import Module
  7. from ..parameter import Parameter
  8. from torch._jit_internal import _copy_to_script_wrapper
  9. from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, overload, Tuple, TypeVar, Union
  10. __all__ = ['Container', 'Sequential', 'ModuleList', 'ModuleDict', 'ParameterList', 'ParameterDict']
  11. T = TypeVar('T', bound=Module)
  12. # Copied from torch.nn.modules.module, required for a cusom __repr__ for ModuleList
  13. def _addindent(s_, numSpaces):
  14. s = s_.split('\n')
  15. # don't do anything for single-line stuff
  16. if len(s) == 1:
  17. return s_
  18. first = s.pop(0)
  19. s = [(numSpaces * ' ') + line for line in s]
  20. s = '\n'.join(s)
  21. s = first + '\n' + s
  22. return s
  23. class Container(Module):
  24. def __init__(self, **kwargs: Any) -> None:
  25. super().__init__()
  26. # DeprecationWarning is ignored by default <sigh>
  27. warnings.warn("nn.Container is deprecated. All of it's functionality "
  28. "is now implemented in nn.Module. Subclass that instead.")
  29. for key, value in kwargs.items():
  30. self.add_module(key, value)
  31. class Sequential(Module):
  32. r"""A sequential container.
  33. Modules will be added to it in the order they are passed in the
  34. constructor. Alternatively, an ``OrderedDict`` of modules can be
  35. passed in. The ``forward()`` method of ``Sequential`` accepts any
  36. input and forwards it to the first module it contains. It then
  37. "chains" outputs to inputs sequentially for each subsequent module,
  38. finally returning the output of the last module.
  39. The value a ``Sequential`` provides over manually calling a sequence
  40. of modules is that it allows treating the whole container as a
  41. single module, such that performing a transformation on the
  42. ``Sequential`` applies to each of the modules it stores (which are
  43. each a registered submodule of the ``Sequential``).
  44. What's the difference between a ``Sequential`` and a
  45. :class:`torch.nn.ModuleList`? A ``ModuleList`` is exactly what it
  46. sounds like--a list for storing ``Module`` s! On the other hand,
  47. the layers in a ``Sequential`` are connected in a cascading way.
  48. Example::
  49. # Using Sequential to create a small model. When `model` is run,
  50. # input will first be passed to `Conv2d(1,20,5)`. The output of
  51. # `Conv2d(1,20,5)` will be used as the input to the first
  52. # `ReLU`; the output of the first `ReLU` will become the input
  53. # for `Conv2d(20,64,5)`. Finally, the output of
  54. # `Conv2d(20,64,5)` will be used as input to the second `ReLU`
  55. model = nn.Sequential(
  56. nn.Conv2d(1,20,5),
  57. nn.ReLU(),
  58. nn.Conv2d(20,64,5),
  59. nn.ReLU()
  60. )
  61. # Using Sequential with OrderedDict. This is functionally the
  62. # same as the above code
  63. model = nn.Sequential(OrderedDict([
  64. ('conv1', nn.Conv2d(1,20,5)),
  65. ('relu1', nn.ReLU()),
  66. ('conv2', nn.Conv2d(20,64,5)),
  67. ('relu2', nn.ReLU())
  68. ]))
  69. """
  70. _modules: Dict[str, Module] # type: ignore[assignment]
  71. @overload
  72. def __init__(self, *args: Module) -> None:
  73. ...
  74. @overload
  75. def __init__(self, arg: 'OrderedDict[str, Module]') -> None:
  76. ...
  77. def __init__(self, *args):
  78. super().__init__()
  79. if len(args) == 1 and isinstance(args[0], OrderedDict):
  80. for key, module in args[0].items():
  81. self.add_module(key, module)
  82. else:
  83. for idx, module in enumerate(args):
  84. self.add_module(str(idx), module)
  85. def _get_item_by_idx(self, iterator, idx) -> T:
  86. """Get the idx-th item of the iterator"""
  87. size = len(self)
  88. idx = operator.index(idx)
  89. if not -size <= idx < size:
  90. raise IndexError('index {} is out of range'.format(idx))
  91. idx %= size
  92. return next(islice(iterator, idx, None))
  93. @_copy_to_script_wrapper
  94. def __getitem__(self, idx: Union[slice, int]) -> Union['Sequential', T]:
  95. if isinstance(idx, slice):
  96. return self.__class__(OrderedDict(list(self._modules.items())[idx]))
  97. else:
  98. return self._get_item_by_idx(self._modules.values(), idx)
  99. def __setitem__(self, idx: int, module: Module) -> None:
  100. key: str = self._get_item_by_idx(self._modules.keys(), idx)
  101. return setattr(self, key, module)
  102. def __delitem__(self, idx: Union[slice, int]) -> None:
  103. if isinstance(idx, slice):
  104. for key in list(self._modules.keys())[idx]:
  105. delattr(self, key)
  106. else:
  107. key = self._get_item_by_idx(self._modules.keys(), idx)
  108. delattr(self, key)
  109. # To preserve numbering
  110. str_indices = [str(i) for i in range(len(self._modules))]
  111. self._modules = OrderedDict(list(zip(str_indices, self._modules.values())))
  112. @_copy_to_script_wrapper
  113. def __len__(self) -> int:
  114. return len(self._modules)
  115. def __add__(self, other) -> 'Sequential':
  116. if isinstance(other, Sequential):
  117. ret = Sequential()
  118. for layer in self:
  119. ret.append(layer)
  120. for layer in other:
  121. ret.append(layer)
  122. return ret
  123. else:
  124. raise ValueError('add operator supports only objects '
  125. 'of Sequential class, but {} is given.'.format(
  126. str(type(other))))
  127. def pop(self, key: Union[int, slice]) -> Module:
  128. v = self[key]
  129. del self[key]
  130. return v
  131. def __iadd__(self, other) -> 'Sequential':
  132. if isinstance(other, Sequential):
  133. offset = len(self)
  134. for i, module in enumerate(other):
  135. self.add_module(str(i + offset), module)
  136. return self
  137. else:
  138. raise ValueError('add operator supports only objects '
  139. 'of Sequential class, but {} is given.'.format(
  140. str(type(other))))
  141. def __mul__(self, other: int) -> 'Sequential':
  142. if not isinstance(other, int):
  143. raise TypeError(f"unsupported operand type(s) for *: {type(self)} and {type(other)}")
  144. elif (other <= 0):
  145. raise ValueError(f"Non-positive multiplication factor {other} for {type(self)}")
  146. else:
  147. combined = Sequential()
  148. offset = 0
  149. for _ in range(other):
  150. for module in self:
  151. combined.add_module(str(offset), module)
  152. offset += 1
  153. return combined
  154. def __rmul__(self, other: int) -> 'Sequential':
  155. return self.__mul__(other)
  156. def __imul__(self, other: int) -> 'Sequential':
  157. if not isinstance(other, int):
  158. raise TypeError(f"unsupported operand type(s) for *: {type(self)} and {type(other)}")
  159. elif (other <= 0):
  160. raise ValueError(f"Non-positive multiplication factor {other} for {type(self)}")
  161. else:
  162. len_original = len(self)
  163. offset = len(self)
  164. for _ in range(other - 1):
  165. for i in range(len_original):
  166. self.add_module(str(i + offset), self._modules[str(i)])
  167. offset += len_original
  168. return self
  169. @_copy_to_script_wrapper
  170. def __dir__(self):
  171. keys = super().__dir__()
  172. keys = [key for key in keys if not key.isdigit()]
  173. return keys
  174. @_copy_to_script_wrapper
  175. def __iter__(self) -> Iterator[Module]:
  176. return iter(self._modules.values())
  177. # NB: We can't really type check this function as the type of input
  178. # may change dynamically (as is tested in
  179. # TestScript.test_sequential_intermediary_types). Cannot annotate
  180. # with Any as TorchScript expects a more precise type
  181. def forward(self, input):
  182. for module in self:
  183. input = module(input)
  184. return input
  185. def append(self, module: Module) -> 'Sequential':
  186. r"""Appends a given module to the end.
  187. Args:
  188. module (nn.Module): module to append
  189. """
  190. self.add_module(str(len(self)), module)
  191. return self
  192. def insert(self, index: int, module: Module) -> 'Sequential':
  193. if not isinstance(module, Module):
  194. raise AssertionError(
  195. 'module should be of type: {}'.format(Module))
  196. n = len(self._modules)
  197. if not (-n <= index <= n):
  198. raise IndexError(
  199. 'Index out of range: {}'.format(index))
  200. if index < 0:
  201. index += n
  202. for i in range(n, index, -1):
  203. self._modules[str(i)] = self._modules[str(i - 1)]
  204. self._modules[str(index)] = module
  205. return self
  206. def extend(self, sequential) -> 'Sequential':
  207. for layer in sequential:
  208. self.append(layer)
  209. return self
  210. class ModuleList(Module):
  211. r"""Holds submodules in a list.
  212. :class:`~torch.nn.ModuleList` can be indexed like a regular Python list, but
  213. modules it contains are properly registered, and will be visible by all
  214. :class:`~torch.nn.Module` methods.
  215. Args:
  216. modules (iterable, optional): an iterable of modules to add
  217. Example::
  218. class MyModule(nn.Module):
  219. def __init__(self):
  220. super().__init__()
  221. self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])
  222. def forward(self, x):
  223. # ModuleList can act as an iterable, or be indexed using ints
  224. for i, l in enumerate(self.linears):
  225. x = self.linears[i // 2](x) + l(x)
  226. return x
  227. """
  228. _modules: Dict[str, Module] # type: ignore[assignment]
  229. def __init__(self, modules: Optional[Iterable[Module]] = None) -> None:
  230. super().__init__()
  231. if modules is not None:
  232. self += modules
  233. def _get_abs_string_index(self, idx):
  234. """Get the absolute index for the list of modules"""
  235. idx = operator.index(idx)
  236. if not (-len(self) <= idx < len(self)):
  237. raise IndexError('index {} is out of range'.format(idx))
  238. if idx < 0:
  239. idx += len(self)
  240. return str(idx)
  241. @_copy_to_script_wrapper
  242. def __getitem__(self, idx: Union[int, slice]) -> Union[Module, 'ModuleList']:
  243. if isinstance(idx, slice):
  244. return self.__class__(list(self._modules.values())[idx])
  245. else:
  246. return self._modules[self._get_abs_string_index(idx)]
  247. def __setitem__(self, idx: int, module: Module) -> None:
  248. idx = self._get_abs_string_index(idx)
  249. return setattr(self, str(idx), module)
  250. def __delitem__(self, idx: Union[int, slice]) -> None:
  251. if isinstance(idx, slice):
  252. for k in range(len(self._modules))[idx]:
  253. delattr(self, str(k))
  254. else:
  255. delattr(self, self._get_abs_string_index(idx))
  256. # To preserve numbering, self._modules is being reconstructed with modules after deletion
  257. str_indices = [str(i) for i in range(len(self._modules))]
  258. self._modules = OrderedDict(list(zip(str_indices, self._modules.values())))
  259. @_copy_to_script_wrapper
  260. def __len__(self) -> int:
  261. return len(self._modules)
  262. @_copy_to_script_wrapper
  263. def __iter__(self) -> Iterator[Module]:
  264. return iter(self._modules.values())
  265. def __iadd__(self, modules: Iterable[Module]) -> 'ModuleList':
  266. return self.extend(modules)
  267. def __add__(self, other: Iterable[Module]) -> 'ModuleList':
  268. combined = ModuleList()
  269. for i, module in enumerate(chain(self, other)):
  270. combined.add_module(str(i), module)
  271. return combined
  272. def __repr__(self):
  273. """A custom repr for ModuleList that compresses repeated module representations"""
  274. list_of_reprs = [repr(item) for item in self]
  275. if len(list_of_reprs) == 0:
  276. return self._get_name() + '()'
  277. start_end_indices = [[0, 0]]
  278. repeated_blocks = [list_of_reprs[0]]
  279. for i, r in enumerate(list_of_reprs[1:], 1):
  280. if r == repeated_blocks[-1]:
  281. start_end_indices[-1][1] += 1
  282. continue
  283. start_end_indices.append([i, i])
  284. repeated_blocks.append(r)
  285. lines = []
  286. main_str = self._get_name() + '('
  287. for (start_id, end_id), b in zip(start_end_indices, repeated_blocks):
  288. local_repr = f"({start_id}): {b}" # default repr
  289. if start_id != end_id:
  290. n = end_id - start_id + 1
  291. local_repr = f"({start_id}-{end_id}): {n} x {b}"
  292. local_repr = _addindent(local_repr, 2)
  293. lines.append(local_repr)
  294. main_str += '\n ' + '\n '.join(lines) + '\n'
  295. main_str += ')'
  296. return main_str
  297. @_copy_to_script_wrapper
  298. def __dir__(self):
  299. keys = super().__dir__()
  300. keys = [key for key in keys if not key.isdigit()]
  301. return keys
  302. def insert(self, index: int, module: Module) -> None:
  303. r"""Insert a given module before a given index in the list.
  304. Args:
  305. index (int): index to insert.
  306. module (nn.Module): module to insert
  307. """
  308. for i in range(len(self._modules), index, -1):
  309. self._modules[str(i)] = self._modules[str(i - 1)]
  310. self._modules[str(index)] = module
  311. def append(self, module: Module) -> 'ModuleList':
  312. r"""Appends a given module to the end of the list.
  313. Args:
  314. module (nn.Module): module to append
  315. """
  316. self.add_module(str(len(self)), module)
  317. return self
  318. def pop(self, key: Union[int, slice]) -> Module:
  319. v = self[key]
  320. del self[key]
  321. return v
  322. def extend(self, modules: Iterable[Module]) -> 'ModuleList':
  323. r"""Appends modules from a Python iterable to the end of the list.
  324. Args:
  325. modules (iterable): iterable of modules to append
  326. """
  327. if not isinstance(modules, container_abcs.Iterable):
  328. raise TypeError("ModuleList.extend should be called with an "
  329. "iterable, but got " + type(modules).__name__)
  330. offset = len(self)
  331. for i, module in enumerate(modules):
  332. self.add_module(str(offset + i), module)
  333. return self
  334. # remove forward alltogether to fallback on Module's _forward_unimplemented
  335. class ModuleDict(Module):
  336. r"""Holds submodules in a dictionary.
  337. :class:`~torch.nn.ModuleDict` can be indexed like a regular Python dictionary,
  338. but modules it contains are properly registered, and will be visible by all
  339. :class:`~torch.nn.Module` methods.
  340. :class:`~torch.nn.ModuleDict` is an **ordered** dictionary that respects
  341. * the order of insertion, and
  342. * in :meth:`~torch.nn.ModuleDict.update`, the order of the merged
  343. ``OrderedDict``, ``dict`` (started from Python 3.6) or another
  344. :class:`~torch.nn.ModuleDict` (the argument to
  345. :meth:`~torch.nn.ModuleDict.update`).
  346. Note that :meth:`~torch.nn.ModuleDict.update` with other unordered mapping
  347. types (e.g., Python's plain ``dict`` before Python version 3.6) does not
  348. preserve the order of the merged mapping.
  349. Args:
  350. modules (iterable, optional): a mapping (dictionary) of (string: module)
  351. or an iterable of key-value pairs of type (string, module)
  352. Example::
  353. class MyModule(nn.Module):
  354. def __init__(self):
  355. super().__init__()
  356. self.choices = nn.ModuleDict({
  357. 'conv': nn.Conv2d(10, 10, 3),
  358. 'pool': nn.MaxPool2d(3)
  359. })
  360. self.activations = nn.ModuleDict([
  361. ['lrelu', nn.LeakyReLU()],
  362. ['prelu', nn.PReLU()]
  363. ])
  364. def forward(self, x, choice, act):
  365. x = self.choices[choice](x)
  366. x = self.activations[act](x)
  367. return x
  368. """
  369. _modules: Dict[str, Module] # type: ignore[assignment]
  370. def __init__(self, modules: Optional[Mapping[str, Module]] = None) -> None:
  371. super().__init__()
  372. if modules is not None:
  373. self.update(modules)
  374. @_copy_to_script_wrapper
  375. def __getitem__(self, key: str) -> Module:
  376. return self._modules[key]
  377. def __setitem__(self, key: str, module: Module) -> None:
  378. self.add_module(key, module)
  379. def __delitem__(self, key: str) -> None:
  380. del self._modules[key]
  381. @_copy_to_script_wrapper
  382. def __len__(self) -> int:
  383. return len(self._modules)
  384. @_copy_to_script_wrapper
  385. def __iter__(self) -> Iterator[str]:
  386. return iter(self._modules)
  387. @_copy_to_script_wrapper
  388. def __contains__(self, key: str) -> bool:
  389. return key in self._modules
  390. def clear(self) -> None:
  391. """Remove all items from the ModuleDict.
  392. """
  393. self._modules.clear()
  394. def pop(self, key: str) -> Module:
  395. r"""Remove key from the ModuleDict and return its module.
  396. Args:
  397. key (str): key to pop from the ModuleDict
  398. """
  399. v = self[key]
  400. del self[key]
  401. return v
  402. @_copy_to_script_wrapper
  403. def keys(self) -> Iterable[str]:
  404. r"""Return an iterable of the ModuleDict keys.
  405. """
  406. return self._modules.keys()
  407. @_copy_to_script_wrapper
  408. def items(self) -> Iterable[Tuple[str, Module]]:
  409. r"""Return an iterable of the ModuleDict key/value pairs.
  410. """
  411. return self._modules.items()
  412. @_copy_to_script_wrapper
  413. def values(self) -> Iterable[Module]:
  414. r"""Return an iterable of the ModuleDict values.
  415. """
  416. return self._modules.values()
  417. def update(self, modules: Mapping[str, Module]) -> None:
  418. r"""Update the :class:`~torch.nn.ModuleDict` with the key-value pairs from a
  419. mapping or an iterable, overwriting existing keys.
  420. .. note::
  421. If :attr:`modules` is an ``OrderedDict``, a :class:`~torch.nn.ModuleDict`, or
  422. an iterable of key-value pairs, the order of new elements in it is preserved.
  423. Args:
  424. modules (iterable): a mapping (dictionary) from string to :class:`~torch.nn.Module`,
  425. or an iterable of key-value pairs of type (string, :class:`~torch.nn.Module`)
  426. """
  427. if not isinstance(modules, container_abcs.Iterable):
  428. raise TypeError("ModuleDict.update should be called with an "
  429. "iterable of key/value pairs, but got " +
  430. type(modules).__name__)
  431. if isinstance(modules, (OrderedDict, ModuleDict, container_abcs.Mapping)):
  432. for key, module in modules.items():
  433. self[key] = module
  434. else:
  435. # modules here can be a list with two items
  436. for j, m in enumerate(modules):
  437. if not isinstance(m, container_abcs.Iterable):
  438. raise TypeError("ModuleDict update sequence element "
  439. "#" + str(j) + " should be Iterable; is" +
  440. type(m).__name__)
  441. if not len(m) == 2:
  442. raise ValueError("ModuleDict update sequence element "
  443. "#" + str(j) + " has length " + str(len(m)) +
  444. "; 2 is required")
  445. # modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)]
  446. # that's too cumbersome to type correctly with overloads, so we add an ignore here
  447. self[m[0]] = m[1] # type: ignore[assignment]
  448. # remove forward alltogether to fallback on Module's _forward_unimplemented
  449. class ParameterList(Module):
  450. r"""Holds parameters in a list.
  451. :class:`~torch.nn.ParameterList` can be used like a regular Python
  452. list, but Tensors that are :class:`~torch.nn.Parameter` are properly registered,
  453. and will be visible by all :class:`~torch.nn.Module` methods.
  454. Note that the constructor, assigning an element of the list, the
  455. :meth:`~torch.nn.ParameterDict.append` method and the :meth:`~torch.nn.ParameterDict.extend`
  456. method will convert any :class:`~torch.Tensor` into :class:`~torch.nn.Parameter`.
  457. Args:
  458. parameters (iterable, optional): an iterable of elements to add to the list.
  459. Example::
  460. class MyModule(nn.Module):
  461. def __init__(self):
  462. super().__init__()
  463. self.params = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(10)])
  464. def forward(self, x):
  465. # ParameterList can act as an iterable, or be indexed using ints
  466. for i, p in enumerate(self.params):
  467. x = self.params[i // 2].mm(x) + p.mm(x)
  468. return x
  469. """
  470. def __init__(self, values: Optional[Iterable[Any]] = None) -> None:
  471. super().__init__()
  472. self._size = 0
  473. if values is not None:
  474. self += values
  475. def _get_abs_string_index(self, idx):
  476. """Get the absolute index for the list of modules"""
  477. idx = operator.index(idx)
  478. if not (-len(self) <= idx < len(self)):
  479. raise IndexError('index {} is out of range'.format(idx))
  480. if idx < 0:
  481. idx += len(self)
  482. return str(idx)
  483. @overload
  484. def __getitem__(self, idx: int) -> Any:
  485. ...
  486. @overload
  487. def __getitem__(self: T, idx: slice) -> T:
  488. ...
  489. def __getitem__(self, idx):
  490. if isinstance(idx, slice):
  491. start, stop, step = idx.indices(len(self))
  492. out = self.__class__()
  493. for i in range(start, stop, step):
  494. out.append(self[i])
  495. return out
  496. else:
  497. idx = self._get_abs_string_index(idx)
  498. return getattr(self, str(idx))
  499. def __setitem__(self, idx: int, param: Any) -> None:
  500. # Note that all other function that add an entry to the list part of
  501. # the ParameterList end up here. So this is the only place where we need
  502. # to wrap things into Parameter if needed.
  503. # Objects added via setattr() are not in the list part and thus won't
  504. # call into this function.
  505. idx = self._get_abs_string_index(idx)
  506. if isinstance(param, torch.Tensor) and not isinstance(param, Parameter):
  507. param = Parameter(param)
  508. return setattr(self, str(idx), param)
  509. def __len__(self) -> int:
  510. return self._size
  511. def __iter__(self) -> Iterator[Any]:
  512. return iter(self[i] for i in range(len(self)))
  513. def __iadd__(self, parameters: Iterable[Any]) -> 'ParameterList':
  514. return self.extend(parameters)
  515. def __dir__(self):
  516. keys = super().__dir__()
  517. keys = [key for key in keys if not key.isdigit()]
  518. return keys
  519. def append(self, value: Any) -> 'ParameterList':
  520. """Appends a given value at the end of the list.
  521. Args:
  522. value (Any): value to append
  523. """
  524. new_idx = len(self)
  525. self._size += 1
  526. self[new_idx] = value
  527. return self
  528. def extend(self, values: Iterable[Any]) -> 'ParameterList':
  529. """Appends values from a Python iterable to the end of the list.
  530. Args:
  531. values (iterable): iterable of values to append
  532. """
  533. # Tensor is an iterable but we never want to unpack it here
  534. if not isinstance(values, container_abcs.Iterable) or isinstance(values, torch.Tensor):
  535. raise TypeError("ParameterList.extend should be called with an "
  536. "iterable, but got " + type(values).__name__)
  537. for value in values:
  538. self.append(value)
  539. return self
  540. def extra_repr(self) -> str:
  541. child_lines = []
  542. for k, p in enumerate(self):
  543. if isinstance(p, torch.Tensor):
  544. size_str = 'x'.join(str(size) for size in p.size())
  545. device_str = '' if not p.is_cuda else ' (GPU {})'.format(p.get_device())
  546. parastr = '{} containing: [{} of size {}{}]'.format(
  547. "Parameter" if isinstance(p, Parameter) else "Tensor",
  548. p.dtype, size_str, device_str)
  549. child_lines.append(' (' + str(k) + '): ' + parastr)
  550. else:
  551. child_lines.append(' (' + str(k) + '): Object of type: ' + type(p).__name__)
  552. tmpstr = '\n'.join(child_lines)
  553. return tmpstr
  554. def __call__(self, *args, **kwargs):
  555. raise RuntimeError('ParameterList should not be called.')
  556. class ParameterDict(Module):
  557. r"""Holds parameters in a dictionary.
  558. ParameterDict can be indexed like a regular Python dictionary, but Parameters it
  559. contains are properly registered, and will be visible by all Module methods.
  560. Other objects are treated as would be done by a regular Python dictionary
  561. :class:`~torch.nn.ParameterDict` is an **ordered** dictionary.
  562. :meth:`~torch.nn.ParameterDict.update` with other unordered mapping
  563. types (e.g., Python's plain ``dict``) does not preserve the order of the
  564. merged mapping. On the other hand, ``OrderedDict`` or another :class:`~torch.nn.ParameterDict`
  565. will preserve their ordering.
  566. Note that the constructor, assigning an element of the dictionary and the
  567. :meth:`~torch.nn.ParameterDict.update` method will convert any :class:`~torch.Tensor` into
  568. :class:`~torch.nn.Parameter`.
  569. Args:
  570. values (iterable, optional): a mapping (dictionary) of
  571. (string : Any) or an iterable of key-value pairs
  572. of type (string, Any)
  573. Example::
  574. class MyModule(nn.Module):
  575. def __init__(self):
  576. super().__init__()
  577. self.params = nn.ParameterDict({
  578. 'left': nn.Parameter(torch.randn(5, 10)),
  579. 'right': nn.Parameter(torch.randn(5, 10))
  580. })
  581. def forward(self, x, choice):
  582. x = self.params[choice].mm(x)
  583. return x
  584. """
  585. def __init__(self, parameters: Any = None) -> None:
  586. super().__init__()
  587. self._keys: Dict[str, None] = {}
  588. if parameters is not None:
  589. self.update(parameters)
  590. def _key_to_attr(self, key: str) -> str:
  591. if not isinstance(key, str):
  592. raise TypeError("Index given to ParameterDict cannot be used as a key as it is "
  593. f"not a string (type is '{type(key).__name__}'). Open an issue on "
  594. "github if you need non-string keys.")
  595. else:
  596. # Use the key as-is so that `.named_parameters()` returns the right thing
  597. return key
  598. def __getitem__(self, key: str) -> Any:
  599. attr = self._key_to_attr(key)
  600. return getattr(self, attr)
  601. def __setitem__(self, key: str, value: Any) -> None:
  602. # Note that all other function that add an entry to the dictionary part of
  603. # the ParameterDict end up here. So this is the only place where we need
  604. # to wrap things into Parameter if needed.
  605. # Objects added via setattr() are not in the dictionary part and thus won't
  606. # call into this function.
  607. self._keys[key] = None
  608. attr = self._key_to_attr(key)
  609. if isinstance(value, torch.Tensor) and not isinstance(value, Parameter):
  610. value = Parameter(value)
  611. setattr(self, attr, value)
  612. def __delitem__(self, key: str) -> None:
  613. del self._keys[key]
  614. attr = self._key_to_attr(key)
  615. delattr(self, attr)
  616. def __len__(self) -> int:
  617. return len(self._keys)
  618. def __iter__(self) -> Iterator[str]:
  619. return iter(self._keys)
  620. def __reversed__(self) -> Iterator[str]:
  621. return reversed(list(self._keys))
  622. def copy(self) -> 'ParameterDict':
  623. """Returns a copy of this :class:`~torch.nn.ParameterDict` instance.
  624. """
  625. # We have to use an OrderedDict because the ParameterDict constructor
  626. # behaves differently on plain dict vs OrderedDict
  627. return ParameterDict(OrderedDict((k, self[k]) for k in self._keys))
  628. def __contains__(self, key: str) -> bool:
  629. return key in self._keys
  630. def setdefault(self, key: str, default: Optional[Any] = None) -> Any:
  631. """If key is in the ParameterDict, return its value.
  632. If not, insert `key` with a parameter `default` and return `default`.
  633. `default` defaults to `None`.
  634. Args:
  635. key (str): key to set default for
  636. default (Any): the parameter set to the key
  637. """
  638. if key not in self:
  639. self[key] = default
  640. return self[key]
  641. def clear(self) -> None:
  642. """Remove all items from the ParameterDict.
  643. """
  644. for k in self._keys.copy():
  645. del self[k]
  646. def pop(self, key: str) -> Any:
  647. r"""Remove key from the ParameterDict and return its parameter.
  648. Args:
  649. key (str): key to pop from the ParameterDict
  650. """
  651. v = self[key]
  652. del self[key]
  653. return v
  654. def popitem(self) -> Tuple[str, Any]:
  655. """Remove and return the last inserted `(key, parameter)` pair
  656. from the ParameterDict
  657. """
  658. k, _ = self._keys.popitem()
  659. # We need the key in the _keys to be able to access/del
  660. self._keys[k] = None
  661. val = self[k]
  662. del self[k]
  663. return k, val
  664. def get(self, key: str, default: Optional[Any] = None) -> Any:
  665. r"""Return the parameter associated with key if present.
  666. Otherwise return default if provided, None if not.
  667. Args:
  668. key (str): key to get from the ParameterDict
  669. default (Parameter, optional): value to return if key not present
  670. """
  671. return self[key] if key in self else default
  672. def fromkeys(self, keys: Iterable[str], default: Optional[Any] = None) -> 'ParameterDict':
  673. r"""Return a new ParameterDict with the keys provided
  674. Args:
  675. keys (iterable, string): keys to make the new ParameterDict from
  676. default (Parameter, optional): value to set for all keys
  677. """
  678. return ParameterDict(((k, default) for k in keys))
  679. def keys(self) -> Iterable[str]:
  680. r"""Return an iterable of the ParameterDict keys.
  681. """
  682. return self._keys.keys()
  683. def items(self) -> Iterable[Tuple[str, Any]]:
  684. r"""Return an iterable of the ParameterDict key/value pairs.
  685. """
  686. return ((k, self[k]) for k in self._keys)
  687. def values(self) -> Iterable[Any]:
  688. r"""Return an iterable of the ParameterDict values.
  689. """
  690. return (self[k] for k in self._keys)
  691. def update(self, parameters: Union[Mapping[str, Any], 'ParameterDict']) -> None:
  692. r"""Update the :class:`~torch.nn.ParameterDict` with the key-value pairs from a
  693. mapping or an iterable, overwriting existing keys.
  694. .. note::
  695. If :attr:`parameters` is an ``OrderedDict``, a :class:`~torch.nn.ParameterDict`, or
  696. an iterable of key-value pairs, the order of new elements in it is preserved.
  697. Args:
  698. parameters (iterable): a mapping (dictionary) from string to
  699. :class:`~torch.nn.Parameter`, or an iterable of
  700. key-value pairs of type (string, :class:`~torch.nn.Parameter`)
  701. """
  702. if not isinstance(parameters, container_abcs.Iterable):
  703. raise TypeError("ParametersDict.update should be called with an "
  704. "iterable of key/value pairs, but got " +
  705. type(parameters).__name__)
  706. if isinstance(parameters, (OrderedDict, ParameterDict)):
  707. for key, parameter in parameters.items():
  708. self[key] = parameter
  709. elif isinstance(parameters, container_abcs.Mapping):
  710. for key, parameter in sorted(parameters.items()):
  711. self[key] = parameter
  712. else:
  713. for j, p in enumerate(parameters):
  714. if not isinstance(p, container_abcs.Iterable):
  715. raise TypeError("ParameterDict update sequence element "
  716. "#" + str(j) + " should be Iterable; is" +
  717. type(p).__name__)
  718. if not len(p) == 2:
  719. raise ValueError("ParameterDict update sequence element "
  720. "#" + str(j) + " has length " + str(len(p)) +
  721. "; 2 is required")
  722. # parameters as length-2 list too cumbersome to type, see ModuleDict.update comment
  723. self[p[0]] = p[1] # type: ignore[assignment]
  724. def extra_repr(self) -> str:
  725. child_lines = []
  726. for k, p in self.items():
  727. if isinstance(p, torch.Tensor):
  728. size_str = 'x'.join(str(size) for size in p.size())
  729. device_str = '' if not p.is_cuda else ' (GPU {})'.format(p.get_device())
  730. parastr = '{} containing: [{} of size {}{}]'.format(
  731. "Parameter" if isinstance(p, Parameter) else "Tensor",
  732. torch.typename(p), size_str, device_str)
  733. child_lines.append(' (' + str(k) + '): ' + parastr)
  734. else:
  735. child_lines.append(' (' + str(k) + '): Object of type: ' + type(p).__name__)
  736. tmpstr = '\n'.join(child_lines)
  737. return tmpstr
  738. def __call__(self, input):
  739. raise RuntimeError('ParameterDict should not be called.')
  740. def __or__(self, other: 'ParameterDict') -> 'ParameterDict':
  741. copy = self.copy()
  742. copy.update(other)
  743. return copy
  744. def __ror__(self, other: 'ParameterDict') -> 'ParameterDict':
  745. copy = other.copy()
  746. copy.update(self)
  747. return copy
  748. def __ior__(self, other : 'ParameterDict') -> 'ParameterDict':
  749. self.update(other)
  750. return self