12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667 |
- # Copyright (c) Facebook, Inc. and its affiliates.
- # All rights reserved.
- #
- # This source code is licensed under the BSD-style license found in the
- # LICENSE file in the root directory of this source tree.
- import torch
- from . import _Tensor, Tensor
- from .reference import _dims, _enable_layers, llist, ltuple
- class DelayedMulTensor(_Tensor):
- def __init__(self, lhs, rhs):
- self._lhs, self._rhs = lhs, rhs
- self._data = None
- self._levels_data = None
- self._has_device = lhs._has_device or rhs._has_device
- self._batchtensor_data = None
- self._tensor_data = None
- @property
- def _levels(self):
- if self._levels_data is None:
- levels = llist(self._lhs._levels)
- for l in self._rhs._levels:
- if l not in levels:
- levels.append(l)
- self._levels_data = ltuple(levels)
- return self._levels_data
- @property
- def _batchtensor(self):
- if self._batchtensor_data is None:
- with _enable_layers(self._levels):
- print("bt multiply fallback")
- self._batchtensor_data = self._lhs._batchtensor * self._rhs._batchtensor
- return self._batchtensor_data
- @property
- def _tensor(self):
- if self._tensor_data is None:
- self._tensor_data = Tensor.from_batched(self._batchtensor, self._has_device)._tensor
- return self._tensor_data
- @property
- def ndim(self):
- return self._batchtensor.ndim
- @property
- def dims(self):
- return ltuple(super().dims)
- def sum(self, dim):
- dims = _dims(dim, 0, False, False)
- n = ord('a')
- all_levels = self._levels
- def to_char(d):
- return chr(n + all_levels.index(d))
- plhs, levelslhs = self._lhs._tensor, self._lhs._levels
- prhs, levelsrhs = self._rhs._tensor, self._rhs._levels
- new_dims = tuple(d for d in self.dims if d not in dims)
- new_levels = [l for l in self._levels if l not in dims]
- fmt = ''.join([*(to_char(d) for d in levelslhs), ',',
- *(to_char(d) for d in levelsrhs), '->',
- *(to_char(d) for d in new_levels)])
- result_data = torch.einsum(fmt, (plhs, prhs))
- return Tensor.from_positional(result_data, new_levels, True)
|