normalize.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. import operator
  2. from typing import Any, Callable, Dict, Tuple, Optional
  3. import torch
  4. import torch.fx
  5. import torch.fx as fx
  6. from torch.fx import Transformer, Proxy
  7. from torch.fx.node import Argument, Target, Node, map_aggregate
  8. from torch.fx.operator_schemas import (
  9. normalize_module,
  10. normalize_function,
  11. create_type_hint,
  12. )
  13. from .schema_type_annotation import AnnotateTypesWithSchema
  14. class NormalizeArgs(Transformer):
  15. """
  16. Normalize arguments to Python targets. This means that
  17. `args/kwargs` will be matched up to the module/functional's
  18. signature and rewritten to exclusively kwargs in positional order
  19. if `normalize_to_only_use_kwargs` is true. Also populates default
  20. values. Does not support positional-only parameters or varargs
  21. parameters (*args, **kwargs).
  22. If the nodes have 'type' metadata, it will use it to disambiguate
  23. overloads. Otherwise, it will throw an error.
  24. Example usage:
  25. m = torchvision.models.resnet18()
  26. traced = torch.fx.symbolic_trace(m)
  27. traced = NormalizeArgs(traced).transform()
  28. """
  29. def __init__(
  30. self, module: torch.fx.GraphModule, normalize_to_only_use_kwargs: bool = True
  31. ):
  32. super().__init__(module)
  33. self.node_map: Dict[Proxy, Node] = {}
  34. self.normalize_to_only_use_kwargs = normalize_to_only_use_kwargs
  35. def run_node(self, n: Node) -> Any:
  36. args, kwargs = self.fetch_args_kwargs_from_env(n)
  37. def get_type(arg):
  38. if isinstance(arg, fx.Node):
  39. return n.meta["type"] if "type" in n.meta else None
  40. return type(arg)
  41. arg_types = map_aggregate(n.args, get_type)
  42. assert isinstance(arg_types, tuple)
  43. arg_types = tuple([create_type_hint(i) for i in arg_types])
  44. kwarg_types = {k: get_type(v) for k, v in kwargs.items()}
  45. if n.op == "call_function":
  46. out = self.call_function(n.target, args, kwargs, arg_types, kwarg_types)
  47. else:
  48. out = super().run_node(n)
  49. if n.op != "output":
  50. self.node_map[out] = n
  51. out.node.meta = n.meta
  52. out.node.type = n.type
  53. return out
  54. def call_function(
  55. self,
  56. target: Target,
  57. args: Tuple[Argument, ...],
  58. kwargs: Dict[str, Any],
  59. arg_types: Optional[Tuple[Any, ...]] = None,
  60. kwarg_types: Optional[Dict[str, Any]] = None,
  61. ):
  62. assert callable(target)
  63. new_args_and_kwargs = normalize_function(
  64. target,
  65. args, # type: ignore[arg-type]
  66. kwargs,
  67. arg_types, # type: ignore[arg-type]
  68. kwarg_types,
  69. self.normalize_to_only_use_kwargs,
  70. )
  71. if new_args_and_kwargs:
  72. new_args, new_kwargs = new_args_and_kwargs
  73. return self.tracer.create_proxy(
  74. "call_function", target, new_args, new_kwargs
  75. )
  76. else:
  77. return super().call_function(target, args, kwargs)
  78. def call_module(
  79. self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]
  80. ):
  81. assert isinstance(target, str)
  82. new_args_and_kwargs = normalize_module(
  83. self.module,
  84. target,
  85. args, # type: ignore[arg-type]
  86. kwargs,
  87. self.normalize_to_only_use_kwargs,
  88. )
  89. if new_args_and_kwargs:
  90. new_args, new_kwargs = new_args_and_kwargs
  91. return super().call_module(target, new_args, new_kwargs)
  92. else:
  93. return super().call_module(target, args, kwargs)
  94. class NormalizeOperators(AnnotateTypesWithSchema):
  95. """
  96. Normalize callsites that are different ways of "spelling" the same
  97. invocation into a single, canonical call. Currently supports:
  98. 1. Normalize operators (e.g. operator.add) to the `torch` ops they
  99. ultimately invoke (e.g. torch.add) when it is possible to statically
  100. reason that
  101. Example usage:
  102. m = torchvision.models.resnet18()
  103. traced = torch.fx.symbolic_trace(m)
  104. traced = NormalizeOperators(traced).transform()
  105. """
  106. binary_magic_method_remap: Dict[
  107. Callable[[Any, Any], Any], Callable[[Any, Any], Any]
  108. ] = {
  109. torch.add: operator.add,
  110. torch.mul: operator.mul,
  111. torch.sub: operator.sub,
  112. torch.div: operator.truediv,
  113. torch.floor_divide: operator.floordiv,
  114. torch.remainder: operator.mod,
  115. torch.eq: operator.eq,
  116. torch.ne: operator.ne,
  117. torch.lt: operator.lt,
  118. torch.le: operator.le,
  119. torch.gt: operator.gt,
  120. torch.ge: operator.ge,
  121. }
  122. def call_function(
  123. self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]
  124. ):
  125. # Normalize operators according to the magic methods implemented on tensors here:
  126. # https://github.com/pytorch/pytorch/blob/28c5d90b679c6b38bf4183ec99f16d933c2f1bcd/tools/autograd/templates/python_variable_methods.cpp#L1137 # noqa: B950
  127. assert callable(target)
  128. if target in self.binary_magic_method_remap:
  129. if len(args) != 2:
  130. return super().call_function(target, args, kwargs)
  131. lhs, rhs = args
  132. return super().call_function(
  133. target=self.binary_magic_method_remap[target],
  134. args=(lhs, rhs),
  135. kwargs={},
  136. )
  137. return super().call_function(target, args, kwargs)