123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296 |
- import collections
- from typing import Any, Callable, Dict, List, Optional, Set
- from .. import variables
- from ..exc import unimplemented
- from ..source import AttrSource, Source
- from ..utils import dict_values, identity, istype, odict_values
- class MutableLocal:
- """
- Marker used to indicate this (list, iter, etc) was constructed in
- local scope and can be mutated safely in analysis without leaking
- state.
- """
- def __hash__(self):
- return id(self)
- def __eq__(self, other):
- return self is other
- # metaclass to call post_init
- class HasPostInit(type):
- def __call__(cls, *args, **kwargs):
- obj = type.__call__(cls, *args, **kwargs)
- obj.__post_init__(*args, **kwargs)
- return obj
- class VariableTracker(metaclass=HasPostInit):
- """
- Base class for tracked locals and stack values
- VariableTracker instances are immutable and should be copied in
- order to change them.
- """
- # fields to leave unmodified in apply()
- _nonvar_fields = ["value"]
- @staticmethod
- def propagate(*vars: List[List["VariableTracker"]]):
- """Combine the guards from many VariableTracker into **kwargs for a new instance"""
- guards = set()
- def visit(var):
- if type(var) in (list, tuple, dict_values, odict_values):
- for i in var:
- visit(i)
- else:
- assert isinstance(var, VariableTracker), typestr(var)
- guards.update(var.guards)
- visit(vars)
- return {
- "guards": guards,
- }
- def clone(self, **kwargs):
- """Shallow copy with some (optional) changes"""
- args = dict(self.__dict__)
- args.update(kwargs)
- return self.__class__(**args)
- @classmethod
- def copy(cls, value):
- """Deeper (but not full) copy, leaving FX and user objects alone"""
- return cls.apply(identity, value)
- @classmethod
- def apply(
- cls,
- fn: Callable[["VariableTracker"], "VariableTracker"],
- value,
- cache=None,
- skip_fn=lambda _: False, # Whether we should skip applying to this var
- ):
- """
- Walk this object and call fn on all the VariableTracker
- instances to produce a new VariableTracker with the results.
- """
- if cache is None:
- cache = dict()
- idx = id(value)
- if idx in cache:
- return cache[idx][0]
- if isinstance(value, VariableTracker):
- if not skip_fn(value):
- updated_dict = dict(value.__dict__)
- for key in updated_dict.keys():
- if key not in value._nonvar_fields:
- updated_dict[key] = cls.apply(
- fn, updated_dict[key], cache, skip_fn
- )
- result = fn(value.clone(**updated_dict))
- else:
- result = fn(value)
- elif istype(value, list):
- result = [cls.apply(fn, v, cache, skip_fn) for v in value]
- elif istype(value, tuple):
- result = tuple(cls.apply(fn, v, cache, skip_fn) for v in value)
- elif istype(value, collections.OrderedDict):
- result = collections.OrderedDict(
- cls.apply(fn, v, cache, skip_fn) for v in value.items()
- )
- elif istype(value, dict):
- result = {
- k: cls.apply(fn, v, cache, skip_fn) for k, v in list(value.items())
- }
- else:
- result = value
- # save `value` to keep it alive and ensure id() isn't reused
- cache[idx] = (result, value)
- return result
- def add_guard(self, guard):
- return self.clone(guards=set.union(self.guards, {guard}))
- def add_guards(self, guards):
- if guards is None:
- return self
- assert isinstance(guards, set)
- return self.clone(guards=set.union(self.guards, guards))
- def add_options(self, options, *more):
- if more:
- return self.add_options(options).add_options(*more)
- if isinstance(options, VariableTracker):
- return self.add_guards(options.guards)
- assert isinstance(options, dict)
- return self.add_guards(options.get("guards", set()))
- def __str__(self):
- return f"{self.__class__.__name__}()"
- def __repr__(self):
- return str(self)
- def python_type(self):
- raise NotImplementedError(f"{self} has no type")
- def as_python_constant(self):
- """For constants"""
- raise NotImplementedError(f"{self} is not a constant")
- def is_python_constant(self):
- try:
- self.as_python_constant()
- return True
- except NotImplementedError:
- return False
- def as_specialized(self, tx):
- """
- For specialized variables, return itself,
- For unspecialized variables, convert to constant variable and return.
- """
- return self
- def can_make_guard(self):
- try:
- self.make_guard(None)
- return True
- except NotImplementedError:
- return False
- def make_guard(self, fn):
- if self.source:
- return self.source.make_guard(fn)
- raise NotImplementedError()
- def replace_guards(self, guards, *fns):
- name = self.source.name()
- new_guards = {g for g in (guards or []) if g.name != name}
- new_guards.update(self.source.make_guard(fn) for fn in fns)
- return new_guards
- def const_getattr(self, tx, name: str) -> Any:
- """getattr(self, name) returning a python constant"""
- raise NotImplementedError()
- def var_getattr(self, tx, name: str) -> "VariableTracker":
- """getattr(self, name) returning a new variable"""
- options = VariableTracker.propagate(self)
- value = self.const_getattr(tx, name)
- if not variables.ConstantVariable.is_literal(value):
- raise NotImplementedError()
- if self.source:
- options["source"] = AttrSource(self.source, name)
- return variables.ConstantVariable(value, **options)
- def is_proxy(self):
- try:
- self.as_proxy()
- return True
- except NotImplementedError:
- return False
- def as_proxy(self):
- raise NotImplementedError(str(self))
- def reconstruct(self, codegen):
- raise NotImplementedError()
- def unpack_var_sequence(self, tx):
- raise NotImplementedError()
- def has_unpack_var_sequence(self, tx):
- try:
- self.unpack_var_sequence(tx)
- return True
- except NotImplementedError:
- return False
- def num_parameters(self):
- unimplemented(f"num_parameters: {self}")
- def call_hasattr(self, tx, name: str) -> "VariableTracker":
- unimplemented(f"hasattr: {repr(self)}")
- def call_function(
- self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
- ) -> "VariableTracker":
- unimplemented(f"call_function {self} {args} {kwargs}")
- def call_method(
- self,
- tx,
- name,
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
- ) -> "VariableTracker":
- if name == "__len__" and self.has_unpack_var_sequence(tx):
- assert not (args or kwargs)
- return variables.ConstantVariable(
- len(self.unpack_var_sequence(tx)), **VariableTracker.propagate(self)
- )
- elif (
- name == "__getattr__"
- and len(args) == 1
- and args[0].is_python_constant()
- and not kwargs
- ):
- return self.var_getattr(tx, args[0].as_python_constant()).add_options(
- self, args[0]
- )
- raise unimplemented(f"call_method {self} {name} {args} {kwargs}")
- def __init__(
- self,
- guards: Optional[Set] = None,
- source: Source = None,
- mutable_local: MutableLocal = None,
- recursively_contains: Optional[Set] = None,
- ):
- super().__init__()
- self.guards = guards or set()
- self.source = source
- self.mutable_local = mutable_local
- self.recursively_contains = (
- recursively_contains # provides hint to replace_all when replacing vars
- )
- def __post_init__(self, *args, **kwargs):
- if self.recursively_contains is None:
- self.recursively_contains = set()
- def aggregate_mutables(var):
- self.recursively_contains.update(var.recursively_contains)
- if var.mutable_local is not None:
- self.recursively_contains.add(var.mutable_local)
- return var
- VariableTracker.apply(
- aggregate_mutables, self, skip_fn=lambda var: var is not self
- )
- assert None not in self.recursively_contains
- def typestr(*objs):
- if len(objs) == 1:
- (obj,) = objs
- if isinstance(obj, VariableTracker):
- return str(obj)
- else:
- return type(obj).__name__
- else:
- return " ".join(map(typestr, objs))
|