rnn.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import torch.cuda
  2. try:
  3. from torch._C import _cudnn
  4. except ImportError:
  5. # Uses of all the functions below should be guarded by torch.backends.cudnn.is_available(),
  6. # so it's safe to not emit any checks here.
  7. _cudnn = None # type: ignore[assignment]
  8. def get_cudnn_mode(mode):
  9. if mode == 'RNN_RELU':
  10. return int(_cudnn.RNNMode.rnn_relu)
  11. elif mode == 'RNN_TANH':
  12. return int(_cudnn.RNNMode.rnn_tanh)
  13. elif mode == 'LSTM':
  14. return int(_cudnn.RNNMode.lstm)
  15. elif mode == 'GRU':
  16. return int(_cudnn.RNNMode.gru)
  17. else:
  18. raise Exception("Unknown mode: {}".format(mode))
  19. # NB: We don't actually need this class anymore (in fact, we could serialize the
  20. # dropout state for even better reproducibility), but it is kept for backwards
  21. # compatibility for old models.
  22. class Unserializable:
  23. def __init__(self, inner):
  24. self.inner = inner
  25. def get(self):
  26. return self.inner
  27. def __getstate__(self):
  28. # Note: can't return {}, because python2 won't call __setstate__
  29. # if the value evaluates to False
  30. return "<unserializable>"
  31. def __setstate__(self, state):
  32. self.inner = None
  33. def init_dropout_state(dropout, train, dropout_seed, dropout_state):
  34. dropout_desc_name = 'desc_' + str(torch.cuda.current_device())
  35. dropout_p = dropout if train else 0
  36. if (dropout_desc_name not in dropout_state) or (dropout_state[dropout_desc_name].get() is None):
  37. if dropout_p == 0:
  38. dropout_state[dropout_desc_name] = Unserializable(None)
  39. else:
  40. dropout_state[dropout_desc_name] = Unserializable(torch._cudnn_init_dropout_state( # type: ignore[call-arg]
  41. dropout_p,
  42. train,
  43. dropout_seed,
  44. self_ty=torch.uint8,
  45. device=torch.device('cuda')))
  46. dropout_ts = dropout_state[dropout_desc_name].get()
  47. return dropout_ts