init.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. import inspect
  2. import torch
  3. def skip_init(module_cls, *args, **kwargs):
  4. r"""
  5. Given a module class object and args / kwargs, instantiates the module without initializing
  6. parameters / buffers. This can be useful if initialization is slow or if custom initialization will
  7. be performed, making the default initialization unnecessary. There are some caveats to this, due to
  8. the way this function is implemented:
  9. 1. The module must accept a `device` arg in its constructor that is passed to any parameters
  10. or buffers created during construction.
  11. 2. The module must not perform any computation on parameters in its constructor except
  12. initialization (i.e. functions from :mod:`torch.nn.init`).
  13. If these conditions are satisfied, the module can be instantiated with parameter / buffer values
  14. uninitialized, as if having been created using :func:`torch.empty`.
  15. Args:
  16. module_cls: Class object; should be a subclass of :class:`torch.nn.Module`
  17. args: args to pass to the module's constructor
  18. kwargs: kwargs to pass to the module's constructor
  19. Returns:
  20. Instantiated module with uninitialized parameters / buffers
  21. Example::
  22. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  23. >>> import torch
  24. >>> m = torch.nn.utils.skip_init(torch.nn.Linear, 5, 1)
  25. >>> m.weight
  26. Parameter containing:
  27. tensor([[0.0000e+00, 1.5846e+29, 7.8307e+00, 2.5250e-29, 1.1210e-44]],
  28. requires_grad=True)
  29. >>> m2 = torch.nn.utils.skip_init(torch.nn.Linear, in_features=6, out_features=1)
  30. >>> m2.weight
  31. Parameter containing:
  32. tensor([[-1.4677e+24, 4.5915e-41, 1.4013e-45, 0.0000e+00, -1.4677e+24,
  33. 4.5915e-41]], requires_grad=True)
  34. """
  35. if not issubclass(module_cls, torch.nn.Module):
  36. raise RuntimeError('Expected a Module; got {}'.format(module_cls))
  37. if 'device' not in inspect.signature(module_cls).parameters:
  38. raise RuntimeError('Module must support a \'device\' arg to skip initialization')
  39. final_device = kwargs.pop('device', 'cpu')
  40. kwargs['device'] = 'meta'
  41. return module_cls(*args, **kwargs).to_empty(device=final_device)