utils.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. import contextlib
  2. import torch
  3. from torch._C._functorch import (
  4. set_single_level_autograd_function_allowed,
  5. get_single_level_autograd_function_allowed,
  6. unwrap_if_dead,
  7. )
  8. @contextlib.contextmanager
  9. def enable_single_level_autograd_function():
  10. try:
  11. prev_state = get_single_level_autograd_function_allowed()
  12. set_single_level_autograd_function_allowed(True)
  13. yield
  14. finally:
  15. set_single_level_autograd_function_allowed(prev_state)
  16. def unwrap_dead_wrappers(args):
  17. # NB: doesn't use tree_map_only for performance reasons
  18. result = tuple(
  19. unwrap_if_dead(arg) if isinstance(arg, torch.Tensor) else arg
  20. for arg in args
  21. )
  22. return result
  23. # Allows one to expose an API in a private submodule publicly as per the definition
  24. # in PyTorch's public api policy.
  25. #
  26. # It is a temporary solution while we figure out if it should be the long-term solution
  27. # or if we should amend PyTorch's public api policy. The concern is that this approach
  28. # may not be very robust because it's not clear what __module__ is used for.
  29. # However, both numpy and jax overwrite the __module__ attribute of their APIs
  30. # without problem, so it seems fine.
  31. def exposed_in(module):
  32. def wrapper(fn):
  33. fn.__module__ = module
  34. return fn
  35. return wrapper