_dataclass_impls.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. # Functions for synthesizing magic methods for JIT-compiled dataclasses
  2. import os
  3. from functools import partial
  4. from torch._jit_internal import is_optional, FAKE_FILENAME_PREFIX
  5. from torch._sources import ParsedDef, SourceContext
  6. from typing import Callable, Dict, List
  7. import ast
  8. import dataclasses
  9. import inspect
  10. def _get_fake_filename(cls, method_name):
  11. return os.path.join(FAKE_FILENAME_PREFIX, cls.__name__, method_name)
  12. def compose_fn(cls, name: str, body_lines: List[str], signature: str) -> ParsedDef:
  13. body = '\n'.join(f' {b}' for b in body_lines)
  14. decl = f'def {name}{signature}:\n{body}'
  15. # Parse the function declaration
  16. try:
  17. py_ast = ast.parse(decl)
  18. except SyntaxError as e:
  19. # This should only happen if there's some unforeseeable change
  20. # in the dataclasses module that makes our synthesized code fail
  21. raise RuntimeError(
  22. f"TorchScript failed to synthesize dataclass method '{name}' for class '{cls.__name__}'. "
  23. "Please file a bug report at <https://github.com/pytorch/pytorch/issues>"
  24. ) from e
  25. fake_filename = _get_fake_filename(cls, name)
  26. # Parse the function
  27. return ParsedDef(
  28. py_ast,
  29. ctx=SourceContext(
  30. source=decl,
  31. filename=fake_filename,
  32. file_lineno=0,
  33. leading_whitespace_len=0
  34. ),
  35. source=decl,
  36. filename=fake_filename,
  37. file_lineno=0
  38. )
  39. def synthesize__init__(cls) -> ParsedDef:
  40. # Supporting default factories in the way that people expect would sort of require us to
  41. # allow compiling lambda functions, which is not currently supported.
  42. if any(field.default_factory is not dataclasses.MISSING for field in dataclasses.fields(cls)):
  43. raise NotImplementedError("Default factory initializers are not supported in TorchScript dataclasses")
  44. # Simply read off the generated __init__ signature from CPython's implementation. It'll be
  45. # almost correct except for InitVar annotations, which we need to handle specially.
  46. signature = inspect.signature(cls.__init__)
  47. # Handle InitVars if needed (only works on Python 3.8+, when a `type` attribute was added to InitVar);
  48. # see CPython commit here https://github.com/python/cpython/commit/01ee12ba35a333e8a6a25c4153c4a21838e9585c
  49. init_vars: List[str] = []
  50. params = []
  51. for name, param in signature.parameters.items():
  52. ann = param.annotation
  53. if isinstance(ann, dataclasses.InitVar):
  54. # The TorchScript interpreter can't handle InitVar annotations, so we unwrap the underlying type here
  55. init_vars.append(name)
  56. params.append(param.replace(annotation=ann.type)) # type: ignore[attr-defined]
  57. else:
  58. params.append(param)
  59. signature = signature.replace(parameters=params)
  60. body = [
  61. # Assign all attributes to self
  62. f'self.{field.name} = {field.name}'
  63. for field in dataclasses.fields(cls)
  64. if field.init and field.name not in init_vars
  65. ]
  66. # Call user's impl of __post_init__ if it exists
  67. if hasattr(cls, '__post_init__'):
  68. body.append('self.__post_init__(' + ', '.join(init_vars) + ')')
  69. return compose_fn(cls, '__init__', body or ['pass'], signature=str(signature))
  70. # This is a placeholder at the moment since the TorchScript interpreter doesn't call __repr__
  71. def synthesize__repr__(cls) -> ParsedDef:
  72. return compose_fn(
  73. cls, '__repr__',
  74. [f"return '{cls.__name__}(" + ", ".join([
  75. f"{field.name}=self.{field.name}"
  76. for field in dataclasses.fields(cls) if field.repr
  77. ]) + ")'"],
  78. signature='(self) -> str'
  79. )
  80. def synthesize__hash__(cls) -> ParsedDef:
  81. return compose_fn(
  82. cls, '__hash__',
  83. [
  84. # This is just a placeholder to prevent compilation from failing; this won't even get called at
  85. # all right now because the TorchScript interpreter doesn't call custom __hash__ implementations
  86. "raise NotImplementedError('__hash__ is not supported for dataclasses in TorchScript')"
  87. ],
  88. signature='(self) -> int'
  89. )
  90. # Implementation for __eq__ and __ne__
  91. def synthesize_equality(cls, name: str, converse: str) -> ParsedDef:
  92. return synthesize_comparison(cls, name, allow_eq=True, raise_on_none=False, inner=[
  93. f"if val1 {converse} val2: return False"
  94. ])
  95. def synthesize_inequality(cls, name: str, op: str, allow_eq: bool) -> ParsedDef:
  96. return synthesize_comparison(cls, name, allow_eq, raise_on_none=True, inner=[
  97. f"if val1 {op} val2: return True",
  98. f"elif val2 {op} val1: return False",
  99. ])
  100. def synthesize_comparison(cls, name: str, allow_eq: bool, raise_on_none: bool, inner: List[str]) -> ParsedDef:
  101. body = []
  102. for field in dataclasses.fields(cls):
  103. if not field.compare:
  104. continue
  105. body.extend([
  106. f"val1 = self.{field.name}",
  107. f"val2 = other.{field.name}",
  108. ])
  109. body.extend(
  110. inner if not is_optional(field.type) else [
  111. # Type refinement for optional fields; we need this to avoid type errors from the interpreter
  112. "if val1 is not None and val2 is not None:",
  113. *[' ' + line for line in inner],
  114. "elif (val1 is None) != (val2 is None):",
  115. f" raise TypeError('Cannot compare {cls.__name__} with None')" if raise_on_none else " return False"
  116. ]
  117. )
  118. body.append(f"return {allow_eq}")
  119. return compose_fn(cls, name, body, signature=f'(self, other: {cls.__name__}) -> bool')
  120. DATACLASS_MAGIC_METHODS: Dict[str, Callable] = {
  121. "__init__": synthesize__init__,
  122. "__repr__": synthesize__repr__,
  123. "__hash__": synthesize__hash__,
  124. "__eq__": partial(synthesize_equality, name="__eq__", converse="!="),
  125. "__ne__": partial(synthesize_equality, name="__ne__", converse="=="),
  126. "__lt__": partial(synthesize_inequality, name="__lt__", op="<", allow_eq=False),
  127. "__le__": partial(synthesize_inequality, name="__le__", op="<", allow_eq=True),
  128. "__gt__": partial(synthesize_inequality, name="__gt__", op=">", allow_eq=False),
  129. "__ge__": partial(synthesize_inequality, name="__ge__", op=">", allow_eq=True),
  130. }