123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121 |
- import ast
- import inspect
- import textwrap
- import copy
- import functools
- from types import FunctionType
- from typing import cast, Union, Callable, Dict, Optional, Any
- from torch.fx._symbolic_trace import Tracer
- from torch.fx.graph import Graph
- from torch._sources import normalize_source_lines
- import torch
- class AST_Rewriter(ast.NodeTransformer):
- """
- Take a FunctionType object representing a `forward` method, then
- perform an AST rewrite to swap out nodes that are not symbolically
- traceable with a callsite to the FX alternative.
- To support swapping out an AST node, define a new `visit` method on
- that node. For more details, see:
- https://docs.python.org/3/library/ast.html#ast.NodeTransformer
- """
- def rewrite(self, fn: FunctionType):
- # Normalize the source lines
- sourcelines, _ = inspect.getsourcelines(fn)
- sourcelines = normalize_source_lines(sourcelines)
- source = ''.join(sourcelines)
- normalized_str = textwrap.dedent(source)
- # Rewrite the original AST
- source_ast = ast.parse(normalized_str)
- dest_ast = ast.fix_missing_locations(self.visit(source_ast))
- # Pull out the compiled fucntion from the newly-created Module
- code = compile(dest_ast, "", "exec")
- globals_dict = copy.copy(fn.__globals__)
- keys_before = set(globals_dict.keys())
- exec(code, globals_dict)
- new_keys = list(set(globals_dict.keys()) - keys_before)
- assert len(new_keys) == 1
- fn_compiled = globals_dict[new_keys[0]]
- # return the compiled function with the original globals
- def change_func_globals(f, globals):
- """Based on https://stackoverflow.com/a/13503277/2988730 (@unutbu)"""
- # __globals__ is a private member of the function class
- # so we have to copy the function, f, all of its member, except f.__globals__
- g = FunctionType(
- f.__code__,
- globals,
- name=f.__name__,
- argdefs=f.__defaults__,
- closure=f.__closure__,
- )
- g = functools.update_wrapper(g, f)
- g.__kwdefaults__ = copy.copy(f.__kwdefaults__)
- return g
- # Return the correct FunctionType object
- return change_func_globals(fn_compiled, globals=fn.__globals__)
- def visit_Assert(self, node):
- """
- Swap out the Assert node (Python's `assert`) with a callsite to the
- symbolically-traceable torch._assert function
- """
- # Create the Call node
- n = ast.parse('torch._assert()', mode='eval')
- assert isinstance(n, ast.Expression)
- call_node = n.body
- assert isinstance(call_node, ast.Call)
- msg = node.msg if node.msg else ast.Constant(value="", kind=None)
- call_node.args = [node.test, msg]
- # Ensure that the new node conforms to the Python AST grammar
- expr_wrapper = ast.Expr(value=call_node)
- # Return the new Call node to signify that we want to use it as
- # a replacement for the original _assert node
- return ast.copy_location(expr_wrapper, node)
- def visit_AnnAssign(self, node):
- """
- Swap out Python's AnnAssign with an Assign node where the annotation function is called.
- Example:
- Original:
- y: Tensor_Type(1,2,3, Dyn) = f2(x)
- Output:
- y = annotate(f2(x),Tensor_Type((1,2,3,Dyn)))
- """
- return ast.Assign(targets=[node.target], value=ast.Call(
- func=ast.Name(id='annotate', ctx=ast.Load()),
- args=[node.value, node.annotation], keywords=[]))
- class RewritingTracer(Tracer):
- def trace(self, root: Union[torch.nn.Module, Callable], concrete_args: Optional[Dict[str, Any]] = None) -> Graph:
- return super().trace(_rewrite(root), concrete_args)
- def _rewrite(fn: Union[torch.nn.Module, Callable]) -> Union[torch.nn.Module, Callable]:
- if isinstance(fn, torch.nn.Module):
- # Rewrite this module's `forward` as well as the `forward`s of
- # all of this module's recursive descendents. Return the new,
- # rewritten module hierarchy.
- def rewrite_module(m : torch.nn.Module):
- class RewrittenModule(torch.nn.Module):
- def __init__(self, orig):
- super().__init__()
- for k, v in orig.__dict__.items():
- if isinstance(v, torch.nn.Module):
- self.__dict__[k] = copy.copy(rewrite_module(v))
- else:
- self.__dict__[k] = copy.copy(v)
- RewrittenModule.forward = AST_Rewriter().rewrite(cast(FunctionType, m.forward))
- return RewrittenModule(m)
- return rewrite_module(fn)
- else:
- # Rewrite this single free function
- return AST_Rewriter().rewrite(cast(FunctionType, fn))
|