123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410 |
- import collections
- import contextlib
- import functools
- import importlib
- import inspect
- import random
- import types
- from typing import Dict, List
- import torch.nn
- from .. import variables
- from ..exc import unimplemented
- from ..guards import GuardBuilder
- from ..source import AttrSource, ODictGetItemSource, RandomValueSource
- from ..utils import is_namedtuple_cls, namedtuple_fields
- from .base import MutableLocal, VariableTracker
- from .misc import NullContextVariable
- class UserDefinedVariable(VariableTracker):
- pass
- class UserDefinedClassVariable(UserDefinedVariable):
- def __init__(self, value, **kwargs):
- super().__init__(**kwargs)
- self.value = value
- def as_python_constant(self):
- return self.value
- def python_type(self):
- return type(self.value)
- def var_getattr(self, tx, name: str) -> "VariableTracker":
- from . import ConstantVariable
- from .builder import VariableBuilder
- options = VariableTracker.propagate(self)
- source = AttrSource(self.source, name) if self.source is not None else None
- try:
- obj = inspect.getattr_static(self.value, name)
- except AttributeError:
- obj = None
- if isinstance(obj, staticmethod):
- return variables.UserFunctionVariable(
- obj.__get__(self.value), source=source, **options
- )
- elif isinstance(obj, classmethod):
- return variables.UserMethodVariable(
- obj.__func__, self, source=source, **options
- )
- if name in getattr(self.value, "__dict__", {}) or ConstantVariable.is_literal(
- obj
- ):
- if source:
- return VariableBuilder(tx, source)(obj).add_options(options)
- elif ConstantVariable.is_literal(obj):
- return ConstantVariable(obj, **options)
- return super().var_getattr(tx, name)
- def call_method(
- self,
- tx,
- name,
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
- ) -> "VariableTracker":
- if (
- name == "__subclasses__"
- and len(args) == 0
- and not kwargs
- and "__subclasses__" not in self.value.__dict__
- ):
- options = VariableTracker.propagate(self, args, kwargs.values())
- options["mutable_local"] = MutableLocal()
- subs_as_vars: List[VariableTracker] = list()
- for sub in self.value.__subclasses__():
- source = AttrSource(tx.import_source(sub.__module__), sub.__name__)
- subs_as_vars.append(
- variables.UserDefinedClassVariable(sub, source=source)
- )
- return variables.ListVariable(subs_as_vars, **options)
- return super().call_method(tx, name, args, kwargs)
- def call_function(
- self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
- ) -> "VariableTracker":
- from ..side_effects import SideEffects
- options = VariableTracker.propagate(self, args, kwargs.values())
- if self.value in (
- contextlib.nullcontext,
- torch.autograd.profiler.profile,
- ):
- return NullContextVariable(**options)
- elif is_namedtuple_cls(self.value):
- fields = namedtuple_fields(self.value)
- items = list(args)
- items.extend([None] * (len(fields) - len(items)))
- for name, value in kwargs.items():
- assert name in fields
- items[fields.index(name)] = value
- assert all(x is not None for x in items)
- return variables.NamedTupleVariable(
- items, self.value, **VariableTracker.propagate(self, items)
- )
- elif (
- inspect.getattr_static(self.value, "__new__", None) in (object.__new__,)
- and SideEffects.cls_supports_mutation_side_effects(self.value)
- and self.source
- ):
- var = tx.output.side_effects.track_object_new(
- self.source, self.value, UserDefinedObjectVariable, options
- )
- return var.add_options(var.call_method(tx, "__init__", args, kwargs))
- elif variables.DataClassVariable.is_matching_cls(self.value):
- options["mutable_local"] = MutableLocal()
- return variables.DataClassVariable.create(self.value, args, kwargs, options)
- return super().call_function(tx, args, kwargs)
- def const_getattr(self, tx, name):
- if name == "__name__":
- return self.value.__name__
- return super().const_getattr(tx, name)
- class UserDefinedObjectVariable(UserDefinedVariable):
- """
- Mostly objects of defined type. Catch-all for something where we only know the type.
- """
- def __init__(self, value, value_type=None, **kwargs):
- super().__init__(**kwargs)
- self.value = value
- self.value_type = value_type or type(value)
- assert type(value) is self.value_type
- def __str__(self):
- inner = self.value_type.__name__
- if inner in [
- "builtin_function_or_method",
- "getset_descriptor",
- "method_descriptor",
- "method",
- ]:
- inner = str(getattr(self.value, "__name__", None))
- return f"{self.__class__.__name__}({inner})"
- def python_type(self):
- return self.value_type
- @staticmethod
- @functools.lru_cache(None)
- def _supported_random_functions():
- fns = {
- random.random,
- random.randint,
- random.randrange,
- random.uniform,
- }
- return fns
- def call_method(
- self,
- tx,
- name,
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
- ) -> "VariableTracker":
- from . import ConstantVariable, TupleVariable, UserMethodVariable
- options = VariableTracker.propagate(self, args, kwargs.values())
- if name not in getattr(self.value, "__dict__", {}):
- try:
- method = inspect.getattr_static(type(self.value), name)
- except AttributeError:
- method = None
- if method is object.__init__:
- return ConstantVariable(None, **options)
- if method is collections.OrderedDict.keys and self.source:
- # subclass of OrderedDict
- assert not (args or kwargs)
- keys = list(self.value.keys())
- assert all(map(ConstantVariable.is_literal, keys))
- return TupleVariable(
- [ConstantVariable(k, **options) for k in keys], **options
- ).add_guard(self.source.make_guard(GuardBuilder.ODICT_KEYS))
- if (
- method is collections.OrderedDict.items
- and isinstance(self.value, collections.OrderedDict)
- and self.source
- ):
- assert not (args or kwargs)
- items = []
- keys = self.call_method(tx, "keys", [], {})
- options = VariableTracker.propagate(self, args, kwargs.values(), keys)
- for key in keys.unpack_var_sequence(tx):
- items.append(
- TupleVariable(
- [key, self.odict_getitem(tx, key)],
- **options,
- )
- )
- return TupleVariable(items, **options)
- if method is collections.OrderedDict.__getitem__ and len(args) == 1:
- assert not kwargs
- return self.odict_getitem(tx, args[0])
- # check for methods implemented in C++
- if isinstance(method, types.FunctionType):
- source = (
- None
- if self.source is None
- else AttrSource(AttrSource(self.source, "__class__"), name)
- )
- # TODO(jansel): add a guard to check for monkey patching?
- return UserMethodVariable(
- method, self, source=source, **options
- ).call_function(tx, args, kwargs)
- return super().call_method(tx, name, args, kwargs)
- def is_supported_random(self):
- try:
- return self.value in self._supported_random_functions()
- except TypeError:
- # TypeError: unhashable type
- return False
- def call_function(
- self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
- ) -> "VariableTracker":
- from .builder import VariableBuilder
- if (
- self.is_supported_random()
- and all(k.is_python_constant() for k in args)
- and all(v.is_python_constant() for v in kwargs.values())
- ):
- args = [x.as_python_constant() for x in args]
- kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
- random_call_index = len(tx.random_calls)
- if random_call_index == 0:
- tx.output.initial_random_state = random.getstate()
- example_value = self.value(*args, **kwargs)
- source = RandomValueSource(random_call_index)
- tx.random_calls.append((self.value, args, kwargs))
- return VariableBuilder(tx, source).wrap_unspecialized_primitive(
- example_value
- )
- return super().call_function(tx, args, kwargs)
- def _check_for_getattribute(self):
- try:
- if isinstance(
- inspect.getattr_static(type(self.value), "__getattribute__"),
- types.FunctionType,
- ):
- unimplemented("UserDefinedObjectVariable with custom __getattribute__")
- except AttributeError:
- pass
- def _check_for_getattr(self):
- try:
- getattr_fn = inspect.getattr_static(type(self.value), "__getattr__")
- except AttributeError:
- getattr_fn = None
- if getattr_fn is torch.nn.Module.__getattr__:
- # ignore this case of getattr
- getattr_fn = None
- return getattr_fn
- def _getattr_static(self, name):
- if (
- isinstance(self.value, torch.nn.Module)
- or "__slots__" in self.value.__class__.__dict__
- ):
- # getattr_static doesn't work on these
- subobj = getattr(self.value, name)
- else:
- subobj = inspect.getattr_static(self.value, name)
- return subobj
- def var_getattr(self, tx, name):
- from . import ConstantVariable
- from .builder import VariableBuilder
- options = VariableTracker.propagate(self)
- value = self.value
- source = AttrSource(self.source, name) if self.source else None
- self._check_for_getattribute()
- getattr_fn = self._check_for_getattr()
- try:
- subobj = self._getattr_static(name)
- except AttributeError:
- subobj = None
- if isinstance(getattr_fn, types.FunctionType):
- return variables.UserMethodVariable(
- getattr_fn, self, source=source, **options
- ).call_function(tx, [ConstantVariable(name)], {})
- elif getattr_fn is not None:
- unimplemented("UserDefined with non-function __getattr__")
- if isinstance(subobj, property):
- return variables.UserMethodVariable(
- subobj.fget, self, source=source, **options
- ).call_function(tx, [], {})
- elif isinstance(subobj, staticmethod):
- return variables.UserFunctionVariable(
- subobj.__get__(self.value), source=source, **options
- )
- elif isinstance(subobj, classmethod):
- return variables.UserMethodVariable(
- subobj.__func__, self, source=source, **options
- )
- elif isinstance(subobj, types.FunctionType):
- return variables.UserMethodVariable(subobj, self, source=source, **options)
- if (
- name in getattr(value, "__dict__", {})
- or ConstantVariable.is_literal(subobj)
- or isinstance(
- subobj,
- (
- torch.Tensor,
- torch.nn.Module,
- ),
- )
- ):
- if source:
- return VariableBuilder(tx, source)(subobj).add_options(options)
- elif ConstantVariable.is_literal(subobj):
- return ConstantVariable(subobj, **options)
- if (
- name not in getattr(value, "__dict__", {})
- and type(value).__module__.startswith("torch.")
- and "torch.optim" not in type(value).__module__
- and not callable(value)
- ):
- if not source:
- assert getattr(
- importlib.import_module(type(value).__module__),
- type(value).__name__,
- ) is type(value)
- source = AttrSource(
- AttrSource(
- tx.import_source(type(value).__module__), type(value).__name__
- ),
- name,
- )
- return VariableBuilder(tx, source)(subobj).add_options(options)
- options["source"] = source
- if isinstance(
- subobj,
- (
- torch.distributions.constraints._Interval,
- torch.distributions.constraints._Real,
- torch.distributions.constraints.Constraint,
- ),
- ):
- return UserDefinedObjectVariable(subobj, **options)
- if name == "__class__":
- return UserDefinedClassVariable(type(self.value), **options)
- return variables.GetAttrVariable(self, name, **options)
- def call_hasattr(self, tx, name: str) -> "VariableTracker":
- if not self.source:
- unimplemented("hasattr no source")
- options = VariableTracker.propagate(self)
- options["guards"].add(
- AttrSource(self.source, name).make_guard(GuardBuilder.HASATTR)
- )
- if self._check_for_getattribute() or self._check_for_getattr():
- unimplemented("hasattr with custom __getattr__")
- try:
- self._getattr_static(name)
- return variables.ConstantVariable(True, **options)
- except AttributeError:
- return variables.ConstantVariable(False, **options)
- def odict_getitem(self, tx, key):
- from .builder import VariableBuilder
- return VariableBuilder(
- tx,
- ODictGetItemSource(self.source, key.as_python_constant()),
- )(
- collections.OrderedDict.__getitem__(self.value, key.as_python_constant())
- ).add_options(
- key, self
- )
|