123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440 |
- import collections
- import dataclasses
- import functools
- import inspect
- from typing import Dict, List
- from .. import variables
- from ..bytecode_transformation import create_instruction
- from ..eval_frame import skip_code
- from ..exc import unimplemented
- from ..source import AttrSource, GlobalWeakRefSource
- from ..utils import global_key_name, istensor
- from .base import MutableLocal, VariableTracker
- from .constant import ConstantVariable
- from .tensor import TensorVariable
- class ConstDictVariable(VariableTracker):
- def __init__(self, items, user_cls, recursively_contains=None, **kwargs):
- super().__init__(recursively_contains=recursively_contains, **kwargs)
- self.guards.update(VariableTracker.propagate(items.values())["guards"])
- self.items = items
- self.user_cls = user_cls
- def as_proxy(self):
- return {k: v.as_proxy() for k, v in self.items.items()}
- def as_python_constant(self):
- return {k: v.as_python_constant() for k, v in self.items.items()}
- def python_type(self):
- return self.user_cls
- def reconstruct(self, codegen):
- for key, value in self.items.items():
- if istensor(key):
- codegen.extend_output(
- [
- codegen.create_load_global(global_key_name(key), add=True),
- create_instruction("CALL_FUNCTION", 0),
- ]
- )
- else:
- codegen.append_output(codegen.create_load_const(key))
- codegen(self.items[key])
- return [create_instruction("BUILD_MAP", len(self.items))]
- def getitem_const(self, arg: VariableTracker):
- return self.items[ConstDictVariable.get_key(arg)].add_options(self, arg)
- def call_method(
- self,
- tx,
- name,
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
- ) -> "VariableTracker":
- from . import ConstantVariable, TupleVariable
- options = VariableTracker.propagate(self, args, kwargs.values())
- val = self.items
- if name == "__getitem__":
- return self.getitem_const(args[0])
- elif name == "items":
- assert not (args or kwargs)
- return TupleVariable(
- [
- TupleVariable(
- [
- ConstDictVariable._key_to_var(
- tx,
- k,
- **options,
- ),
- v,
- ],
- **options,
- )
- for k, v in val.items()
- ],
- **options,
- )
- elif name == "keys":
- assert not (args or kwargs)
- return TupleVariable(
- [
- ConstDictVariable._key_to_var(
- tx,
- k,
- **options,
- )
- for k in val.keys()
- ],
- **options,
- )
- elif name == "values":
- assert not (args or kwargs)
- return TupleVariable(list(val.values()), **options)
- elif name == "__len__":
- assert not (args or kwargs)
- return ConstantVariable(len(self.items), **options)
- elif (
- name == "__setitem__"
- and args
- and ConstDictVariable.is_valid_key(args[0])
- and self.mutable_local
- ):
- assert not kwargs and len(args) == 2
- k = ConstDictVariable.get_key(args[0])
- if istensor(k):
- tx.store_dict_key(global_key_name(k), k)
- newval = collections.OrderedDict(val)
- newval[k] = args[1]
- new_rec_contains = self.recursively_contains.union(
- args[1].recursively_contains
- )
- if args[1].mutable_local is not None:
- new_rec_contains.add(args[1].mutable_local)
- return tx.replace_all(
- self,
- self.modifed(newval, new_rec_contains, **options),
- )
- elif (
- name in ("pop", "get")
- and args
- and ConstDictVariable.is_valid_key(args[0])
- and ConstDictVariable.get_key(args[0]) not in self.items
- and len(args) == 2
- ):
- # missing item, return the default value
- return args[1].add_options(options)
- elif (
- name == "pop"
- and args
- and ConstDictVariable.is_valid_key(args[0])
- and self.mutable_local
- ):
- newval = collections.OrderedDict(val)
- result = newval.pop(ConstDictVariable.get_key(args[0]))
- tx.replace_all(self, self.modifed(newval, None, **options))
- return result.add_options(options)
- elif (
- name == "update"
- and args
- and isinstance(args[0], ConstDictVariable)
- and self.mutable_local
- ):
- newval = collections.OrderedDict(val)
- newval.update(args[0].items)
- new_rec_contains = self.recursively_contains.union(
- args[0].recursively_contains
- )
- result = self.modifed(
- newval, recursively_contains=new_rec_contains, **options
- )
- return tx.replace_all(self, result)
- elif (
- name in ("get", "__getattr__")
- and args
- and ConstDictVariable.is_valid_key(args[0])
- and ConstDictVariable.get_key(args[0]) in self.items
- ):
- result = self.items[ConstDictVariable.get_key(args[0])]
- return result.add_options(options)
- elif (
- name == "__contains__" and args and ConstDictVariable.is_valid_key(args[0])
- ):
- return ConstantVariable(
- ConstDictVariable.get_key(args[0]) in self.items, **options
- )
- else:
- return super().call_method(tx, name, args, kwargs)
- def modifed(self, items, recursively_contains, **options):
- """a copy of self with different items"""
- return self.clone(
- items=items, recursively_contains=recursively_contains, **options
- )
- def unpack_var_sequence(self, tx):
- options = VariableTracker.propagate([self])
- val = self.items
- result = [ConstDictVariable._key_to_var(tx, k, **options) for k in val.keys()]
- return result
- @classmethod
- def get_key(cls, arg: VariableTracker):
- if isinstance(arg, TensorVariable) and arg.specialized_value is not None:
- return arg.specialized_value
- else:
- return arg.as_python_constant()
- @classmethod
- def is_valid_key(cls, key):
- return (
- key.is_python_constant()
- or isinstance(key, TensorVariable)
- and key.specialized_value is not None
- )
- @classmethod
- def _key_to_var(cls, tx, key, **options):
- from .builder import VariableBuilder
- if istensor(key):
- return VariableBuilder(tx, GlobalWeakRefSource(global_key_name(key)))(key)
- else:
- assert ConstantVariable.is_literal(key)
- return ConstantVariable(key, **options)
- class DefaultDictVariable(ConstDictVariable):
- def __init__(self, items, user_cls, default_factory=None, **kwargs):
- super().__init__(items, user_cls, **kwargs)
- assert user_cls is collections.defaultdict
- self.default_factory = default_factory
- def call_method(
- self,
- tx,
- name,
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
- ) -> "VariableTracker":
- from . import ListVariable, TupleVariable
- options = VariableTracker.propagate(self, args, kwargs.values())
- if name == "__getitem__":
- k = ConstDictVariable.get_key(args[0])
- if k in self.items:
- return self.getitem_const(args[0])
- else:
- if self.default_factory is None:
- raise KeyError(f"{k}")
- else:
- if istensor(k):
- tx.store_dict_key(global_key_name(k), k)
- new_val = collections.OrderedDict(self.items)
- if self.default_factory is list:
- default_var = ListVariable([], mutable_local=MutableLocal())
- elif self.default_factory is tuple:
- default_var = TupleVariable([], mutable_local=MutableLocal())
- elif self.default_factory is dict:
- default_var = ConstDictVariable(
- {}, dict, mutable_local=MutableLocal()
- )
- else:
- unimplemented(
- f"defaultdict with default_factory = {self.default_factory}"
- )
- new_val[k] = default_var
- new_rec_contains = self.recursively_contains.union(
- default_var.recursively_contains
- )
- if default_var.mutable_local is not None:
- new_rec_contains.add(default_var.mutable_local)
- tx.replace_all(
- self, self.modifed(new_val, new_rec_contains, **options)
- )
- return default_var
- else:
- return super().call_method(tx, name, args, kwargs)
- class DataClassVariable(ConstDictVariable):
- """
- This is a bit of a hack to deal with
- transformers.file_utils.ModelOutput() from huggingface.
- ModelOutput causes trouble because it a a mix of a dataclass and a
- OrderedDict and it calls super() methods implemented in C.
- """
- # ModelOutput() excludes None, though generic datclasses don't
- include_none = False
- @staticmethod
- @functools.lru_cache(None)
- def _patch_once():
- from transformers.file_utils import ModelOutput
- for obj in ModelOutput.__dict__.values():
- if callable(obj):
- skip_code(obj.__code__)
- @staticmethod
- def is_matching_cls(cls):
- try:
- from transformers.file_utils import ModelOutput
- return issubclass(cls, ModelOutput)
- except ImportError:
- return False
- @classmethod
- def is_matching_object(cls, obj):
- return cls.is_matching_cls(type(obj))
- @classmethod
- def create(cls, user_cls, args, kwargs, options):
- DataClassVariable._patch_once()
- skip_code(user_cls.__init__.__code__)
- keys = [f.name for f in dataclasses.fields(user_cls)]
- bound = inspect.signature(user_cls).bind(*args, **kwargs)
- bound.apply_defaults()
- assert set(bound.arguments.keys()) == set(keys)
- items = collections.OrderedDict()
- for key in keys:
- val = bound.arguments[key]
- if isinstance(val, VariableTracker):
- items[key] = val
- else:
- if cls.include_none:
- assert variables.ConstantVariable.is_literal(val)
- items[key] = variables.ConstantVariable(val)
- else:
- assert val is None, f"unexpected {val}"
- if len(items) == 1 and not isinstance(items[keys[0]], variables.TensorVariable):
- unimplemented("DataClassVariable iterator constructor")
- # TODO(jansel): implement unpacking logic in ModelOutput.__post_init__
- return cls(items, user_cls, **options)
- @classmethod
- def wrap(cls, builder, obj):
- user_cls = type(obj)
- keys = [f.name for f in dataclasses.fields(user_cls)]
- excluded = []
- items = collections.OrderedDict()
- for key in keys:
- # __init__ function of a dataclass might not have yet defined the key
- if hasattr(obj, key):
- val = getattr(obj, key)
- var = builder.__class__(
- tx=builder.tx, source=AttrSource(builder.source, key)
- )(val)
- if val is not None or cls.include_none:
- items[key] = var
- else:
- excluded.append(var)
- return cls(
- items, user_cls, **VariableTracker.propagate(excluded, items.values())
- )
- def __init__(self, items, user_cls, **options):
- super().__init__(items, user_cls, **options)
- assert self.is_matching_cls(user_cls)
- def as_proxy(self):
- raise NotImplementedError()
- def reconstruct(self, codegen):
- codegen.extend_output([codegen._create_load_const(self.user_cls)])
- keys = tuple(self.items.keys())
- for key in keys:
- codegen(self.items[key])
- return [
- codegen.create_load_const(keys),
- create_instruction("CALL_FUNCTION_KW", len(keys)),
- ]
- def call_method(
- self,
- tx,
- name,
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
- ) -> "VariableTracker":
- options = VariableTracker.propagate(self, args, kwargs.values())
- if name == "__getitem__":
- assert not kwargs and len(args) == 1
- index = args[0].as_python_constant()
- if isinstance(index, str):
- return self.items[index].add_options(options)
- else:
- return (
- self.call_method(tx, "to_tuple", [], {})
- .call_method(tx, "__getitem__", args, kwargs)
- .add_options(options)
- )
- elif name == "to_tuple":
- assert not (args or kwargs)
- return variables.TupleVariable(list(self.items.values()), **options)
- elif name == "__setattr__":
- name = "__setitem__"
- return super().call_method(tx, name, args, kwargs)
- def var_getattr(self, tx, name: str) -> "VariableTracker":
- if name in self.items:
- return self.call_method(
- tx, "__getitem__", [variables.ConstantVariable(name)], {}
- )
- elif not self.include_none:
- defaults = {f.name: f.default for f in dataclasses.fields(self.user_cls)}
- if name in defaults:
- assert variables.ConstantVariable.is_literal(defaults[name])
- return variables.ConstantVariable(defaults[name]).add_options(self)
- super().var_getattr(tx, name)
- class HFPretrainedConfigVariable(VariableTracker):
- """
- Hack for HuggingFace PretrainedConfig
- """
- @staticmethod
- def is_matching_cls(cls):
- try:
- from transformers.configuration_utils import PretrainedConfig
- return issubclass(cls, PretrainedConfig)
- except ImportError:
- return False
- @classmethod
- def is_matching_object(cls, obj):
- return cls.is_matching_cls(type(obj))
- def __init__(self, obj, **kwargs):
- super().__init__(**kwargs)
- self.obj = obj
- assert self.is_matching_cls(type(obj))
- def var_getattr(self, tx, name: str) -> "VariableTracker":
- from . import ConstantVariable
- return ConstantVariable(getattr(self.obj, name))
|