meta_tracer.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. import torch
  2. import torch.fx
  3. import warnings
  4. import functools
  5. import builtins
  6. from typing import Any, Callable, Dict, Optional, Union
  7. def embedding_override(self, input):
  8. return torch.empty(*input.shape, self.weight.shape[-1], device='meta')
  9. def nn_layernorm_override(self, input):
  10. return input
  11. def torch_relu_override(x):
  12. return x
  13. def torch_nn_relu_override(self, x):
  14. return x
  15. def functional_relu_override(x, inplace=False):
  16. assert not inplace, 'dont support inplace functional.relu for metatensor analysis'
  17. return x
  18. def torch_where_override(condition, x, y):
  19. # torch.where returns the broadcasted tensor of condition, x, and y,
  20. # so hack it by using addition
  21. return condition.to(device='meta') + x.to(device='meta') + y.to(device='meta')
  22. def torch_abs_override(input, *, out=None):
  23. assert out is None, 'Dont support in-place abs for MetaTensor analysis'
  24. return input
  25. manual_meta_overrides : Dict[Callable, Callable] = {
  26. torch.nn.Embedding: embedding_override,
  27. torch.nn.LayerNorm: nn_layernorm_override,
  28. torch.relu: torch_relu_override,
  29. torch.nn.functional.relu: functional_relu_override,
  30. torch.nn.ReLU: torch_nn_relu_override,
  31. torch.where: torch_where_override,
  32. torch.abs: torch_abs_override,
  33. }
  34. def gen_constructor_wrapper(target):
  35. @functools.wraps(target)
  36. def wrapper(*args, **kwargs):
  37. proxy = None
  38. def check_has_proxy(v):
  39. if isinstance(v, torch.fx.Proxy):
  40. nonlocal proxy
  41. proxy = v
  42. torch.fx.node.map_aggregate(args, check_has_proxy)
  43. torch.fx.node.map_aggregate(kwargs, check_has_proxy)
  44. if proxy is not None:
  45. return proxy.tracer.create_proxy('call_function', target, args, kwargs)
  46. else:
  47. return target(*args, **kwargs)
  48. return wrapper, target
  49. class MetaProxy(torch.fx.Proxy):
  50. def install_tensor_meta(self, tensor_meta):
  51. self._tensor_meta = tensor_meta
  52. def size(self, dim=None):
  53. if hasattr(self, '_tensor_meta') and self._tensor_meta is not None:
  54. return self._tensor_meta.size(*[dim] if dim else [])
  55. return self.tracer.create_proxy('call_method', 'size', (self, dim) if dim else (self,), {})
  56. def dim(self):
  57. if hasattr(self, '_tensor_meta') and self._tensor_meta is not None:
  58. return self._tensor_meta.dim()
  59. return self.tracer.create_proxy('call_method', 'dim', (self,), {})
  60. @property
  61. def shape(self):
  62. if hasattr(self, '_tensor_meta') and self._tensor_meta is not None:
  63. return self._tensor_meta.shape
  64. return self.tracer.create_proxy('call_function', builtins.getattr, (self, 'shape'), {})
  65. @property
  66. def dtype(self):
  67. if hasattr(self, '_tensor_meta') and self._tensor_meta is not None:
  68. return self._tensor_meta.dtype
  69. return self.tracer.create_proxy('call_function', builtins.getattr, (self, 'dtype'), {})
  70. @property
  71. def device(self):
  72. # Hack so we can track when devices are used. During meta-tensor propagation,
  73. # replace these values with a constant 'meta'
  74. return MetaDeviceAttribute(self, 'device')
  75. def __getattr__(self, k):
  76. if k == '_tensor_meta':
  77. return self.__getattribute__(k)
  78. # note: not added to the graph yet, if this is a method call
  79. # we peephole optimize to the method invocation
  80. return MetaAttribute(self, k)
  81. class MetaAttribute(MetaProxy):
  82. def __init__(self, root, attr: str):
  83. self.root = root
  84. self.attr = attr
  85. self.tracer = root.tracer
  86. self._node = None
  87. @property
  88. def node(self):
  89. # the node for attributes is added lazily, since most will just be method calls
  90. # which do not rely on the getitem call
  91. if self._node is None:
  92. self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node
  93. return self._node
  94. def __call__(self, *args, **kwargs):
  95. return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs)
  96. class MetaDeviceAttribute(MetaAttribute):
  97. pass
  98. def proxys_to_metas(v):
  99. if isinstance(v, MetaDeviceAttribute):
  100. return 'meta'
  101. if isinstance(v, torch.fx.Proxy):
  102. assert isinstance(v, MetaProxy), f'Expected MetaProxy but got {type(v)}'
  103. assert hasattr(v, '_tensor_meta'), 'MetaProxy does not have an associated meta'
  104. return v._tensor_meta
  105. return v
  106. class MetaTracer(torch.fx.Tracer):
  107. allow_insert_stateless_mods : bool = True
  108. _TORCH_METHODS_TO_PATCH = ['arange', 'zeros', 'ones', 'full_like', 'eye']
  109. def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None):
  110. rv = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn)
  111. if kind == 'placeholder' and target in self.meta_args:
  112. rv.install_tensor_meta(self.meta_args[target])
  113. return rv
  114. if target in self.orig_fns:
  115. # NOTE: tensor constructors in PyTorch define the `device` argument as
  116. # *kwargs-only*. That is why this works. If you add methods to
  117. # _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only,
  118. # this will break and you will likely see issues where we cannot infer
  119. # the size of the output.
  120. if 'device' in kwargs:
  121. kwargs['device'] = 'meta'
  122. try:
  123. args_metas = torch.fx.node.map_aggregate(args, proxys_to_metas)
  124. kwargs_metas = torch.fx.node.map_aggregate(kwargs, proxys_to_metas)
  125. if kind == 'call_function':
  126. meta_target = manual_meta_overrides.get(target, target)
  127. meta_out = meta_target(*args_metas, **kwargs_metas)
  128. elif kind == 'call_method':
  129. meta_out = getattr(args_metas[0], target)(*args_metas[1:], **kwargs_metas)
  130. elif kind == 'call_module':
  131. assert hasattr(self, 'orig_forward')
  132. self._disable_module_getattr = True
  133. try:
  134. mod = self.root.get_submodule(target)
  135. mod_type = type(mod)
  136. if mod_type in manual_meta_overrides:
  137. meta_out = manual_meta_overrides[mod_type](mod, *args_metas, **kwargs_metas)
  138. else:
  139. meta_out = self.orig_forward(*args_metas, **kwargs_metas)
  140. finally:
  141. self._disable_module_getattr = False
  142. elif kind == 'get_attr':
  143. self._disable_module_getattr = True
  144. try:
  145. attr_itr = self.root
  146. atoms = target.split('.')
  147. for atom in atoms:
  148. attr_itr = getattr(attr_itr, atom)
  149. assert isinstance(attr_itr, torch.Tensor)
  150. meta_out = attr_itr.to(device='meta')
  151. finally:
  152. self._disable_module_getattr = False
  153. else:
  154. return rv
  155. # TODO
  156. assert isinstance(rv, torch.fx.Proxy), 'Dont support composite output yet'
  157. rv.install_tensor_meta(meta_out)
  158. except Exception as e:
  159. warnings.warn(f'Could not compute metadata for {kind} target {target}: {e}')
  160. return rv
  161. def getattr(self, attr, attr_val, parameter_proxy_cache):
  162. if getattr(self, '_disable_module_getattr', False):
  163. return attr_val
  164. else:
  165. return super().getattr(attr, attr_val, parameter_proxy_cache)
  166. def call_module(self, m, forward, args, kwargs):
  167. self.orig_forward = forward
  168. return super().call_module(m, forward, args, kwargs)
  169. def _insert_module_as_submodule(self, mod: torch.nn.Module) -> str:
  170. """
  171. Helper method which tries to insert a module that was not declared as submodule.
  172. """
  173. idx = 0
  174. mod_name = mod.__class__.__name__.lower()
  175. path = f"{mod_name}_{idx}"
  176. while hasattr(self.root, path):
  177. path = f"{mod_name}_{idx}"
  178. idx += 1
  179. self.root.add_module(path, mod)
  180. return path
  181. def path_of_module(self, mod: torch.nn.Module) -> str:
  182. try:
  183. return super().path_of_module(mod)
  184. except NameError as e:
  185. if self.allow_insert_stateless_mods and len(list(mod.parameters())) == 0 and len(list(mod.buffers())) == 0:
  186. path = self._insert_module_as_submodule(mod)
  187. self.prev_module = path
  188. return path
  189. raise
  190. def proxy(self, node):
  191. return MetaProxy(node, self)
  192. def trace(self, root, meta_args : Dict[str, torch.Tensor], concrete_args=None):
  193. assert isinstance(meta_args, dict)
  194. self.meta_args = meta_args
  195. self.patched_torch_methods = {
  196. target: gen_constructor_wrapper(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH
  197. }
  198. self.orig_fns = set()
  199. for name, (wrapper, orig) in self.patched_torch_methods.items():
  200. setattr(torch, name, wrapper)
  201. self.orig_fns.add(orig)
  202. try:
  203. graph = super().trace(root, concrete_args)
  204. graph._tracer_extras = {'meta_args': meta_args}
  205. return graph
  206. finally:
  207. for name, (_, orig) in self.patched_torch_methods.items():
  208. setattr(torch, name, orig)
  209. def symbolic_trace(root : Union[torch.nn.Module, Callable[..., Any]],
  210. meta_args : Dict[str, torch.Tensor] = None,
  211. concrete_args: Optional[Dict[str, Any]] = None) -> torch.fx.GraphModule:
  212. tracer = MetaTracer()
  213. graph = tracer.trace(root, meta_args, concrete_args)
  214. name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__
  215. gm = torch.fx.GraphModule(tracer.root, graph, name)
  216. return gm