dim.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the BSD-style license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. _vmap_levels = []
  7. @dataclass
  8. class LevelInfo:
  9. level: int
  10. alive: bool = True
  11. class Dim:
  12. def __init__(self, name: str, size: Union[None, int] = None):
  13. self.name = name
  14. self._size = None
  15. self._vmap_level = None
  16. if size is not None:
  17. self.size = size
  18. def __del__(self):
  19. if self._vmap_level is not None:
  20. _vmap_active_levels[self._vmap_stack].alive = False
  21. while not _vmap_levels[-1].alive and current_level() == _vmap_levels[-1].level:
  22. _vmap_decrement_nesting()
  23. _vmap_levels.pop()
  24. @property
  25. def size(self):
  26. assert self.is_bound
  27. return self._size
  28. @size.setter
  29. def size(self, size: int):
  30. if self._size is None:
  31. self._size = size
  32. self._vmap_level = _vmap_increment_nesting(size, 'same')
  33. self._vmap_stack = len(_vmap_levels)
  34. _vmap_levels.append(LevelInfo(self._vmap_level))
  35. elif self._size != size:
  36. raise DimensionBindError(
  37. f"Dim '{self}' previously bound to a dimension of size {self._size} cannot bind to a dimension of size {size}")
  38. @property
  39. def is_bound(self):
  40. return self._size is not None
  41. def __repr__(self):
  42. return self.name
  43. def extract_name(inst):
  44. assert inst.opname == 'STORE_FAST' or inst.opname == 'STORE_NAME'
  45. return inst.argval
  46. _cache = {}
  47. def dims(lists=0):
  48. frame = inspect.currentframe()
  49. assert frame is not None
  50. calling_frame = frame.f_back
  51. assert calling_frame is not None
  52. code, lasti = calling_frame.f_code, calling_frame.f_lasti
  53. key = (code, lasti)
  54. if key not in _cache:
  55. first = lasti // 2 + 1
  56. instructions = list(dis.get_instructions(calling_frame.f_code))
  57. unpack = instructions[first]
  58. if unpack.opname == 'STORE_FAST' or unpack.opname == 'STORE_NAME':
  59. # just a single dim, not a list
  60. name = unpack.argval
  61. ctor = Dim if lists == 0 else DimList
  62. _cache[key] = lambda: ctor(name=name)
  63. else:
  64. assert unpack.opname == 'UNPACK_SEQUENCE'
  65. ndims = unpack.argval
  66. names = tuple(extract_name(instructions[first + 1 + i]) for i in range(ndims))
  67. first_list = len(names) - lists
  68. _cache[key] = lambda: tuple(Dim(n) if i < first_list else DimList(name=n) for i, n in enumerate(names))
  69. return _cache[key]()
  70. def _dim_set(positional, arg):
  71. def convert(a):
  72. if isinstance(a, Dim):
  73. return a
  74. else:
  75. assert isinstance(a, int)
  76. return positional[a]
  77. if arg is None:
  78. return positional
  79. elif not isinstance(arg, (Dim, int)):
  80. return tuple(convert(a) for a in arg)
  81. else:
  82. return (convert(arg),)