rewriter.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. import ast
  2. import inspect
  3. import textwrap
  4. import copy
  5. import functools
  6. from types import FunctionType
  7. from typing import cast, Union, Callable, Dict, Optional, Any
  8. from torch.fx._symbolic_trace import Tracer
  9. from torch.fx.graph import Graph
  10. from torch._sources import normalize_source_lines
  11. import torch
  12. class AST_Rewriter(ast.NodeTransformer):
  13. """
  14. Take a FunctionType object representing a `forward` method, then
  15. perform an AST rewrite to swap out nodes that are not symbolically
  16. traceable with a callsite to the FX alternative.
  17. To support swapping out an AST node, define a new `visit` method on
  18. that node. For more details, see:
  19. https://docs.python.org/3/library/ast.html#ast.NodeTransformer
  20. """
  21. def rewrite(self, fn: FunctionType):
  22. # Normalize the source lines
  23. sourcelines, _ = inspect.getsourcelines(fn)
  24. sourcelines = normalize_source_lines(sourcelines)
  25. source = ''.join(sourcelines)
  26. normalized_str = textwrap.dedent(source)
  27. # Rewrite the original AST
  28. source_ast = ast.parse(normalized_str)
  29. dest_ast = ast.fix_missing_locations(self.visit(source_ast))
  30. # Pull out the compiled fucntion from the newly-created Module
  31. code = compile(dest_ast, "", "exec")
  32. globals_dict = copy.copy(fn.__globals__)
  33. keys_before = set(globals_dict.keys())
  34. exec(code, globals_dict)
  35. new_keys = list(set(globals_dict.keys()) - keys_before)
  36. assert len(new_keys) == 1
  37. fn_compiled = globals_dict[new_keys[0]]
  38. # return the compiled function with the original globals
  39. def change_func_globals(f, globals):
  40. """Based on https://stackoverflow.com/a/13503277/2988730 (@unutbu)"""
  41. # __globals__ is a private member of the function class
  42. # so we have to copy the function, f, all of its member, except f.__globals__
  43. g = FunctionType(
  44. f.__code__,
  45. globals,
  46. name=f.__name__,
  47. argdefs=f.__defaults__,
  48. closure=f.__closure__,
  49. )
  50. g = functools.update_wrapper(g, f)
  51. g.__kwdefaults__ = copy.copy(f.__kwdefaults__)
  52. return g
  53. # Return the correct FunctionType object
  54. return change_func_globals(fn_compiled, globals=fn.__globals__)
  55. def visit_Assert(self, node):
  56. """
  57. Swap out the Assert node (Python's `assert`) with a callsite to the
  58. symbolically-traceable torch._assert function
  59. """
  60. # Create the Call node
  61. n = ast.parse('torch._assert()', mode='eval')
  62. assert isinstance(n, ast.Expression)
  63. call_node = n.body
  64. assert isinstance(call_node, ast.Call)
  65. msg = node.msg if node.msg else ast.Constant(value="", kind=None)
  66. call_node.args = [node.test, msg]
  67. # Ensure that the new node conforms to the Python AST grammar
  68. expr_wrapper = ast.Expr(value=call_node)
  69. # Return the new Call node to signify that we want to use it as
  70. # a replacement for the original _assert node
  71. return ast.copy_location(expr_wrapper, node)
  72. def visit_AnnAssign(self, node):
  73. """
  74. Swap out Python's AnnAssign with an Assign node where the annotation function is called.
  75. Example:
  76. Original:
  77. y: Tensor_Type(1,2,3, Dyn) = f2(x)
  78. Output:
  79. y = annotate(f2(x),Tensor_Type((1,2,3,Dyn)))
  80. """
  81. return ast.Assign(targets=[node.target], value=ast.Call(
  82. func=ast.Name(id='annotate', ctx=ast.Load()),
  83. args=[node.value, node.annotation], keywords=[]))
  84. class RewritingTracer(Tracer):
  85. def trace(self, root: Union[torch.nn.Module, Callable], concrete_args: Optional[Dict[str, Any]] = None) -> Graph:
  86. return super().trace(_rewrite(root), concrete_args)
  87. def _rewrite(fn: Union[torch.nn.Module, Callable]) -> Union[torch.nn.Module, Callable]:
  88. if isinstance(fn, torch.nn.Module):
  89. # Rewrite this module's `forward` as well as the `forward`s of
  90. # all of this module's recursive descendents. Return the new,
  91. # rewritten module hierarchy.
  92. def rewrite_module(m : torch.nn.Module):
  93. class RewrittenModule(torch.nn.Module):
  94. def __init__(self, orig):
  95. super().__init__()
  96. for k, v in orig.__dict__.items():
  97. if isinstance(v, torch.nn.Module):
  98. self.__dict__[k] = copy.copy(rewrite_module(v))
  99. else:
  100. self.__dict__[k] = copy.copy(v)
  101. RewrittenModule.forward = AST_Rewriter().rewrite(cast(FunctionType, m.forward))
  102. return RewrittenModule(m)
  103. return rewrite_module(fn)
  104. else:
  105. # Rewrite this single free function
  106. return AST_Rewriter().rewrite(cast(FunctionType, fn))