123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302 |
- # This file establishes the public comptime interface to Dynamo.
- # This allows Dynamo users to execute arbitrary Python code while
- # Dynamo is symbolically evaluating their original programs.
- #
- # The goal of the public API is to give users rope, without actually
- # leaking private implementation details of Dynamo.
- import dis
- import traceback
- from .exc import unimplemented
- class ComptimeVar:
- """
- A ComptimeVar represents a Python value, at some particular point
- in time, in the Python code we are symbolically evaluating with
- torchdynamo. This must be distinguished from a runtime value, as
- at compile-time there are some properties of the variable we
- do not know (for example, if the ComptimeVar represents a Tensor,
- we only know metadata about the tensor; we do NOT know what the
- actual data in the Tensor is.)
- """
- def __init__(self, v):
- self.__variable = v
- def as_proxy(self):
- """
- Returns an fx.Proxy (or tuple/list of fx.Proxy) representing
- this variable in the FX graph we are assembling to pass
- to the user compiler.
- This method only works for variables we actually track in
- the FX graph, aka Tensors (and ints, if you are compiling
- with dynamic shapes). In particular, if you have a list
- or tuple of tensors, you will get a list/tuple of proxies
- (not a single proxy representing the entire list/tuple).
- """
- return self.__variable.as_proxy()
- def is_proxy(self):
- """
- Returns True if as_proxy() would succeed.
- """
- return self.__variable.is_proxy()
- def as_fake(self):
- """
- Returns a "fake" value (either a FakeTensor or a SymInt)
- representing the variable in question. This only works
- for variables that denote Tensor or int. You can use
- this to query metadata; e.g., v.as_fake().size(0) will
- tell you the compile-time known size of the tensor.
- WARNING: Do NOT mutate the returned tensor.
- """
- return self.__variable.as_proxy().node.meta["example_value"]
- def python_type(self):
- """
- Returns what type(v) would have returned for the variable
- at compile time.
- """
- return self.__variable.python_type()
- def as_python_constant(self):
- """
- Returns the Python value this variable would have, but only if it is
- completely known at compile-time (e.g., it is constant).
- WARNING: Do NOT mutate the returned constant. The returned constant
- may or may not correspond to the actual value this variable may take
- on at runtime; for example, if the variable in question is a constant
- list, we may return a copy of that list.
- """
- return self.__variable.as_python_constant()
- def is_python_constant(self):
- """
- Returns True if as_python_constant would succeed.
- """
- return self.__variable.is_python_constant()
- def _i_will_not_complain_if_bc_breaks_VariableTracker(self):
- """
- Returns the internal data structure VariableTracker that Dynamo uses
- to represent variables at compile time. There are no BC guarantees on
- this API and WE RESERVE THE RIGHT TO BREAK YOUR CODE if you rely on
- it.
- """
- return self.__variable
- def __repr__(self):
- # TODO: The default repr is pretty bad, do better
- return repr(self.__variable)
- # TODO: API for adding a custom guard
- class ComptimeContext:
- """
- This context class provides access to a public API for Dynamo's internals.
- If there is something here you would find useful that is missing, please
- file a feature request at https://github.com/pytorch/pytorch/
- """
- def __init__(self, tx):
- self.__tx = tx
- def get_local(self, name: str, *, stacklevel=0) -> ComptimeVar:
- """
- Retrieve the compile-time known information about a local.
- """
- tx = self.__get_tx(stacklevel)
- return ComptimeVar(tx.symbolic_locals[name])
- def graph_break(self, msg="ComptimeContext.graph_break"):
- """
- Manually trigger a graph break
- """
- unimplemented(msg)
- def graph(self):
- """
- Retrieve the partially constructed FX graph that would be
- passed to the user compiler after compilation.
- """
- return self.__tx.output.graph
- def print_graph(self, *, verbose=True, file=None):
- """
- Print the partially constructed FX graph that would be passed
- to the user compiler after compilation.
- """
- print(
- self.__tx.output.graph.python_code("self", verbose=verbose).src, file=file
- )
- def __get_tx(self, stacklevel):
- tx = self.__tx
- for _ in range(stacklevel):
- tx = tx.parent
- return tx
- def print_disas(self, *, file=None, stacklevel=0):
- """
- Print the current series of opcodes being executed (not including
- parent frames), including where you are in the particular opcode
- stream.
- """
- tx = self.__get_tx(stacklevel)
- print(
- dis.Bytecode(
- tx.f_code,
- current_offset=tx.instructions[tx.instruction_pointer].offset,
- ).dis(),
- file=file,
- )
- def print_value_stack(self, *, file=None, stacklevel=0):
- """
- Print the current Python value stack. Note that this is NOT the same
- as the traceback; use print_bt() to print that. Note that at
- stacklevel=0, this will typically be empty, as comptime cannot
- currently be used in an expression context where there would be
- intermediates on the stack. If you would find this useful, please
- file a bug at https://github.com/pytorch/pytorch/
- NB: Stack grows downwards in our print
- """
- # TODO: improve printing
- tx = self.__get_tx(stacklevel)
- for s in tx.stack:
- print(f"- {s}", file=file)
- def print_locals(self, *, file=None, stacklevel=0):
- """
- Print all of the locals available in the current context.
- By default this view is very limited; you can get more information
- about any individual local using get_local().
- """
- # TODO: improve by improving the VariableTracker printing
- tx = self.__get_tx(stacklevel)
- for k, v in tx.symbolic_locals.items():
- print(f"{k} = {v}", file=file)
- def print_bt(self, *, file=None, stacklevel=0):
- """
- Print the user code backtrace, starting at the beginning of the
- frame Dynamo started evaluating. Note that this MAY NOT go all
- the way to the torch.compile invocation, as we may have done
- a graph break and are compiling an intermediate frame as the
- starting point. If you think the other behavior would be better,
- file a bug at https://github.com/pytorch/pytorch/
- """
- stack = []
- tx = self.__get_tx(stacklevel)
- while tx is not None:
- stack.append(tx.frame_summary())
- tx = getattr(tx, "parent", None)
- print(
- "".join(traceback.StackSummary.from_list(reversed(stack)).format()),
- file=file,
- )
- def print_guards(self, *, file=None):
- """
- Print the currently installed guards for the Dynamo context.
- This does NOT include guards associated with variables that
- may or may not be installed in the future if those variables
- are used.
- """
- # TODO: improve print format, current guard format is extremely
- # verbose
- print(
- "\n".join(f"-{str(guard)}" for guard in sorted(self.__tx.output.guards)),
- file=file,
- )
- def _i_will_not_complain_if_bc_breaks_InstructionTranslator(self):
- """
- Returns the internal data structure InstructionTranslator that Dynamo
- uses to track state of symbolic evaluation. There are no BC
- guarantees on this API and WE RESERVE THE RIGHT TO BREAK YOUR CODE if
- you rely on it.
- """
- return self.__tx
- # Convenience wrappers that are more compact to use
- def graph_break():
- comptime(lambda ctx: ctx.graph_break())
- def print_graph():
- comptime(lambda ctx: ctx.print_graph())
- def print_disas(*, stacklevel=0):
- comptime(
- lambda ctx: ctx.print_disas(
- stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1
- )
- )
- def print_value_stack(*, stacklevel=0):
- comptime(
- lambda ctx: ctx.print_value_stack(
- stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1
- )
- )
- # This is a more useful variant of print_value_stack that can be used
- # in an expression context; e.g., x + print_value_stack_and_return(y + z),
- # you will see x on the stack prior to the addition operation
- def print_value_stack_and_return(e, *, stacklevel=0):
- comptime(
- lambda ctx: ctx.print_value_stack(
- stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1
- )
- )
- return e
- def print_locals(*, stacklevel=0):
- comptime(
- lambda ctx: ctx.print_locals(
- stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1
- )
- )
- def print_bt(*, stacklevel=0):
- comptime(
- lambda ctx: ctx.print_bt(
- stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1
- )
- )
- def print_guards():
- comptime(lambda ctx: ctx.print_guards())
- def comptime(fn):
- """fn gets called at compile time in TorchDynamo, does nothing otherwise"""
- return
- comptime.graph_break = graph_break
- comptime.print_graph = print_graph
- comptime.print_disas = print_disas
- comptime.print_value_stack = print_value_stack
- comptime.print_value_stack_and_return = print_value_stack_and_return
- comptime.print_locals = print_locals
- comptime.print_bt = print_bt
- comptime.print_guards = print_guards
|