creation.py 527 B

123456789101112131415161718192021
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. from .core import MaskedTensor
  3. __all__ = [
  4. "as_masked_tensor",
  5. "masked_tensor",
  6. ]
  7. """"
  8. These two factory functions are intended to mirror
  9. torch.tensor - guaranteed to be a leaf node
  10. torch.as_tensor - differentiable constructor that preserves the autograd history
  11. """
  12. def masked_tensor(data, mask, requires_grad=False):
  13. return MaskedTensor(data, mask, requires_grad)
  14. def as_masked_tensor(data, mask):
  15. return MaskedTensor._from_values(data, mask)