autocast_mode.py 1.2 KB

123456789101112131415161718192021222324252627282930313233
  1. import torch
  2. from typing import Any
  3. __all__ = ["autocast"]
  4. class autocast(torch.amp.autocast_mode.autocast):
  5. r"""
  6. See :class:`torch.autocast`.
  7. ``torch.cpu.amp.autocast(args...)`` is equivalent to ``torch.autocast("cpu", args...)``
  8. """
  9. def __init__(self, enabled : bool = True, dtype : torch.dtype = torch.bfloat16, cache_enabled : bool = True):
  10. if torch._jit_internal.is_scripting():
  11. self._enabled = enabled
  12. self.device = "cpu"
  13. self.fast_dtype = dtype
  14. return
  15. super().__init__("cpu", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled)
  16. def __enter__(self):
  17. if torch._jit_internal.is_scripting():
  18. return self
  19. return super().__enter__()
  20. # TODO: discuss a unified TorchScript-friendly API for autocast
  21. def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override]
  22. if torch._jit_internal.is_scripting():
  23. return
  24. return super().__exit__(exc_type, exc_val, exc_tb)
  25. def __call__(self, func):
  26. if torch._jit_internal.is_scripting():
  27. return func
  28. return super().__call__(func)