__future__.py 815 B

123456789101112131415161718192021
  1. """
  2. This global flag controls whether to assign new tensors to the parameters
  3. instead of changing the existing parameters in-place when converting an `nn.Module`
  4. using the following methods:
  5. 1. `module.cuda()` / `.cpu()` (for moving `module` between devices)
  6. 2. `module.float()` / `.double()` / `.half()` (for converting `module` to a different dtype)
  7. 3. `module.to()` / `.type()` (for changing `module`'s device or dtype)
  8. 4. `module._apply(fn)` (for generic functions applied to `module`)
  9. Default: False
  10. """
  11. _overwrite_module_params_on_conversion = False
  12. def set_overwrite_module_params_on_conversion(value):
  13. global _overwrite_module_params_on_conversion
  14. _overwrite_module_params_on_conversion = value
  15. def get_overwrite_module_params_on_conversion():
  16. return _overwrite_module_params_on_conversion