dropout.py 743 B

123456789101112131415161718192021222324252627
  1. import torch
  2. __all__ = ['Dropout']
  3. class Dropout(torch.nn.Dropout):
  4. r"""This is the quantized equivalent of :class:`~torch.nn.Dropout`.
  5. And this is a placeholder to enable models where fp32 tensors
  6. had dropout to work with quantized tensors in train and eval mode.
  7. Args:
  8. p: probability of an element to be zeroed
  9. inplace: can optionally do the operation in-place. Default: ``False``
  10. """
  11. def forward(self, input):
  12. return input
  13. def _get_name(self):
  14. return 'QuantizedDropout'
  15. @classmethod
  16. def from_float(cls, mod):
  17. return cls(mod.p, mod.inplace)
  18. @classmethod
  19. def from_reference(cls, mod, scale, zero_point):
  20. return cls(mod.p, mod.inplace)