123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399 |
- import collections
- import dataclasses
- import enum
- from typing import Any, Optional, Union
- from torch._guards import GuardSource, Source
- from . import utils
- from .bytecode_transformation import create_instruction
- from .utils import enum_repr, rename_implicit
- _GUARD_SOURCE_NN_MODULE = {
- GuardSource.LOCAL: GuardSource.LOCAL_NN_MODULE,
- GuardSource.GLOBAL: GuardSource.GLOBAL_NN_MODULE,
- GuardSource.LOCAL_NN_MODULE: GuardSource.LOCAL_NN_MODULE,
- GuardSource.GLOBAL_NN_MODULE: GuardSource.GLOBAL_NN_MODULE,
- }
- _GUARD_SOURCE_NOT_NN_MODULE = {
- GuardSource.LOCAL: GuardSource.LOCAL,
- GuardSource.GLOBAL: GuardSource.GLOBAL,
- GuardSource.LOCAL_NN_MODULE: GuardSource.LOCAL,
- GuardSource.GLOBAL_NN_MODULE: GuardSource.GLOBAL,
- }
- def is_constant_source(source):
- if isinstance(source, ConstantSource):
- return True
- try:
- if source.guard_source() == GuardSource.CONSTANT:
- return True
- except NotImplementedError:
- pass
- return False
- def is_input_source(source):
- return source.guard_source() in [
- GuardSource.LOCAL,
- GuardSource.GLOBAL,
- GuardSource.LOCAL_NN_MODULE,
- GuardSource.GLOBAL_NN_MODULE,
- ]
- @dataclasses.dataclass
- class LocalSource(Source):
- local_name: str
- def reconstruct(self, codegen):
- return [codegen.create_load(self.local_name)]
- def guard_source(self):
- return GuardSource.LOCAL
- def name(self):
- return rename_implicit(self.local_name)
- @dataclasses.dataclass
- class LocalInputSource(LocalSource):
- pos: int
- @dataclasses.dataclass
- class RandomValueSource(Source):
- random_call_index: int
- def guard_source(self):
- return GuardSource.RANDOM_VALUE
- def reconstruct(self, codegen):
- return [
- codegen.create_load(codegen.tx.output.random_values_var),
- codegen.create_load_const(self.random_call_index),
- create_instruction("BINARY_SUBSCR"),
- ]
- def name(self):
- return rename_implicit(f"random_value_{self.random_call_index}")
- @dataclasses.dataclass
- class GlobalSource(Source):
- global_name: str
- def reconstruct(self, codegen):
- return [codegen.create_load_global(self.global_name, add=True)]
- def guard_source(self):
- return GuardSource.GLOBAL
- def name(self):
- return self.global_name
- @dataclasses.dataclass
- class GlobalWeakRefSource(Source):
- global_name: str
- def reconstruct(self, codegen):
- return [
- codegen.create_load_global(self.global_name, add=True),
- create_instruction("CALL_FUNCTION", 0),
- ]
- def guard_source(self):
- return GuardSource.GLOBAL
- def name(self):
- return f"{self.global_name}()"
- @dataclasses.dataclass
- class AttrSource(Source):
- base: Source
- member: str
- def __init__(self, base, member):
- super().__init__()
- assert base, "Can't construct an AttrSource without a valid base source"
- if "." in member:
- member_parts = member.split(".")
- self.base = AttrSource(base, ".".join(member_parts[:-1]))
- self.member = member_parts[-1]
- else:
- self.base = base
- self.member = member
- assert self.base is not None
- def reconstruct(self, codegen):
- return self.base.reconstruct(codegen) + codegen.create_load_attrs(self.member)
- def guard_source(self):
- return self.base.guard_source()
- def name(self):
- if self.member.isnumeric():
- return f"getattr({self.base.name()}, {self.member!r})"
- return f"{self.base.name()}.{self.member}"
- class TensorProperty(enum.Enum):
- SIZE = 0
- STRIDE = 1
- STORAGE_OFFSET = 2
- @dataclasses.dataclass
- class TensorPropertySource(Source):
- base: Source
- prop: TensorProperty
- idx: Optional[int] = None # None for STORAGE_OFFSET
- def __post_init__(self):
- assert self.base is not None
- if self.prop is TensorProperty.STORAGE_OFFSET:
- assert self.idx is None
- else:
- assert self.idx is not None
- def reconstruct(self, codegen):
- raise NotImplementedError()
- def guard_source(self):
- return self.base.guard_source()
- def name(self):
- if self.prop is TensorProperty.SIZE:
- return f"{self.base.name()}.size()[{self.idx}]"
- elif self.prop is TensorProperty.STRIDE:
- return f"{self.base.name()}.stride()[{self.idx}]"
- elif self.prop is TensorProperty.STORAGE_OFFSET:
- assert self.idx is None
- return f"{self.base.name()}.storage_offset()"
- else:
- raise AssertionError(f"unhandled {self.prop}")
- @dataclasses.dataclass
- class NegateSource(Source):
- base: Source
- def __post_init__(self):
- assert self.base is not None
- def reconstruct(self, codegen):
- raise NotImplementedError()
- def guard_source(self):
- return self.base.guard_source()
- def name(self):
- # NB: use method call so that function stripping regexes work
- return f"{self.base.name()}.__neg__()"
- @dataclasses.dataclass
- class DefaultsSource(Source):
- base: Source
- idx_key: Union[int, str]
- is_kw: bool
- field: str
- def __init__(self, base, idx_key, is_kw=False):
- super().__init__()
- assert (
- base
- ), "Base must be a valid source in order to properly track and guard this Defaults to its origin."
- self.base = base
- self.idx_key = idx_key
- self.is_kw = is_kw
- if self.is_kw:
- assert isinstance(idx_key, str)
- self.field = "__kwdefaults__"
- self._name = f"{self.base.name()}.{self.field}['{self.idx_key}']"
- else:
- assert isinstance(idx_key, int)
- self.field = "__defaults__"
- self._name = f"{self.base.name()}.{self.field}[{self.idx_key}]"
- def reconstruct(self, codegen):
- instrs = self.base.reconstruct(codegen)
- instrs.extend(codegen.create_load_attrs(self.field))
- instrs.extend(
- [
- codegen.create_load_const(self.idx_key),
- create_instruction("BINARY_SUBSCR"),
- ]
- )
- return instrs
- def guard_source(self):
- return self.base.guard_source()
- def name(self):
- return self._name
- @dataclasses.dataclass
- class GetItemSource(Source):
- base: Source
- index: Any
- def __post_init__(self):
- assert self.base is not None
- def reconstruct(self, codegen):
- instrs = self.base.reconstruct(codegen)
- if isinstance(self.index, Source):
- instrs.extend(self.index.reconstruct(codegen))
- else:
- instrs.append(codegen.create_load_const(self.index))
- instrs.append(create_instruction("BINARY_SUBSCR"))
- return instrs
- def guard_source(self):
- return self.base.guard_source()
- def name(self):
- if isinstance(self.index, Source):
- return f"{self.base.name()}[{self.index.name()}]"
- else:
- if isinstance(self.index, enum.Enum):
- return f"{self.base.name()}[{enum_repr(self.index)}]"
- else:
- return f"{self.base.name()}[{self.index!r}]"
- @dataclasses.dataclass
- class TupleIteratorGetItemSource(GetItemSource):
- def reconstruct(self, codegen):
- codegen.load_import_from(utils.__name__, "tuple_iterator_getitem")
- return self.base.reconstruct(codegen) + [
- codegen.create_load_const(self.index),
- create_instruction("CALL_FUNCTION", 2),
- ]
- def name(self):
- return f"___tuple_iterator_getitem({self.base.name()}, {self.index!r})"
- @dataclasses.dataclass
- class TypeSource(Source):
- base: Source
- def __post_init__(self):
- assert self.base is not None
- def reconstruct(self, codegen):
- codegen.load_import_from("builtins", "type")
- return self.base.reconstruct(codegen) + [create_instruction("CALL_FUNCTION", 1)]
- def guard_source(self):
- return self.base.guard_source()
- def name(self):
- return f"type({self.base.name()})"
- @dataclasses.dataclass
- class SuperSource(Source):
- type: Source
- obj: Source
- def __post_init__(self):
- assert self.type is not None
- assert self.obj is not None
- def reconstruct(self, codegen):
- codegen.load_import_from("builtins", "super")
- return (
- self.type.reconstruct(codegen)
- + self.obj.reconstruct(codegen)
- + [create_instruction("CALL_FUNCTION", 2)]
- )
- def guard_source(self):
- return self.obj.guard_source()
- def name(self):
- return f"super({self.type.name()}, {self.obj.name()})"
- @dataclasses.dataclass
- class ODictGetItemSource(Source):
- base: Source
- index: Any
- def __post_init__(self):
- assert self.base is not None
- def reconstruct(self, codegen):
- return (
- [codegen._create_load_const(collections.OrderedDict.__getitem__)]
- + self.base.reconstruct(codegen)
- + [
- codegen.create_load_const(self.index),
- create_instruction("CALL_FUNCTION", 2),
- ]
- )
- def guard_source(self):
- return self.base.guard_source()
- def name(self):
- return f"___odict_getitem({self.base.name()}, {self.index!r})"
- @dataclasses.dataclass
- class NNModuleSource(Source):
- inner: Source
- def reconstruct(self, codegen):
- return self.inner.reconstruct(codegen)
- def guard_source(self):
- return _GUARD_SOURCE_NN_MODULE[self.inner.guard_source()]
- def name(self):
- return self.inner.name()
- class NotNNModuleSource(NNModuleSource):
- def guard_source(self):
- return _GUARD_SOURCE_NOT_NN_MODULE[self.inner.guard_source()]
- @dataclasses.dataclass
- class ConstantSource(Source):
- source_name: str
- def reconstruct(self, codegen):
- return [codegen.create_load_global(self.source_name, add=False)]
- def guard_source(self):
- return GuardSource.CONSTANT
- def name(self):
- return self.source_name
- def make_guard(self, fn, is_volatile=False):
- raise NotImplementedError()
- # This is a synthetic source that is associated with the singleton
- # shape env guard we always register for all frames. We get the actual
- # guard contents from the ambient ShapeEnv
- @dataclasses.dataclass
- class ShapeEnvSource(Source):
- def name(self):
- return ""
- def guard_source(self):
- return GuardSource.SHAPE_ENV
|