123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476 |
- import abc
- import enum
- import functools
- import inspect
- import itertools
- import types
- from typing import Dict, List
- import torch
- from .. import variables
- from ..bytecode_transformation import create_instruction
- from ..exc import unimplemented
- from ..source import AttrSource, ConstantSource, DefaultsSource, GetItemSource
- from ..utils import istensor, istype, make_cell
- from .base import typestr, VariableTracker
- def wrap_bound_arg(tx, val, options, source=None):
- # Source propagation is best effort since not every object we encounter has a source to begin with.
- assert (
- "source" not in options
- ), "Source needs to be separate from options due to recursive calls for lists/dicts"
- if isinstance(val, dict):
- return variables.ConstDictVariable(
- {
- k: wrap_bound_arg(tx, v, options, source=getattr(v, "source", None))
- for k, v in val.items()
- },
- dict,
- **options,
- )
- elif isinstance(val, (tuple, list)):
- cls = variables.BaseListVariable.cls_for(type(val))
- return cls(
- [
- wrap_bound_arg(tx, x, options, source=getattr(x, "source", None))
- for x in val
- ],
- **options,
- )
- if variables.ConstantVariable.is_literal(val) or istype(
- val, (torch.Size, torch.device, torch.dtype)
- ):
- return variables.ConstantVariable(val, **options)
- elif isinstance(val, types.FunctionType):
- return variables.UserFunctionVariable(val, source=source, **options)
- elif isinstance(val, enum.Enum):
- return variables.EnumVariable(val, source=source, **options)
- elif isinstance(val, (type, abc.ABCMeta)):
- return variables.UserDefinedClassVariable(val, source=source, **options)
- elif istensor(val):
- from torch._dynamo.variables.builder import VariableBuilder
- return VariableBuilder(tx, source=source, **options)(val)
- else:
- assert isinstance(val, VariableTracker), typestr(val)
- return val
- def wrap_args_kwargs(tx, result, options):
- for k, v in list(result.items()):
- if isinstance(v, (tuple, dict)):
- # args/kwargs
- result[k] = wrap_bound_arg(tx, v, options)
- def init_cellvars(parent, result, code):
- closure_cells = dict()
- side_effects = parent.output.side_effects
- for name in code.co_cellvars:
- closure_cells[name] = side_effects.track_cell_new()
- if name in result:
- side_effects.store_cell(closure_cells[name], result.pop(name))
- return closure_cells
- class BaseUserFunctionVariable(VariableTracker):
- def get_filename(self):
- return self.get_code().co_filename
- def get_name(self):
- return self.get_code().co_name
- def call_function(
- self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
- ) -> "VariableTracker":
- return tx.inline_user_function_return(
- self, list(self.self_args()) + list(args), kwargs
- )
- def num_parameters(self):
- return len(inspect.signature(self.get_function()).parameters)
- def closure_vars(self, tx):
- return {}
- class UserFunctionVariable(BaseUserFunctionVariable):
- """Some unsupported user-defined global function"""
- def __init__(self, fn, is_constant=False, **kwargs):
- super().__init__(**kwargs)
- if getattr(fn, "_dynamo_marked_constant", False):
- # This method should be treated as a constant for the purposes of compilation
- self.is_constant = True
- else:
- self.is_constant = False
- assert isinstance(
- fn, (types.FunctionType, torch.jit.ScriptFunction)
- ), f"expected FunctionType found {typestr(fn)} {fn}"
- # unpack @torch._dynamo.optimize()(fn) wrapped function
- fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn)
- # unpack torch.jit.script_if_tracing
- if inspect.getattr_static(fn, "__script_if_tracing_wrapper", False):
- fn = inspect.getattr_static(fn, "__original_fn", fn)
- self.fn: types.FunctionType = fn
- def self_args(self):
- return []
- def get_function(self):
- return self.fn
- def get_code(self):
- return self.fn.__code__
- def python_type(self):
- return types.FunctionType
- def has_self(self):
- return getattr(self.fn, "__self__", None) is not None
- def get_globals(self):
- return self.fn.__globals__
- def bind_args(self, parent, args, kwargs):
- assert not self.is_constant
- options = VariableTracker.propagate([self])
- tx = parent.output.root_tx
- wrap = functools.partial(wrap_bound_arg, tx=tx, options=options)
- fn: types.FunctionType = self.fn
- defaults = fn.__defaults__ or []
- defaults_sources = [
- None if self.source is None else DefaultsSource(self.source, idx)
- for idx, _ in enumerate(defaults)
- ]
- fake_func = types.FunctionType(
- fn.__code__,
- fn.__globals__,
- fn.__name__,
- tuple(
- [
- wrap(val=arg, source=source)
- for arg, source in zip(defaults, defaults_sources)
- ]
- ),
- fn.__closure__,
- )
- if fn.__kwdefaults__:
- kwdefaults_sources = {
- k: None
- if self.source is None
- else DefaultsSource(self.source, k, is_kw=True)
- for k in fn.__kwdefaults__
- }
- fake_func.__kwdefaults__ = {
- k: wrap(val=v, source=kwdefaults_sources[k])
- for k, v in fn.__kwdefaults__.items()
- }
- bound = inspect.signature(fake_func).bind(*args, **kwargs)
- bound.apply_defaults()
- result = dict(bound.arguments.items())
- wrap_args_kwargs(tx, result, options)
- closure_cells = init_cellvars(parent, result, fn.__code__)
- closure = self.fn.__closure__ or ()
- assert len(closure) == len(self.fn.__code__.co_freevars)
- for idx, name, cell in zip(
- itertools.count(), self.fn.__code__.co_freevars, closure
- ):
- if name == "__class__":
- source = AttrSource(self.source, "__class__") if self.source else None
- result[name] = variables.UserDefinedClassVariable(
- cell.cell_contents,
- source=source,
- )
- else:
- var = tx.match_nested_cell(name, cell)
- if var is not None:
- # optimization for cleaner codegen
- result[name] = var
- elif self.source:
- from .builder import VariableBuilder
- side_effects = parent.output.side_effects
- if cell in side_effects:
- out = side_effects[cell]
- else:
- closure_cell = GetItemSource(
- AttrSource(self.source, "__closure__"), idx
- )
- closure_cell_contents = AttrSource(
- closure_cell, "cell_contents"
- )
- contents_var = VariableBuilder(parent, closure_cell_contents)(
- cell.cell_contents
- )
- if (
- closure_cell_contents.name()
- not in tx.mutated_closure_cell_contents
- ):
- # Optimistically don't allocate the cell, to
- # reduce the number of side effects. This is
- # important for cond, as without it, any accesses
- # to closures create side effects and cond doesn't
- # support side effects. If we're wrong and this
- # closure cell gets written to, we will restart
- # the analysis with this cell's name in the
- # mutated list here
- result[name] = contents_var
- continue
- # cells are written to with "cell_contents",
- # so the source should just be the closure_cell, not its contents
- out = side_effects.track_cell_existing(closure_cell, cell)
- side_effects.store_cell(
- out,
- contents_var,
- )
- result[name] = out
- else:
- unimplemented("inline with __closure__")
- return result, closure_cells
- def export_freevars(self, parent, child):
- pass
- def call_function(
- self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
- ) -> "VariableTracker":
- if self.is_constant:
- options = VariableTracker.propagate(self, args, kwargs.values())
- return invoke_and_store_as_constant(
- tx, self.fn, self.get_name(), options, args, kwargs
- )
- return super().call_function(tx, args, kwargs)
- class UserMethodVariable(UserFunctionVariable):
- """Some unsupported user-defined method"""
- def __init__(self, fn, obj, **kwargs):
- super().__init__(fn=fn, **kwargs)
- self.obj = obj
- def __str__(self):
- return f"{self.__class__.__name__}({self.fn}, {self.obj})"
- def self_args(self):
- return [self.obj]
- def python_type(self):
- return types.MethodType
- def call_function(
- self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
- ) -> "VariableTracker":
- if isinstance(self.obj, variables.NNModuleVariable):
- module_attr = getattr(self.fn, "__module__", "")
- if (
- module_attr is not None
- and module_attr.startswith("torch.nn.")
- or self.is_constant
- ):
- return self.obj.call_method(
- tx, self.fn.__name__, args, kwargs, constant=self.is_constant
- ).add_options(self)
- return super().call_function(tx, args, kwargs)
- def num_parameters(self):
- return super().num_parameters() - 1
- class WrappedUserMethodVariable(UserMethodVariable):
- def __init__(self, wrapped, context, **kwargs):
- kwargs.pop("fn", None)
- kwargs.pop("obj", None)
- super().__init__(wrapped.fn, wrapped.obj, **kwargs)
- self.wrapped = wrapped
- self.context = context
- def call_function(
- self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
- ) -> "VariableTracker":
- self.context.enter(tx)
- result = super().call_function(tx, args, kwargs)
- self.context.exit(tx)
- return result
- class WrappedUserFunctionVariable(UserFunctionVariable):
- def __init__(self, wrapped, context, **kwargs):
- kwargs.pop("fn", None)
- kwargs.pop("obj", None)
- super().__init__(wrapped.fn, **kwargs)
- self.wrapped = wrapped
- self.context = context
- def call_function(
- self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
- ) -> "VariableTracker":
- self.context.enter(tx)
- result = super().call_function(tx, args, kwargs)
- self.context.exit(tx)
- return result
- def invoke_and_store_as_constant(tx, fn, name, options, args, kwargs):
- def convert(x):
- if isinstance(x, variables.TensorVariable):
- return x.get_real_value()
- return x.as_python_constant()
- args = [convert(x) for x in args]
- kwargs = {k: convert(v) for k, v in kwargs.items()}
- res = fn(*args, **kwargs)
- return tx.output.register_attr_or_module(
- res,
- name,
- source=ConstantSource(name),
- **options,
- )
- class NestedUserFunctionVariable(BaseUserFunctionVariable):
- def __init__(
- self,
- fn_name,
- code,
- f_globals,
- defaults,
- kwdefaults,
- annotations,
- closure,
- closure_scope,
- **kwargs,
- ):
- super().__init__(**kwargs)
- assert isinstance(fn_name.as_python_constant(), str)
- assert isinstance(code.as_python_constant(), types.CodeType)
- assert isinstance(f_globals, dict)
- self.fn_name = fn_name
- self.code = code
- self.f_globals = f_globals
- self.defaults = defaults
- self.kwdefaults = kwdefaults
- self.annotations = annotations
- self.closure = closure
- if closure is None:
- closure_scope = None
- self.closure_scope = closure_scope
- def self_args(self):
- return []
- def get_code(self):
- return self.code.as_python_constant()
- def get_function(self):
- if self.closure:
- raise NotImplementedError()
- func = types.FunctionType(
- self.code.as_python_constant(),
- self.f_globals,
- self.fn_name.as_python_constant(),
- )
- if self.defaults:
- func.__defaults__ = self.defaults.as_python_constant()
- if self.kwdefaults:
- func.__kwdefaults__ = self.kwdefaults.as_python_constant()
- if self.annotations:
- annotations = self.annotations.as_python_constant()
- if isinstance(annotations, tuple):
- from itertools import pairwise
- annotations = dict(pairwise(annotations))
- # TypeError: __annotations__ must be set to a dict object
- assert isinstance(annotations, dict)
- func.__annotations__ = annotations
- return func
- def has_closure(self):
- return self.closure is not None
- def has_self(self):
- return False
- def get_globals(self):
- return self.f_globals
- def bind_args(self, parent, args, kwargs):
- code = self.get_code()
- func = types.FunctionType(
- code,
- self.f_globals,
- self.fn_name.as_python_constant(),
- tuple(self.defaults.items) if self.defaults else None,
- tuple(make_cell(None) for _ in range(len(self.get_code().co_freevars))),
- )
- if self.kwdefaults:
- func.__kwdefaults__ = self.kwdefaults.items
- bound = inspect.signature(func).bind(*args, **kwargs)
- bound.apply_defaults()
- result = dict(bound.arguments.items())
- wrap_args_kwargs(parent.output.root_tx, result, VariableTracker.propagate(self))
- closure_cells = init_cellvars(parent, result, code)
- for idx, name in enumerate(code.co_freevars):
- assert getattr(self.closure.items[idx], name, name) == name
- assert name not in result
- closure_cells[name] = self.closure.items[idx]
- return result, closure_cells
- def export_freevars(self, parent, child):
- code = self.get_code()
- for var in code.co_freevars:
- if var in child.symbolic_locals:
- parent.symbolic_locals[var] = child.symbolic_locals[var]
- def reconstruct(self, codegen):
- flags = 0x00
- if self.defaults:
- flags |= 0x01
- codegen(self.defaults)
- if self.kwdefaults:
- flags |= 0x02
- codegen(self.kwdefaults)
- if isinstance(self.annotations, variables.ConstDictVariable) or isinstance(
- self.annotations, variables.TupleVariable
- ):
- flags |= 0x04
- try:
- if isinstance(self.annotations, variables.ConstDictVariable):
- annotations = {
- k: v.as_python_constant()
- for k, v in self.annotations.items.items()
- }
- else:
- annotations = tuple(
- [v.as_python_constant() for v in self.annotations.items]
- )
- codegen.extend_output([codegen._create_load_const(annotations)])
- except NotImplementedError:
- codegen(self.annotations)
- if self.closure:
- flags |= 0x08
- codegen(self.closure)
- codegen(self.code)
- codegen(self.fn_name)
- return [create_instruction("MAKE_FUNCTION", flags)]
|