reference.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the BSD-style license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. # reference python implementations for C ops
  7. import torch
  8. from .tree_map import tree_flatten, tree_map
  9. from .batch_tensor import _enable_layers
  10. from . import op_properties
  11. from functorch._C import dim as _C
  12. DimList = _C.DimList
  13. from functools import reduce
  14. import operator
  15. # use dict to avoid writing C++ bindings for set
  16. pointwise = set(op_properties.pointwise)
  17. def prod(x):
  18. return reduce(operator.mul, x, 1)
  19. def _wrap_dim(d, N, keepdim):
  20. from . import Dim
  21. if isinstance(d, Dim):
  22. assert not keepdim, "cannot preserve first-class dimensions with keepdim=True"
  23. return d
  24. elif d >= 0:
  25. return d - N
  26. else:
  27. return d
  28. def _dims(d, N, keepdim, single_dim):
  29. from . import Dim
  30. if isinstance(d, (Dim, int)):
  31. return ltuple((_wrap_dim(d, N, keepdim),))
  32. assert not single_dim, f"expected a single dimension or int but found: {d}"
  33. return ltuple(_wrap_dim(x, N, keepdim) for x in d)
  34. def _bind_dims_to_size(lhs_size, rhs, lhs_debug):
  35. from . import DimensionMismatchError
  36. not_bound = tuple((i, r) for i, r in enumerate(rhs) if not r.is_bound)
  37. if len(not_bound) == 1:
  38. idx, d = not_bound[0]
  39. rhs_so_far = prod(r.size for r in rhs if r.is_bound)
  40. if lhs_size % rhs_so_far != 0:
  41. rhs_s = tuple('?' if not r.is_bound else str(r.size) for r in rhs)
  42. raise DimensionMismatchError(f"inferred dimension does not evenly fit into larger dimension: {lhs_size} vs {rhs_s}")
  43. new_size = lhs_size // rhs_so_far
  44. d.size = new_size
  45. elif len(not_bound) > 1:
  46. rhs_s = tuple('?' if not r.is_bound else str(r.size) for r in rhs)
  47. raise DimensionMismatchError(f"cannot infer the size of two dimensions at once: {rhs} with sizes {rhs_s}")
  48. else:
  49. rhs_size = prod(r.size for r in rhs)
  50. if lhs_size != rhs_size:
  51. raise DimensionMismatchError(
  52. f"Dimension sizes to do not match ({lhs_size} != {rhs_size}) when matching {lhs_debug} to {rhs}")
  53. def _tensor_levels(inp):
  54. from . import _Tensor
  55. if isinstance(inp, _Tensor):
  56. return inp._tensor, llist(inp._levels), inp._has_device
  57. else:
  58. return inp, llist(range(-inp.ndim, 0)), True
  59. def _match_levels(v, from_levels, to_levels):
  60. view = []
  61. permute = []
  62. requires_view = False
  63. size = v.size()
  64. for t in to_levels:
  65. try:
  66. idx = from_levels.index(t)
  67. permute.append(idx)
  68. view.append(size[idx])
  69. except ValueError:
  70. view.append(1)
  71. requires_view = True
  72. if permute != list(range(len(permute))):
  73. v = v.permute(*permute)
  74. if requires_view:
  75. v = v.view(*view)
  76. return v
  77. # make a single dimension positional but do not permute it,
  78. # used to do multi-tensor operators where the dim being acted on
  79. # should not physically move if possible
  80. def _positional_no_permute(self, dim, expand_dim=False):
  81. from . import Tensor
  82. ptensor, levels = self._tensor, llist(self._levels)
  83. try:
  84. idx = levels.index(dim)
  85. except ValueError:
  86. if not expand_dim:
  87. raise
  88. idx = 0
  89. ptensor = ptensor.expand(dim.size, *ptensor.size())
  90. levels.insert(0, 0)
  91. idx_batched = 0
  92. for i in range(idx):
  93. if isinstance(levels[i], int):
  94. levels[i] -= 1
  95. idx_batched += 1
  96. levels[idx] = -idx_batched - 1
  97. return Tensor.from_positional(ptensor, levels, self._has_device), idx_batched
  98. def seq(a, b):
  99. from . import Dim
  100. if isinstance(a, Dim) != isinstance(b, Dim):
  101. return False
  102. if isinstance(a, Dim):
  103. return a is b
  104. else:
  105. return a == b
  106. class isin:
  107. def __contains__(self, item):
  108. for x in self:
  109. if seq(item, x):
  110. return True
  111. return False
  112. def index(self, item):
  113. for i, x in enumerate(self):
  114. if seq(item, x):
  115. return i
  116. raise ValueError
  117. class llist(isin, list):
  118. pass
  119. class ltuple(isin, tuple):
  120. pass
  121. empty_dict = {}
  122. @classmethod
  123. def __torch_function__(self, orig, cls, args, kwargs=empty_dict):
  124. from . import _Tensor, TensorLike, Tensor
  125. from .delayed_mul_tensor import DelayedMulTensor
  126. if orig is torch.Tensor.__mul__:
  127. lhs, rhs = args
  128. if isinstance(lhs, _Tensor) and isinstance(rhs, _Tensor) and lhs.ndim == 0 and rhs.ndim == 0:
  129. return DelayedMulTensor(lhs, rhs)
  130. all_dims = llist()
  131. flat_args, unflatten = tree_flatten((args, kwargs))
  132. device_holding_tensor = None
  133. for f in flat_args:
  134. if isinstance(f, _Tensor):
  135. if f._has_device:
  136. device_holding_tensor = f._batchtensor
  137. for d in f.dims:
  138. if d not in all_dims:
  139. all_dims.append(d)
  140. def unwrap(t):
  141. if isinstance(t, _Tensor):
  142. r = t._batchtensor
  143. if device_holding_tensor is not None and not t._has_device:
  144. r = r.to(device=device_holding_tensor.device)
  145. return r
  146. return t
  147. if orig in pointwise:
  148. result_levels = llist()
  149. arg_levels = llist()
  150. to_expand = []
  151. for i, f in enumerate(flat_args):
  152. if isinstance(f, TensorLike):
  153. ptensor, levels, _ = _tensor_levels(f)
  154. if isinstance(f, _Tensor) and not f._has_device and device_holding_tensor is not None:
  155. ptensor = ptensor.to(device=device_holding_tensor.device)
  156. flat_args[i] = ptensor
  157. for l in levels:
  158. if l not in result_levels:
  159. result_levels.append(l)
  160. to_expand.append((i, levels))
  161. for i, levels in to_expand:
  162. flat_args[i] = _match_levels(flat_args[i], levels, result_levels)
  163. args, kwargs = unflatten(flat_args)
  164. result = orig(*args, **kwargs)
  165. def wrap(t):
  166. if isinstance(t, TensorLike):
  167. return Tensor.from_positional(t, result_levels, device_holding_tensor is not None)
  168. return t
  169. return tree_map(wrap, result)
  170. else:
  171. def wrap(t):
  172. if isinstance(t, TensorLike):
  173. return Tensor.from_batched(t, device_holding_tensor is not None)
  174. return t
  175. with _enable_layers(all_dims):
  176. print(f"batch_tensor for {orig}")
  177. args, kwargs = unflatten(unwrap(f) for f in flat_args)
  178. result = orig(*args, **kwargs)
  179. # print("END", orig)
  180. return tree_map(wrap, result)
  181. def positional(self, *dims):
  182. from . import Dim, Tensor
  183. ptensor, levels = self._tensor, llist(self._levels)
  184. flat_dims = llist()
  185. view = []
  186. needs_view = False
  187. ndim = self.ndim
  188. for d in dims:
  189. if isinstance(d, DimList):
  190. flat_dims.extend(d)
  191. view.extend(e.size for e in d)
  192. elif isinstance(d, Dim):
  193. flat_dims.append(d)
  194. view.append(d.size)
  195. elif isinstance(d, int):
  196. d = _wrap_dim(d, ndim, False)
  197. flat_dims.append(d)
  198. view.append(ptensor.size(d))
  199. else:
  200. flat_dims.extend(d)
  201. view.append(prod(e.size for e in d))
  202. needs_view = True
  203. permute = list(range(len(levels)))
  204. nflat = len(flat_dims)
  205. for i, d in enumerate(flat_dims):
  206. try:
  207. idx = levels.index(d)
  208. except ValueError as e:
  209. raise DimensionBindError(f'tensor of dimensions {self.dims} does not contain dim {d}') from e
  210. p = permute[idx]
  211. del levels[idx]
  212. del permute[idx]
  213. levels.insert(i, 0)
  214. permute.insert(i, p)
  215. ptensor = ptensor.permute(*permute)
  216. seen = 0
  217. for i in range(len(levels) - 1, -1, -1):
  218. if isinstance(levels[i], int):
  219. seen += 1
  220. levels[i] = -seen
  221. result = Tensor.from_positional(ptensor, levels, self._has_device)
  222. if needs_view:
  223. result = result.reshape(*view, *result.size()[len(flat_dims):])
  224. return result
  225. def _contains_dim(input):
  226. from . import Dim
  227. for i in input:
  228. if isinstance(i, Dim):
  229. return True
  230. def expand(self, *sizes):
  231. if not _contains_dim(sizes):
  232. return self.__torch_function__(torch.Tensor.expand, None, (self, *sizes))
  233. dims = sizes
  234. sizes = [d.size for d in dims] + [-1] * self.ndim
  235. self = self.expand(*sizes)
  236. return self[dims]
  237. _not_present = object()
  238. def _getarg(name, offset, args, kwargs, default):
  239. if len(args) > offset:
  240. return args[offset]
  241. return kwargs.get(name, default)
  242. def _patcharg(name, offset, args, kwargs, value):
  243. if len(args) > offset:
  244. args[offset] = value
  245. else:
  246. kwargs[name] = value
  247. def _wrap(orig, dim_offset=0, keepdim_offset=1, dim_name='dim', single_dim=False, reduce=True):
  248. from . import TensorLike, Dim, Tensor
  249. def fn(self, *args, **kwargs):
  250. dim = _getarg(dim_name, dim_offset, args, kwargs, _not_present)
  251. if dim is _not_present or (single_dim and not isinstance(dim, Dim)):
  252. with _enable_layers(self.dims):
  253. print(f"dim fallback batch_tensor for {orig}")
  254. return Tensor.from_batched(orig(self._batchtensor, *args, **kwargs), self._has_device)
  255. keepdim = _getarg('keepdim', keepdim_offset, args, kwargs, False) if reduce else False
  256. t, levels = self._tensor, llist(self._levels)
  257. dims = _dims(dim, self._batchtensor.ndim, keepdim, single_dim)
  258. dim_indices = tuple(levels.index(d) for d in dims)
  259. if reduce and not keepdim:
  260. new_levels = [l for i, l in enumerate(levels) if i not in dim_indices]
  261. else:
  262. new_levels = levels
  263. if len(dim_indices) == 1:
  264. dim_indices = dim_indices[0] # so that dims that really only take a single argument work...
  265. args = list(args)
  266. _patcharg(dim_name, dim_offset, args, kwargs, dim_indices)
  267. def wrap(t):
  268. if isinstance(t, TensorLike):
  269. return Tensor.from_positional(t, new_levels, self._has_device)
  270. return t
  271. with _enable_layers(new_levels):
  272. print(f"dim used batch_tensor for {orig}")
  273. r = orig(t, *args, **kwargs)
  274. return tree_map(wrap, r)
  275. return fn
  276. def _def(name, *args, **kwargs):
  277. from . import _Tensor
  278. orig = getattr(torch.Tensor, name)
  279. setattr(_Tensor, name, _wrap(orig, *args, **kwargs))
  280. no_slice = slice(None)
  281. _orig_getitem = torch.Tensor.__getitem__
  282. class dim_tracker:
  283. def __init__(self):
  284. self.dims = llist()
  285. self.count = []
  286. def record(self, d):
  287. if d not in self.dims:
  288. self.dims.append(d)
  289. self.count.append(1)
  290. def __getitem__(self, d):
  291. return self.count[self.dims.index(d)]
  292. def t__getitem__(self, input):
  293. from . import Dim, DimensionBindError, _Tensor, TensorLike, DimList, Tensor
  294. # * bail to original example if we have a single non-Dim tensor, or a non-tensor
  295. # * locate ... or an unbound tensor list, and determine its size, bind dim list
  296. # (remember that None does not count to the total dim count)
  297. # * bind simple dims and dim-packs to their sizes, count the number of uses of each dim,
  298. # produce the re-view if needed
  299. # * for each single-use dim index, replace with no_slice and mark that it will be added
  300. # (keep track of whether we have to call super)
  301. # * call super if needed
  302. # * if we have dims to bind, bind them (it will help if we eliminated ... and None before)
  303. # this handles bool indexing handling, as well as some other simple cases.
  304. is_simple = (not isinstance(input, Dim) and
  305. not isinstance(input, (tuple, list)) and
  306. # WAR for functorch bug where zero time tensors in getitem are not handled correctly.
  307. not (isinstance(input, TensorLike) and input.ndim == 0))
  308. if is_simple:
  309. if isinstance(self, _Tensor):
  310. return _Tensor.__torch_function__(_orig_getitem, None, (self, input))
  311. else:
  312. return _orig_getitem(self, input)
  313. # can further optimize this case
  314. if not isinstance(input, tuple):
  315. input = [input]
  316. else:
  317. input = list(input)
  318. dims_indexed = 0
  319. expanding_object = None
  320. dimlists = []
  321. for i, s in enumerate(input):
  322. if s is ... or isinstance(s, DimList) and not s.is_bound:
  323. if expanding_object is not None:
  324. msg = 'at most one ... or unbound dimension list can exist in indexing list but' \
  325. f' found 2 at offsets {i} and {expanding_object}'
  326. raise DimensionBindError(msg)
  327. expanding_object = i
  328. if isinstance(s, DimList):
  329. dims_indexed += len(s) if s.is_bound else 0
  330. dimlists.append(i)
  331. elif s is not None and s is not ...:
  332. dims_indexed += 1
  333. ndim = self.ndim
  334. if dims_indexed > ndim:
  335. raise IndexError(f'at least {dims_indexed} indices were supplied but the tensor only has {ndim} dimensions.')
  336. if expanding_object is not None:
  337. expanding_ndims = ndim - dims_indexed
  338. obj = input[expanding_object]
  339. if obj is ...:
  340. input[expanding_object:expanding_object + 1] = [no_slice] * expanding_ndims
  341. else:
  342. obj.bind_len(expanding_ndims)
  343. # flatten the dimslists into the indexing
  344. for i in reversed(dimlists):
  345. input[i:i + 1] = input[i]
  346. dims_indexed = 0
  347. requires_view = False
  348. size = self.size()
  349. view_sizes = []
  350. dims_seen = dim_tracker()
  351. def add_dims(t):
  352. if not isinstance(t, _Tensor):
  353. return
  354. for d in t.dims:
  355. dims_seen.record(d)
  356. add_dims(self)
  357. dim_packs = []
  358. for i, idx in enumerate(input):
  359. if idx is None:
  360. input[i] = no_slice
  361. view_sizes.append(1)
  362. requires_view = True
  363. else:
  364. sz = size[dims_indexed]
  365. if isinstance(idx, Dim):
  366. idx.size = sz
  367. dims_seen.record(idx)
  368. view_sizes.append(sz)
  369. elif isinstance(idx, (tuple, list)) and idx and isinstance(idx[0], Dim):
  370. for d in idx:
  371. dims_seen.record(idx)
  372. _bind_dims_to_size(sz, idx, f'offset {i}')
  373. view_sizes.extend(d.size for d in idx)
  374. requires_view = True
  375. dim_packs.append(i)
  376. else:
  377. add_dims(idx)
  378. view_sizes.append(sz)
  379. dims_indexed += 1
  380. if requires_view:
  381. self = self.view(*view_sizes)
  382. for i in reversed(dim_packs):
  383. input[i:i + 1] = input[i]
  384. # currenty:
  385. # input is flat, containing either Dim, or Tensor, or something valid for standard indexing
  386. # self may have first-class dims as well.
  387. # to index:
  388. # drop the first class dims from self, they just become direct indices of their positions
  389. # figure out the dimensions of the indexing tensors: union of all the dims in the tensors in the index.
  390. # these dimensions will appear and need to be bound at the first place tensor occures
  391. if isinstance(self, _Tensor):
  392. ptensor_self, levels = self._tensor, list(self._levels)
  393. # indices to ptensor rather than self which has first-class dimensions
  394. input_it = iter(input)
  395. flat_inputs = [next(input_it) if isinstance(l, int) else l for l in levels]
  396. has_device = self._has_device
  397. to_pad = 0
  398. else:
  399. ptensor_self, flat_inputs = self, input
  400. to_pad = ptensor_self.ndim - len(flat_inputs)
  401. has_device = True
  402. result_levels = []
  403. index_levels = []
  404. tensor_insert_point = None
  405. to_expand = {}
  406. requires_getindex = False
  407. for i, inp in enumerate(flat_inputs):
  408. if isinstance(inp, Dim) and dims_seen[inp] == 1:
  409. flat_inputs[i] = no_slice
  410. result_levels.append(inp)
  411. elif isinstance(inp, TensorLike):
  412. requires_getindex = True
  413. if tensor_insert_point is None:
  414. tensor_insert_point = len(result_levels)
  415. ptensor, levels, _ = _tensor_levels(inp)
  416. to_expand[i] = levels
  417. flat_inputs[i] = ptensor
  418. for l in levels:
  419. if l not in index_levels:
  420. index_levels.append(l)
  421. else:
  422. requires_getindex = True
  423. result_levels.append(0)
  424. if tensor_insert_point is not None:
  425. result_levels[tensor_insert_point:tensor_insert_point] = index_levels
  426. for i, levels in to_expand.items():
  427. flat_inputs[i] = _match_levels(flat_inputs[i], levels, index_levels)
  428. if requires_getindex:
  429. result = _orig_getitem(ptensor_self, flat_inputs)
  430. else:
  431. result = ptensor_self
  432. next_positional = -1
  433. if to_pad > 0:
  434. result_levels.extend([0] * to_pad)
  435. for i, r in enumerate(reversed(result_levels)):
  436. if isinstance(r, int):
  437. result_levels[-1 - i] = next_positional
  438. next_positional -= 1
  439. return Tensor.from_positional(result, result_levels, has_device)
  440. # XXX - dim is optional and can be the outer-most dimension...
  441. def stack(tensors, new_dim, dim=0, out=None):
  442. if isinstance(dim, int):
  443. return torch.stack(tensors, dim, out).index(dim, new_dim)
  444. index = None
  445. if out is not None:
  446. out, index = _positional_no_permute(out, dim, expand_dim=True)
  447. ptensors = []
  448. for t in tensors:
  449. pt, pi = _positional_no_permute(t, dim, expand_dim=True)
  450. if index is not None and pi != index:
  451. pt = pt.move_dim(pi, index)
  452. else:
  453. index = pi
  454. ptensors.append(pt)
  455. pr = torch.stack(ptensors, index, out=out)
  456. return pr.index((index, index + 1), (new_dim, dim))
  457. _orig_split = torch.Tensor.split
  458. def split(self, split_size_or_sections, dim=0):
  459. from . import Dim, _Tensor
  460. if isinstance(split_size_or_sections, int) or any(isinstance(t, int) for t in split_size_or_sections):
  461. if isinstance(dim, Dim):
  462. raise ValueError('when dim is specified as a Dim object, split sizes must also be dimensions.')
  463. return _orig_split(self, split_size_or_sections, dim=dim)
  464. if isinstance(dim, Dim):
  465. assert isinstance(self, _Tensor), f"Tensor does not have dimension {dim}"
  466. self, dim = _positional_no_permute(self, dim)
  467. size = self.size(dim)
  468. total_bound_size = 0
  469. unbound = []
  470. sizes = []
  471. for i, d in enumerate(split_size_or_sections):
  472. if d.is_bound:
  473. sizes.append(d.size)
  474. total_bound_size += d.size
  475. else:
  476. sizes.append(0)
  477. unbound.append(i)
  478. if unbound:
  479. assert total_bound_size <= size, \
  480. f"result dimensions are larger than original: {total_bound_size} vs {size} ({split_size_or_sections})"
  481. remaining_size = size - total_bound_size
  482. chunk_size = -(-remaining_size // len(unbound))
  483. for u in unbound:
  484. sz = min(chunk_size, remaining_size)
  485. split_size_or_sections[u].size = sz
  486. sizes[u] = sz
  487. remaining_size -= sz
  488. else:
  489. assert total_bound_size == size, \
  490. f"result dimensions do not match original: {total_bound_size} vs {size} ({split_size_or_sections})"
  491. return tuple(t.index(dim, d) for d, t in zip(split_size_or_sections, _orig_split(self, sizes, dim=dim)))