delayed_mul_tensor.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the BSD-style license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import torch
  7. from . import _Tensor, Tensor
  8. from .reference import _dims, _enable_layers, llist, ltuple
  9. class DelayedMulTensor(_Tensor):
  10. def __init__(self, lhs, rhs):
  11. self._lhs, self._rhs = lhs, rhs
  12. self._data = None
  13. self._levels_data = None
  14. self._has_device = lhs._has_device or rhs._has_device
  15. self._batchtensor_data = None
  16. self._tensor_data = None
  17. @property
  18. def _levels(self):
  19. if self._levels_data is None:
  20. levels = llist(self._lhs._levels)
  21. for l in self._rhs._levels:
  22. if l not in levels:
  23. levels.append(l)
  24. self._levels_data = ltuple(levels)
  25. return self._levels_data
  26. @property
  27. def _batchtensor(self):
  28. if self._batchtensor_data is None:
  29. with _enable_layers(self._levels):
  30. print("bt multiply fallback")
  31. self._batchtensor_data = self._lhs._batchtensor * self._rhs._batchtensor
  32. return self._batchtensor_data
  33. @property
  34. def _tensor(self):
  35. if self._tensor_data is None:
  36. self._tensor_data = Tensor.from_batched(self._batchtensor, self._has_device)._tensor
  37. return self._tensor_data
  38. @property
  39. def ndim(self):
  40. return self._batchtensor.ndim
  41. @property
  42. def dims(self):
  43. return ltuple(super().dims)
  44. def sum(self, dim):
  45. dims = _dims(dim, 0, False, False)
  46. n = ord('a')
  47. all_levels = self._levels
  48. def to_char(d):
  49. return chr(n + all_levels.index(d))
  50. plhs, levelslhs = self._lhs._tensor, self._lhs._levels
  51. prhs, levelsrhs = self._rhs._tensor, self._rhs._levels
  52. new_dims = tuple(d for d in self.dims if d not in dims)
  53. new_levels = [l for l in self._levels if l not in dims]
  54. fmt = ''.join([*(to_char(d) for d in levelslhs), ',',
  55. *(to_char(d) for d in levelsrhs), '->',
  56. *(to_char(d) for d in new_levels)])
  57. result_data = torch.einsum(fmt, (plhs, prhs))
  58. return Tensor.from_positional(result_data, new_levels, True)