1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495 |
- # Copyright (c) Facebook, Inc. and its affiliates.
- # All rights reserved.
- #
- # This source code is licensed under the BSD-style license found in the
- # LICENSE file in the root directory of this source tree.
- _vmap_levels = []
- @dataclass
- class LevelInfo:
- level: int
- alive: bool = True
- class Dim:
- def __init__(self, name: str, size: Union[None, int] = None):
- self.name = name
- self._size = None
- self._vmap_level = None
- if size is not None:
- self.size = size
- def __del__(self):
- if self._vmap_level is not None:
- _vmap_active_levels[self._vmap_stack].alive = False
- while not _vmap_levels[-1].alive and current_level() == _vmap_levels[-1].level:
- _vmap_decrement_nesting()
- _vmap_levels.pop()
- @property
- def size(self):
- assert self.is_bound
- return self._size
- @size.setter
- def size(self, size: int):
- if self._size is None:
- self._size = size
- self._vmap_level = _vmap_increment_nesting(size, 'same')
- self._vmap_stack = len(_vmap_levels)
- _vmap_levels.append(LevelInfo(self._vmap_level))
- elif self._size != size:
- raise DimensionBindError(
- f"Dim '{self}' previously bound to a dimension of size {self._size} cannot bind to a dimension of size {size}")
- @property
- def is_bound(self):
- return self._size is not None
- def __repr__(self):
- return self.name
- def extract_name(inst):
- assert inst.opname == 'STORE_FAST' or inst.opname == 'STORE_NAME'
- return inst.argval
- _cache = {}
- def dims(lists=0):
- frame = inspect.currentframe()
- assert frame is not None
- calling_frame = frame.f_back
- assert calling_frame is not None
- code, lasti = calling_frame.f_code, calling_frame.f_lasti
- key = (code, lasti)
- if key not in _cache:
- first = lasti // 2 + 1
- instructions = list(dis.get_instructions(calling_frame.f_code))
- unpack = instructions[first]
- if unpack.opname == 'STORE_FAST' or unpack.opname == 'STORE_NAME':
- # just a single dim, not a list
- name = unpack.argval
- ctor = Dim if lists == 0 else DimList
- _cache[key] = lambda: ctor(name=name)
- else:
- assert unpack.opname == 'UNPACK_SEQUENCE'
- ndims = unpack.argval
- names = tuple(extract_name(instructions[first + 1 + i]) for i in range(ndims))
- first_list = len(names) - lists
- _cache[key] = lambda: tuple(Dim(n) if i < first_list else DimList(name=n) for i, n in enumerate(names))
- return _cache[key]()
- def _dim_set(positional, arg):
- def convert(a):
- if isinstance(a, Dim):
- return a
- else:
- assert isinstance(a, int)
- return positional[a]
- if arg is None:
- return positional
- elif not isinstance(arg, (Dim, int)):
- return tuple(convert(a) for a in arg)
- else:
- return (convert(arg),)
|