123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536 |
- import functools
- import operator
- from typing import Dict, List, Optional
- import torch
- import torch.fx
- from .. import config, variables
- from ..bytecode_transformation import create_instruction
- from ..exc import unimplemented
- from ..source import GetItemSource
- from ..utils import namedtuple_fields, proxy_args_kwargs
- from .base import MutableLocal, VariableTracker
- from .constant import ConstantVariable
- class BaseListVariable(VariableTracker):
- @staticmethod
- def cls_for(obj):
- return {
- iter: ListIteratorVariable,
- list: ListVariable,
- slice: SliceVariable,
- torch.Size: SizeVariable,
- tuple: TupleVariable,
- }[obj]
- def __init__(
- self,
- items: List[VariableTracker],
- recursively_contains=None,
- regen_guards=True,
- **kwargs,
- ):
- super().__init__(recursively_contains=recursively_contains, **kwargs)
- assert isinstance(items, list)
- assert all(isinstance(x, VariableTracker) for x in items)
- # Sometimes, we know that we have passed in the guards from the items in the list
- if regen_guards:
- self.guards.update(VariableTracker.propagate(items)["guards"])
- self.items: List[VariableTracker] = items
- def _as_proxy(self):
- return [x.as_proxy() for x in self.items]
- def as_python_constant(self):
- return self.python_type()([x.as_python_constant() for x in self.items])
- def as_proxy(self):
- assert self.python_type() is not SizeVariable
- return self.python_type()(self._as_proxy())
- def getitem_const(self, arg: VariableTracker):
- index = arg.as_python_constant()
- if isinstance(index, slice):
- if self.source is not None:
- return self.clone(
- items=self.items[index],
- source=GetItemSource(self.source, index),
- mutable_local=MutableLocal() if self.mutable_local else None,
- ).add_options(arg, self)
- else:
- return self.clone(
- items=self.items[index],
- mutable_local=MutableLocal() if self.mutable_local else None,
- ).add_options(arg, self)
- else:
- assert isinstance(index, int)
- return self.items[index].add_options(arg, self)
- def unpack_var_sequence(self, tx):
- return [x.add_options(self) for x in self.items]
- def call_method(
- self,
- tx,
- name,
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
- ) -> "VariableTracker":
- options = VariableTracker.propagate(self, args, kwargs.values())
- if name == "__getitem__":
- assert not kwargs and len(args) == 1
- return self.getitem_const(args[0])
- elif (
- name == "__contains__"
- and len(args) == 1
- and args[0].is_python_constant()
- and all(x.is_python_constant() for x in self.items)
- ):
- assert not kwargs
- search = args[0].as_python_constant()
- result = any(x.as_python_constant() == search for x in self.items)
- return variables.ConstantVariable(result, **options)
- return super().call_method(tx, name, args, kwargs)
- @staticmethod
- def list_compare(tx, op, left, right):
- from .builtin import BuiltinVariable
- eq_result = BaseListVariable.list_eq(tx, left, right)
- if op is operator.eq:
- return eq_result
- elif op is operator.ne:
- return BuiltinVariable(operator.not_).call_function(tx, [eq_result], {})
- else:
- unimplemented(f"list_compare {left} {op} {right}")
- @staticmethod
- def list_eq(tx, left, right):
- from .builtin import BuiltinVariable
- options = VariableTracker.propagate(left, right)
- # Most list-like variables implement comparison ops the same way,
- # so they can re-use this helper.
- # There are quirks though, like how `tuple([2]) == torch.Size([2])`,
- # but `tuple([2]) != list([2])`
- if len(left.items) != len(right.items):
- return ConstantVariable(False, **options)
- if len(left.items) == 0:
- return ConstantVariable(True, **options)
- # Generic list comparison works by iterating over left aka self and right the compared-to list.
- # If we hit here, their lengths are the same and they cannot be expressed as python constants.
- # So, we iterate over the zipped list items.
- comps = []
- for l, r in zip(left.items, right.items):
- comp = BuiltinVariable(operator.eq).call_function(tx, [l, r], {})
- if comp.is_python_constant() and not comp.as_python_constant():
- # early exit in false case
- return comp.add_options(options)
- comps.append(comp)
- return functools.reduce(
- lambda a, b: BuiltinVariable(operator.and_).call_function(tx, [a, b], {}),
- comps,
- ).add_options(options)
- class RangeVariable(BaseListVariable):
- def __init__(self, items, **kwargs):
- items_to_map = items
- start = variables.ConstantVariable(0)
- stop = None
- step = variables.ConstantVariable(1)
- if len(items_to_map) == 1:
- (stop,) = items_to_map
- elif len(items_to_map) == 2:
- start, stop = items_to_map
- elif len(items_to_map) == 3:
- start, stop, step = items_to_map
- else:
- raise AssertionError()
- assert stop is not None
- super().__init__([start, stop, step], **kwargs)
- def python_type(self):
- return range
- def as_python_constant(self):
- return range(*[x.as_python_constant() for x in self.items])
- def as_proxy(self):
- return self.python_type()(*self._as_proxy())
- def unpack_var_sequence(self, tx):
- return [
- variables.ConstantVariable(x).add_options(self)
- for x in self.as_python_constant()
- ]
- def reconstruct(self, codegen):
- assert "range" not in codegen.tx.f_globals
- codegen.append_output(codegen.create_load_python_module(range))
- codegen.foreach(self.items)
- return [create_instruction("CALL_FUNCTION", 3)]
- def var_getattr(self, tx, name):
- fields = ["start", "stop", "step"]
- if name not in fields:
- unimplemented(f"range.{name}")
- return self.items[fields.index(name)].add_options(self)
- class ListVariable(BaseListVariable):
- def python_type(self):
- return list
- def reconstruct(self, codegen):
- codegen.foreach(self.items)
- return [create_instruction("BUILD_LIST", len(self.items))]
- def call_method(
- self,
- tx,
- name,
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
- ) -> "VariableTracker":
- options = VariableTracker.propagate(self, args, kwargs.values())
- if name == "append" and self.mutable_local:
- assert not kwargs
- (arg,) = args
- new_rec_contains = self.recursively_contains.union(arg.recursively_contains)
- if arg.mutable_local is not None:
- new_rec_contains.add(arg.mutable_local)
- tx.replace_all(
- self,
- ListVariable(
- self.items + [arg],
- recursively_contains=new_rec_contains,
- regen_guards=False,
- **options,
- ),
- )
- return ConstantVariable(None)
- elif (
- name == "extend"
- and self.mutable_local
- and args
- and args[0].has_unpack_var_sequence(tx)
- ):
- assert not kwargs
- (arg,) = args
- return tx.replace_all(
- self,
- ListVariable(
- list(self.items) + list(arg.unpack_var_sequence(tx)),
- regen_guards=False,
- **options,
- ),
- )
- elif name == "insert" and self.mutable_local:
- assert not kwargs
- idx, value = args
- items = list(self.items)
- items.insert(idx.as_python_constant(), value)
- return tx.replace_all(
- self,
- ListVariable(items, regen_guards=False, **options),
- )
- elif name == "pop" and self.mutable_local:
- assert not kwargs
- items = list(self.items)
- result = items.pop(*[a.as_python_constant() for a in args])
- tx.replace_all(
- self,
- ListVariable(items, regen_guards=False, **options),
- )
- return result
- elif name == "clear" and self.mutable_local:
- assert not kwargs and not args
- return tx.replace_all(
- self,
- ListVariable([], regen_guards=False, **options),
- )
- elif (
- name == "__setitem__"
- and self.mutable_local
- and args
- and args[0].is_python_constant()
- ):
- assert not kwargs
- key, value = args
- items = list(self.items)
- if isinstance(key, SliceVariable):
- items[key.as_python_constant()] = list(value.items)
- else:
- items[key.as_python_constant()] = value
- result = ListVariable(items, regen_guards=False, **options)
- return tx.replace_all(self, result)
- else:
- return super().call_method(tx, name, args, kwargs)
- class TupleVariable(BaseListVariable):
- def python_type(self):
- return tuple
- def reconstruct(self, codegen):
- codegen.foreach(self.items)
- return [create_instruction("BUILD_TUPLE", len(self.items))]
- def call_method(
- self,
- tx,
- name,
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
- ) -> "VariableTracker":
- return super().call_method(tx, name, args, kwargs)
- class SizeVariable(TupleVariable):
- """torch.Size(...)"""
- def __init__(
- self,
- items: List[VariableTracker],
- proxy: Optional[torch.fx.Proxy] = None,
- **kwargs,
- ):
- self.proxy = proxy
- super().__init__(items, **kwargs)
- def python_type(self):
- return torch.Size
- def as_proxy(self):
- if self.proxy is not None:
- return self.proxy
- # torch.Size needs special handling. Normally, we pun a list-like
- # container to directly contain Proxy/Node objects from FX, and FX
- # knows to look inside containers (via map_aggregate). But torch.Size
- # is weird; although it subclasses from tuple, it doesn't allow
- # members which aren't int-like (rejecting Proxy and Node). This
- # means we can't use the normal representation trick
- # torch.Size([proxy0, proxy1]). I looked into seeing if I could
- # relax torch.Size in PyTorch proper, but if torch.Size constructor
- # sees a type that it doesn't recognize, it will try to call
- # __index__() on it, so there is no BC way to actually change this
- # behavior (though it occurs to me that I could have just added a
- # YOLO no checking alternate constructor.)
- #
- # To work around this problem, I represent a torch.Size proxy as
- # a straight up proxy, that would have been constructed by taking
- # the constituent proxies as arguments. This trick can be generally
- # used for any construct that we need a proxy for but we can't
- # directly represent as an aggregate; I don't see very many examples
- # of this in torchdynamo though!
- # Look for a proxy. If there are none, do the legacy behavior
- tracer = None
- proxies = self._as_proxy()
- for proxy in proxies:
- if isinstance(proxy, torch.fx.Proxy):
- tracer = proxy.tracer
- break
- if tracer is None:
- return torch.Size(proxies)
- proxy = tracer.create_proxy("call_function", torch.Size, (proxies,), {})
- proxy.node.meta["example_value"] = torch.Size(
- [p.node.meta["example_value"] for p in proxies]
- )
- return proxy
- def reconstruct(self, codegen):
- codegen.load_import_from("torch", "Size")
- codegen.foreach(self.items)
- build_torch_size = [
- create_instruction("BUILD_TUPLE", len(self.items)),
- create_instruction("CALL_FUNCTION", 1),
- ]
- return build_torch_size
- def unpack_var_sequence(self, tx):
- return [x.add_options(self) for x in self.items]
- def call_method(
- self,
- tx,
- name,
- args: "List[VariableTracker]",
- kwargs: "Dict[str, VariableTracker]",
- ) -> "VariableTracker":
- options = VariableTracker.propagate(self, args, kwargs.values())
- if name == "__getitem__":
- assert not kwargs and len(args) == 1
- if config.dynamic_shapes:
- out = self.get_item_dyn(tx, args[0])
- else:
- out = self.getitem_const(args[0])
- return out
- return super().call_method(tx, name, args, kwargs)
- def get_item_dyn(self, tx, arg: VariableTracker):
- from .tensor import SymNodeVariable
- index = arg.as_python_constant()
- if isinstance(index, slice):
- def _dynamo_get_item_lambda(target, index):
- return torch.Size.__getitem__(target, index)
- parent_proxy = self.as_proxy()
- proxy = tx.output.create_proxy(
- "call_function",
- _dynamo_get_item_lambda,
- *proxy_args_kwargs([self, arg], {}),
- )
- items = self.items[index]
- def _unpack_into_example(item):
- if isinstance(item, SymNodeVariable):
- return item.sym_num
- return item.as_python_constant()
- # Mirror the indexing into example_value for downstream correctness
- proxy.node.meta["example_value"] = parent_proxy.node.meta["example_value"][
- index
- ]
- return SizeVariable(items, proxy=proxy).add_options(arg, self)
- else:
- assert isinstance(index, int)
- return self.items[index].add_options(arg, self)
- class ShapeVariable(TupleVariable):
- """
- Represents tensor.shape(...) and helps differentiate between a constant
- TupleVariable and ShapeVariable.
- """
- pass
- class NamedTupleVariable(TupleVariable):
- def __init__(self, items, tuple_cls, **kwargs):
- super().__init__(items, **kwargs)
- self.tuple_cls = tuple_cls
- def python_type(self):
- return self.tuple_cls
- def as_python_constant(self):
- return self.python_type()(*[x.as_python_constant() for x in self.items])
- def reconstruct(self, codegen):
- create_fn = getattr(self.tuple_cls, "_make", self.tuple_cls)
- codegen.append_output(codegen._create_load_const(create_fn))
- codegen.foreach(self.items)
- return [
- create_instruction("BUILD_TUPLE", len(self.items)),
- create_instruction("CALL_FUNCTION", 1),
- ]
- def var_getattr(self, tx, name):
- fields = namedtuple_fields(self.tuple_cls)
- if name not in fields:
- unimplemented(f"NamedTupleVariable.{name}")
- return self.items[fields.index(name)].add_options(self)
- def call_hasattr(self, tx, name: str) -> "VariableTracker":
- options = VariableTracker.propagate(self)
- fields = namedtuple_fields(self.tuple_cls)
- return variables.ConstantVariable(name in fields, **options)
- class SliceVariable(BaseListVariable):
- def __init__(self, items, **kwargs):
- items_to_map = items
- start, stop, step = [variables.ConstantVariable(None)] * 3
- if len(items_to_map) == 1:
- (stop,) = items_to_map
- elif len(items_to_map) == 2:
- start, stop = items_to_map
- elif len(items_to_map) == 3:
- start, stop, step = items_to_map
- else:
- raise AssertionError()
- if isinstance(start, variables.TensorVariable) or isinstance(
- stop, variables.TensorVariable
- ):
- unimplemented("Dynamic slicing on data-dependent value is not supported")
- super().__init__([start, stop, step], **kwargs)
- def as_proxy(self):
- return slice(*self._as_proxy())
- def python_type(self):
- return slice
- def as_python_constant(self):
- return slice(*[x.as_python_constant() for x in self.items])
- def reconstruct(self, codegen):
- codegen.foreach(self.items)
- return [create_instruction("BUILD_SLICE", len(self.items))]
- def var_getattr(self, tx, name):
- fields = ["start", "stop", "step"]
- if name not in fields:
- unimplemented(f"slice.{name}")
- return self.items[fields.index(name)].add_options(self)
- class ListIteratorVariable(VariableTracker):
- def __init__(self, items, index: int = 0, recursively_contains=None, **kwargs):
- super().__init__(recursively_contains=recursively_contains, **kwargs)
- assert isinstance(items, list)
- # Removing this check as it slows things down too much
- # https://github.com/pytorch/pytorch/pull/87533#issuecomment-1287574492
- # assert all(isinstance(x, VariableTracker) for x in items)
- self.items = items
- self.index = index
- def next_variables(self):
- assert self.mutable_local
- if self.index >= len(self.items):
- raise StopIteration()
- return self.items[self.index].add_options(self), ListIteratorVariable(
- self.items,
- self.index + 1,
- mutable_local=MutableLocal(),
- recursively_contains=self.recursively_contains,
- **VariableTracker.propagate([self]),
- )
- def as_python_constant(self):
- if self.index > 0:
- raise NotImplementedError()
- return iter([x.as_python_constant() for x in self.items])
- def unpack_var_sequence(self, tx):
- return [x.add_options(self) for x in self.items[self.index :]]
- def reconstruct(self, codegen):
- remaining_items = self.items[self.index :]
- codegen.foreach(remaining_items)
- return [
- create_instruction("BUILD_TUPLE", len(remaining_items)),
- create_instruction("GET_ITER"),
- ]
|