123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557 |
- import torch
- from .tree_map import tree_flatten, tree_map
- from .batch_tensor import _enable_layers
- from . import op_properties
- from functorch._C import dim as _C
- DimList = _C.DimList
- from functools import reduce
- import operator
- pointwise = set(op_properties.pointwise)
- def prod(x):
- return reduce(operator.mul, x, 1)
- def _wrap_dim(d, N, keepdim):
- from . import Dim
- if isinstance(d, Dim):
- assert not keepdim, "cannot preserve first-class dimensions with keepdim=True"
- return d
- elif d >= 0:
- return d - N
- else:
- return d
- def _dims(d, N, keepdim, single_dim):
- from . import Dim
- if isinstance(d, (Dim, int)):
- return ltuple((_wrap_dim(d, N, keepdim),))
- assert not single_dim, f"expected a single dimension or int but found: {d}"
- return ltuple(_wrap_dim(x, N, keepdim) for x in d)
- def _bind_dims_to_size(lhs_size, rhs, lhs_debug):
- from . import DimensionMismatchError
- not_bound = tuple((i, r) for i, r in enumerate(rhs) if not r.is_bound)
- if len(not_bound) == 1:
- idx, d = not_bound[0]
- rhs_so_far = prod(r.size for r in rhs if r.is_bound)
- if lhs_size % rhs_so_far != 0:
- rhs_s = tuple('?' if not r.is_bound else str(r.size) for r in rhs)
- raise DimensionMismatchError(f"inferred dimension does not evenly fit into larger dimension: {lhs_size} vs {rhs_s}")
- new_size = lhs_size // rhs_so_far
- d.size = new_size
- elif len(not_bound) > 1:
- rhs_s = tuple('?' if not r.is_bound else str(r.size) for r in rhs)
- raise DimensionMismatchError(f"cannot infer the size of two dimensions at once: {rhs} with sizes {rhs_s}")
- else:
- rhs_size = prod(r.size for r in rhs)
- if lhs_size != rhs_size:
- raise DimensionMismatchError(
- f"Dimension sizes to do not match ({lhs_size} != {rhs_size}) when matching {lhs_debug} to {rhs}")
- def _tensor_levels(inp):
- from . import _Tensor
- if isinstance(inp, _Tensor):
- return inp._tensor, llist(inp._levels), inp._has_device
- else:
- return inp, llist(range(-inp.ndim, 0)), True
- def _match_levels(v, from_levels, to_levels):
- view = []
- permute = []
- requires_view = False
- size = v.size()
- for t in to_levels:
- try:
- idx = from_levels.index(t)
- permute.append(idx)
- view.append(size[idx])
- except ValueError:
- view.append(1)
- requires_view = True
- if permute != list(range(len(permute))):
- v = v.permute(*permute)
- if requires_view:
- v = v.view(*view)
- return v
- def _positional_no_permute(self, dim, expand_dim=False):
- from . import Tensor
- ptensor, levels = self._tensor, llist(self._levels)
- try:
- idx = levels.index(dim)
- except ValueError:
- if not expand_dim:
- raise
- idx = 0
- ptensor = ptensor.expand(dim.size, *ptensor.size())
- levels.insert(0, 0)
- idx_batched = 0
- for i in range(idx):
- if isinstance(levels[i], int):
- levels[i] -= 1
- idx_batched += 1
- levels[idx] = -idx_batched - 1
- return Tensor.from_positional(ptensor, levels, self._has_device), idx_batched
- def seq(a, b):
- from . import Dim
- if isinstance(a, Dim) != isinstance(b, Dim):
- return False
- if isinstance(a, Dim):
- return a is b
- else:
- return a == b
- class isin:
- def __contains__(self, item):
- for x in self:
- if seq(item, x):
- return True
- return False
- def index(self, item):
- for i, x in enumerate(self):
- if seq(item, x):
- return i
- raise ValueError
- class llist(isin, list):
- pass
- class ltuple(isin, tuple):
- pass
- empty_dict = {}
- @classmethod
- def __torch_function__(self, orig, cls, args, kwargs=empty_dict):
- from . import _Tensor, TensorLike, Tensor
- from .delayed_mul_tensor import DelayedMulTensor
- if orig is torch.Tensor.__mul__:
- lhs, rhs = args
- if isinstance(lhs, _Tensor) and isinstance(rhs, _Tensor) and lhs.ndim == 0 and rhs.ndim == 0:
- return DelayedMulTensor(lhs, rhs)
- all_dims = llist()
- flat_args, unflatten = tree_flatten((args, kwargs))
- device_holding_tensor = None
- for f in flat_args:
- if isinstance(f, _Tensor):
- if f._has_device:
- device_holding_tensor = f._batchtensor
- for d in f.dims:
- if d not in all_dims:
- all_dims.append(d)
- def unwrap(t):
- if isinstance(t, _Tensor):
- r = t._batchtensor
- if device_holding_tensor is not None and not t._has_device:
- r = r.to(device=device_holding_tensor.device)
- return r
- return t
- if orig in pointwise:
- result_levels = llist()
- arg_levels = llist()
- to_expand = []
- for i, f in enumerate(flat_args):
- if isinstance(f, TensorLike):
- ptensor, levels, _ = _tensor_levels(f)
- if isinstance(f, _Tensor) and not f._has_device and device_holding_tensor is not None:
- ptensor = ptensor.to(device=device_holding_tensor.device)
- flat_args[i] = ptensor
- for l in levels:
- if l not in result_levels:
- result_levels.append(l)
- to_expand.append((i, levels))
- for i, levels in to_expand:
- flat_args[i] = _match_levels(flat_args[i], levels, result_levels)
- args, kwargs = unflatten(flat_args)
- result = orig(*args, **kwargs)
- def wrap(t):
- if isinstance(t, TensorLike):
- return Tensor.from_positional(t, result_levels, device_holding_tensor is not None)
- return t
- return tree_map(wrap, result)
- else:
- def wrap(t):
- if isinstance(t, TensorLike):
- return Tensor.from_batched(t, device_holding_tensor is not None)
- return t
- with _enable_layers(all_dims):
- print(f"batch_tensor for {orig}")
- args, kwargs = unflatten(unwrap(f) for f in flat_args)
- result = orig(*args, **kwargs)
-
- return tree_map(wrap, result)
- def positional(self, *dims):
- from . import Dim, Tensor
- ptensor, levels = self._tensor, llist(self._levels)
- flat_dims = llist()
- view = []
- needs_view = False
- ndim = self.ndim
- for d in dims:
- if isinstance(d, DimList):
- flat_dims.extend(d)
- view.extend(e.size for e in d)
- elif isinstance(d, Dim):
- flat_dims.append(d)
- view.append(d.size)
- elif isinstance(d, int):
- d = _wrap_dim(d, ndim, False)
- flat_dims.append(d)
- view.append(ptensor.size(d))
- else:
- flat_dims.extend(d)
- view.append(prod(e.size for e in d))
- needs_view = True
- permute = list(range(len(levels)))
- nflat = len(flat_dims)
- for i, d in enumerate(flat_dims):
- try:
- idx = levels.index(d)
- except ValueError as e:
- raise DimensionBindError(f'tensor of dimensions {self.dims} does not contain dim {d}') from e
- p = permute[idx]
- del levels[idx]
- del permute[idx]
- levels.insert(i, 0)
- permute.insert(i, p)
- ptensor = ptensor.permute(*permute)
- seen = 0
- for i in range(len(levels) - 1, -1, -1):
- if isinstance(levels[i], int):
- seen += 1
- levels[i] = -seen
- result = Tensor.from_positional(ptensor, levels, self._has_device)
- if needs_view:
- result = result.reshape(*view, *result.size()[len(flat_dims):])
- return result
- def _contains_dim(input):
- from . import Dim
- for i in input:
- if isinstance(i, Dim):
- return True
- def expand(self, *sizes):
- if not _contains_dim(sizes):
- return self.__torch_function__(torch.Tensor.expand, None, (self, *sizes))
- dims = sizes
- sizes = [d.size for d in dims] + [-1] * self.ndim
- self = self.expand(*sizes)
- return self[dims]
- _not_present = object()
- def _getarg(name, offset, args, kwargs, default):
- if len(args) > offset:
- return args[offset]
- return kwargs.get(name, default)
- def _patcharg(name, offset, args, kwargs, value):
- if len(args) > offset:
- args[offset] = value
- else:
- kwargs[name] = value
- def _wrap(orig, dim_offset=0, keepdim_offset=1, dim_name='dim', single_dim=False, reduce=True):
- from . import TensorLike, Dim, Tensor
- def fn(self, *args, **kwargs):
- dim = _getarg(dim_name, dim_offset, args, kwargs, _not_present)
- if dim is _not_present or (single_dim and not isinstance(dim, Dim)):
- with _enable_layers(self.dims):
- print(f"dim fallback batch_tensor for {orig}")
- return Tensor.from_batched(orig(self._batchtensor, *args, **kwargs), self._has_device)
- keepdim = _getarg('keepdim', keepdim_offset, args, kwargs, False) if reduce else False
- t, levels = self._tensor, llist(self._levels)
- dims = _dims(dim, self._batchtensor.ndim, keepdim, single_dim)
- dim_indices = tuple(levels.index(d) for d in dims)
- if reduce and not keepdim:
- new_levels = [l for i, l in enumerate(levels) if i not in dim_indices]
- else:
- new_levels = levels
- if len(dim_indices) == 1:
- dim_indices = dim_indices[0]
- args = list(args)
- _patcharg(dim_name, dim_offset, args, kwargs, dim_indices)
- def wrap(t):
- if isinstance(t, TensorLike):
- return Tensor.from_positional(t, new_levels, self._has_device)
- return t
- with _enable_layers(new_levels):
- print(f"dim used batch_tensor for {orig}")
- r = orig(t, *args, **kwargs)
- return tree_map(wrap, r)
- return fn
- def _def(name, *args, **kwargs):
- from . import _Tensor
- orig = getattr(torch.Tensor, name)
- setattr(_Tensor, name, _wrap(orig, *args, **kwargs))
- no_slice = slice(None)
- _orig_getitem = torch.Tensor.__getitem__
- class dim_tracker:
- def __init__(self):
- self.dims = llist()
- self.count = []
- def record(self, d):
- if d not in self.dims:
- self.dims.append(d)
- self.count.append(1)
- def __getitem__(self, d):
- return self.count[self.dims.index(d)]
- def t__getitem__(self, input):
- from . import Dim, DimensionBindError, _Tensor, TensorLike, DimList, Tensor
-
-
-
-
-
-
-
-
-
-
- is_simple = (not isinstance(input, Dim) and
- not isinstance(input, (tuple, list)) and
-
- not (isinstance(input, TensorLike) and input.ndim == 0))
- if is_simple:
- if isinstance(self, _Tensor):
- return _Tensor.__torch_function__(_orig_getitem, None, (self, input))
- else:
- return _orig_getitem(self, input)
-
- if not isinstance(input, tuple):
- input = [input]
- else:
- input = list(input)
- dims_indexed = 0
- expanding_object = None
- dimlists = []
- for i, s in enumerate(input):
- if s is ... or isinstance(s, DimList) and not s.is_bound:
- if expanding_object is not None:
- msg = 'at most one ... or unbound dimension list can exist in indexing list but' \
- f' found 2 at offsets {i} and {expanding_object}'
- raise DimensionBindError(msg)
- expanding_object = i
- if isinstance(s, DimList):
- dims_indexed += len(s) if s.is_bound else 0
- dimlists.append(i)
- elif s is not None and s is not ...:
- dims_indexed += 1
- ndim = self.ndim
- if dims_indexed > ndim:
- raise IndexError(f'at least {dims_indexed} indices were supplied but the tensor only has {ndim} dimensions.')
- if expanding_object is not None:
- expanding_ndims = ndim - dims_indexed
- obj = input[expanding_object]
- if obj is ...:
- input[expanding_object:expanding_object + 1] = [no_slice] * expanding_ndims
- else:
- obj.bind_len(expanding_ndims)
-
- for i in reversed(dimlists):
- input[i:i + 1] = input[i]
- dims_indexed = 0
- requires_view = False
- size = self.size()
- view_sizes = []
- dims_seen = dim_tracker()
- def add_dims(t):
- if not isinstance(t, _Tensor):
- return
- for d in t.dims:
- dims_seen.record(d)
- add_dims(self)
- dim_packs = []
- for i, idx in enumerate(input):
- if idx is None:
- input[i] = no_slice
- view_sizes.append(1)
- requires_view = True
- else:
- sz = size[dims_indexed]
- if isinstance(idx, Dim):
- idx.size = sz
- dims_seen.record(idx)
- view_sizes.append(sz)
- elif isinstance(idx, (tuple, list)) and idx and isinstance(idx[0], Dim):
- for d in idx:
- dims_seen.record(idx)
- _bind_dims_to_size(sz, idx, f'offset {i}')
- view_sizes.extend(d.size for d in idx)
- requires_view = True
- dim_packs.append(i)
- else:
- add_dims(idx)
- view_sizes.append(sz)
- dims_indexed += 1
- if requires_view:
- self = self.view(*view_sizes)
- for i in reversed(dim_packs):
- input[i:i + 1] = input[i]
-
-
-
-
-
-
-
- if isinstance(self, _Tensor):
- ptensor_self, levels = self._tensor, list(self._levels)
-
- input_it = iter(input)
- flat_inputs = [next(input_it) if isinstance(l, int) else l for l in levels]
- has_device = self._has_device
- to_pad = 0
- else:
- ptensor_self, flat_inputs = self, input
- to_pad = ptensor_self.ndim - len(flat_inputs)
- has_device = True
- result_levels = []
- index_levels = []
- tensor_insert_point = None
- to_expand = {}
- requires_getindex = False
- for i, inp in enumerate(flat_inputs):
- if isinstance(inp, Dim) and dims_seen[inp] == 1:
- flat_inputs[i] = no_slice
- result_levels.append(inp)
- elif isinstance(inp, TensorLike):
- requires_getindex = True
- if tensor_insert_point is None:
- tensor_insert_point = len(result_levels)
- ptensor, levels, _ = _tensor_levels(inp)
- to_expand[i] = levels
- flat_inputs[i] = ptensor
- for l in levels:
- if l not in index_levels:
- index_levels.append(l)
- else:
- requires_getindex = True
- result_levels.append(0)
- if tensor_insert_point is not None:
- result_levels[tensor_insert_point:tensor_insert_point] = index_levels
- for i, levels in to_expand.items():
- flat_inputs[i] = _match_levels(flat_inputs[i], levels, index_levels)
- if requires_getindex:
- result = _orig_getitem(ptensor_self, flat_inputs)
- else:
- result = ptensor_self
- next_positional = -1
- if to_pad > 0:
- result_levels.extend([0] * to_pad)
- for i, r in enumerate(reversed(result_levels)):
- if isinstance(r, int):
- result_levels[-1 - i] = next_positional
- next_positional -= 1
- return Tensor.from_positional(result, result_levels, has_device)
- def stack(tensors, new_dim, dim=0, out=None):
- if isinstance(dim, int):
- return torch.stack(tensors, dim, out).index(dim, new_dim)
- index = None
- if out is not None:
- out, index = _positional_no_permute(out, dim, expand_dim=True)
- ptensors = []
- for t in tensors:
- pt, pi = _positional_no_permute(t, dim, expand_dim=True)
- if index is not None and pi != index:
- pt = pt.move_dim(pi, index)
- else:
- index = pi
- ptensors.append(pt)
- pr = torch.stack(ptensors, index, out=out)
- return pr.index((index, index + 1), (new_dim, dim))
- _orig_split = torch.Tensor.split
- def split(self, split_size_or_sections, dim=0):
- from . import Dim, _Tensor
- if isinstance(split_size_or_sections, int) or any(isinstance(t, int) for t in split_size_or_sections):
- if isinstance(dim, Dim):
- raise ValueError('when dim is specified as a Dim object, split sizes must also be dimensions.')
- return _orig_split(self, split_size_or_sections, dim=dim)
- if isinstance(dim, Dim):
- assert isinstance(self, _Tensor), f"Tensor does not have dimension {dim}"
- self, dim = _positional_no_permute(self, dim)
- size = self.size(dim)
- total_bound_size = 0
- unbound = []
- sizes = []
- for i, d in enumerate(split_size_or_sections):
- if d.is_bound:
- sizes.append(d.size)
- total_bound_size += d.size
- else:
- sizes.append(0)
- unbound.append(i)
- if unbound:
- assert total_bound_size <= size, \
- f"result dimensions are larger than original: {total_bound_size} vs {size} ({split_size_or_sections})"
- remaining_size = size - total_bound_size
- chunk_size = -(-remaining_size // len(unbound))
- for u in unbound:
- sz = min(chunk_size, remaining_size)
- split_size_or_sections[u].size = sz
- sizes[u] = sz
- remaining_size -= sz
- else:
- assert total_bound_size == size, \
- f"result dimensions do not match original: {total_bound_size} vs {size} ({split_size_or_sections})"
- return tuple(t.index(dim, d) for d, t in zip(split_size_or_sections, _orig_split(self, sizes, dim=dim)))
|