_check.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. import ast
  2. import inspect
  3. import textwrap
  4. import torch
  5. import warnings
  6. class AttributeTypeIsSupportedChecker(ast.NodeVisitor):
  7. """
  8. Checks the ``__init__`` method of a given ``nn.Module`` to ensure
  9. that all instance-level attributes can be properly initialized.
  10. Specifically, we do type inference based on attribute values...even
  11. if the attribute in question has already been typed using
  12. Python3-style annotations or ``torch.jit.annotate``. This means that
  13. setting an instance-level attribute to ``[]`` (for ``List``),
  14. ``{}`` for ``Dict``), or ``None`` (for ``Optional``) isn't enough
  15. information for us to properly initialize that attribute.
  16. An object of this class can walk a given ``nn.Module``'s AST and
  17. determine if it meets our requirements or not.
  18. Known limitations
  19. 1. We can only check the AST nodes for certain constructs; we can't
  20. ``eval`` arbitrary expressions. This means that function calls,
  21. class instantiations, and complex expressions that resolve to one of
  22. the "empty" values specified above will NOT be flagged as
  23. problematic.
  24. 2. We match on string literals, so if the user decides to use a
  25. non-standard import (e.g. `from typing import List as foo`), we
  26. won't catch it.
  27. Example:
  28. .. code-block:: python
  29. class M(torch.nn.Module):
  30. def fn(self):
  31. return []
  32. def __init__(self):
  33. super().__init__()
  34. self.x: List[int] = []
  35. def forward(self, x: List[int]):
  36. self.x = x
  37. return 1
  38. The above code will pass the ``AttributeTypeIsSupportedChecker``
  39. check since we have a function call in ``__init__``. However,
  40. it will still fail later with the ``RuntimeError`` "Tried to set
  41. nonexistent attribute: x. Did you forget to initialize it in
  42. __init__()?".
  43. Args:
  44. nn_module - The instance of ``torch.nn.Module`` whose
  45. ``__init__`` method we wish to check
  46. """
  47. def check(self, nn_module: torch.nn.Module) -> None:
  48. source_lines = inspect.getsource(nn_module.__class__.__init__)
  49. # Ignore comments no matter the indentation
  50. def is_useless_comment(line):
  51. line = line.strip()
  52. return line.startswith("#") and not line.startswith("# type:")
  53. source_lines = "\n".join([l for l in source_lines.split("\n") if not is_useless_comment(l)])
  54. # This AST only contains the `__init__` method of the nn.Module
  55. init_ast = ast.parse(textwrap.dedent(source_lines))
  56. # Get items annotated in the class body
  57. self.class_level_annotations = list(nn_module.__annotations__.keys())
  58. # Flag for later
  59. self.visiting_class_level_ann = False
  60. self.visit(init_ast)
  61. def _is_empty_container(self, node: ast.AST, ann_type: str) -> bool:
  62. if ann_type == "List":
  63. # Assigning `[]` to a `List` type gives you a Node where
  64. # value=List(elts=[], ctx=Load())
  65. if not isinstance(node, ast.List):
  66. return False
  67. if node.elts:
  68. return False
  69. elif ann_type == "Dict":
  70. # Assigning `{}` to a `Dict` type gives you a Node where
  71. # value=Dict(keys=[], values=[])
  72. if not isinstance(node, ast.Dict):
  73. return False
  74. if node.keys:
  75. return False
  76. elif ann_type == "Optional":
  77. # Assigning `None` to an `Optional` type gives you a
  78. # Node where value=Constant(value=None, kind=None)
  79. if not isinstance(node, ast.Constant):
  80. return False
  81. if node.value: # type: ignore[attr-defined]
  82. return False
  83. return True
  84. def visit_Assign(self, node):
  85. """
  86. If we're visiting a Call Node (the right-hand side of an
  87. assignment statement), we won't be able to check the variable
  88. that we're assigning to (the left-hand side of an assignment).
  89. Because of this, we need to store this state in visitAssign.
  90. (Luckily, we only have to do this if we're assigning to a Call
  91. Node, i.e. ``torch.jit.annotate``. If we're using normal Python
  92. annotations, we'll be visiting an AnnAssign Node, which has its
  93. target built in.)
  94. """
  95. try:
  96. if (isinstance(node.value, ast.Call)
  97. and node.targets[0].attr in self.class_level_annotations):
  98. self.visiting_class_level_ann = True
  99. except AttributeError:
  100. return
  101. self.generic_visit(node)
  102. self.visiting_class_level_ann = False
  103. def visit_AnnAssign(self, node):
  104. """
  105. Visit an AnnAssign node in an ``nn.Module``'s ``__init__``
  106. method and see if it conforms to our attribute annotation rules.
  107. """
  108. # If we have a local variable
  109. try:
  110. if node.target.value.id != "self":
  111. return
  112. except AttributeError:
  113. return
  114. # If we have an attribute that's already been annotated at the
  115. # class level
  116. if node.target.attr in self.class_level_annotations:
  117. return
  118. # TODO @ansley: add `Union` once landed
  119. # NB: Even though `Tuple` is a "container", we don't want to
  120. # check for it here. `Tuple` functions as an type with an
  121. # "infinite" number of subtypes, in the sense that you can have
  122. # `Tuple[())]`, `Tuple[T1]`, `Tuple[T2]`, `Tuple[T1, T2]`,
  123. # `Tuple[T2, T1]` and so on, and none of these subtypes can be
  124. # used in place of the other. Therefore, assigning an empty
  125. # tuple in `__init__` CORRECTLY means that that variable
  126. # cannot be reassigned later to a non-empty tuple. Same
  127. # deal with `NamedTuple`
  128. containers = {"List", "Dict", "Optional"}
  129. # If we're not evaluating one of the specified problem types
  130. try:
  131. if node.annotation.value.id not in containers:
  132. return
  133. except AttributeError:
  134. # To evaluate a base type (`str`, `int`, etc.), we would
  135. # have needed to get the name through `node.annotation.id`
  136. # instead of `node.annotation.value.id`. Seems that we're
  137. # not evaluating one of our "containers"
  138. return
  139. # Check if the assigned variable is empty
  140. ann_type = node.annotation.value.id
  141. if not self._is_empty_container(node.value, ann_type):
  142. return
  143. warnings.warn("The TorchScript type system doesn't support "
  144. "instance-level annotations on empty non-base "
  145. "types in `__init__`. Instead, either 1) use a "
  146. "type annotation in the class body, or 2) wrap "
  147. "the type in `torch.jit.Attribute`.")
  148. def visit_Call(self, node):
  149. """
  150. Visit a Call node in an ``nn.Module``'s ``__init__``
  151. method and determine if it's ``torch.jit.annotate``. If so,
  152. see if it conforms to our attribute annotation rules.
  153. """
  154. # If we have an attribute that's already been annotated at the
  155. # class level
  156. if self.visiting_class_level_ann:
  157. return
  158. # If this isn't a call to `torch.jit.annotate`
  159. try:
  160. if (node.func.value.value.id != "torch"
  161. or node.func.value.attr != "jit"
  162. or node.func.attr != "annotate"):
  163. self.generic_visit(node)
  164. elif (node.func.value.value.id != "jit"
  165. or node.func.value.attr != "annotate"):
  166. self.generic_visit(node)
  167. except AttributeError:
  168. # Looks like we didn't even have the right node structure
  169. # to check for `torch.jit.annotate` in the first place
  170. self.generic_visit(node)
  171. # Invariant: we have a `torch.jit.annotate` or a
  172. # `torch.annotate` call
  173. # A Call Node for `torch.jit.annotate` should have an `args`
  174. # list of length 2 where args[0] represents the annotation and
  175. # args[1] represents the actual value
  176. if len(node.args) != 2:
  177. return
  178. if not isinstance(node.args[0], ast.Subscript):
  179. return
  180. # See notes in `visit_AnnAssign` r.e. containers
  181. containers = {"List", "Dict", "Optional"}
  182. try:
  183. ann_type = node.args[0].value.id # type: ignore[attr-defined]
  184. except AttributeError:
  185. return
  186. if ann_type not in containers:
  187. return
  188. # Check if the assigned variable is empty
  189. if not self._is_empty_container(node.args[1], ann_type):
  190. return
  191. warnings.warn("The TorchScript type system doesn't support "
  192. "instance-level annotations on empty non-base "
  193. "types in `__init__`. Instead, either 1) use a "
  194. "type annotation in the class body, or 2) wrap "
  195. "the type in `torch.jit.Attribute`.")