123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150 |
- # Functions for synthesizing magic methods for JIT-compiled dataclasses
- import os
- from functools import partial
- from torch._jit_internal import is_optional, FAKE_FILENAME_PREFIX
- from torch._sources import ParsedDef, SourceContext
- from typing import Callable, Dict, List
- import ast
- import dataclasses
- import inspect
- def _get_fake_filename(cls, method_name):
- return os.path.join(FAKE_FILENAME_PREFIX, cls.__name__, method_name)
- def compose_fn(cls, name: str, body_lines: List[str], signature: str) -> ParsedDef:
- body = '\n'.join(f' {b}' for b in body_lines)
- decl = f'def {name}{signature}:\n{body}'
- # Parse the function declaration
- try:
- py_ast = ast.parse(decl)
- except SyntaxError as e:
- # This should only happen if there's some unforeseeable change
- # in the dataclasses module that makes our synthesized code fail
- raise RuntimeError(
- f"TorchScript failed to synthesize dataclass method '{name}' for class '{cls.__name__}'. "
- "Please file a bug report at <https://github.com/pytorch/pytorch/issues>"
- ) from e
- fake_filename = _get_fake_filename(cls, name)
- # Parse the function
- return ParsedDef(
- py_ast,
- ctx=SourceContext(
- source=decl,
- filename=fake_filename,
- file_lineno=0,
- leading_whitespace_len=0
- ),
- source=decl,
- filename=fake_filename,
- file_lineno=0
- )
- def synthesize__init__(cls) -> ParsedDef:
- # Supporting default factories in the way that people expect would sort of require us to
- # allow compiling lambda functions, which is not currently supported.
- if any(field.default_factory is not dataclasses.MISSING for field in dataclasses.fields(cls)):
- raise NotImplementedError("Default factory initializers are not supported in TorchScript dataclasses")
- # Simply read off the generated __init__ signature from CPython's implementation. It'll be
- # almost correct except for InitVar annotations, which we need to handle specially.
- signature = inspect.signature(cls.__init__)
- # Handle InitVars if needed (only works on Python 3.8+, when a `type` attribute was added to InitVar);
- # see CPython commit here https://github.com/python/cpython/commit/01ee12ba35a333e8a6a25c4153c4a21838e9585c
- init_vars: List[str] = []
- params = []
- for name, param in signature.parameters.items():
- ann = param.annotation
- if isinstance(ann, dataclasses.InitVar):
- # The TorchScript interpreter can't handle InitVar annotations, so we unwrap the underlying type here
- init_vars.append(name)
- params.append(param.replace(annotation=ann.type)) # type: ignore[attr-defined]
- else:
- params.append(param)
- signature = signature.replace(parameters=params)
- body = [
- # Assign all attributes to self
- f'self.{field.name} = {field.name}'
- for field in dataclasses.fields(cls)
- if field.init and field.name not in init_vars
- ]
- # Call user's impl of __post_init__ if it exists
- if hasattr(cls, '__post_init__'):
- body.append('self.__post_init__(' + ', '.join(init_vars) + ')')
- return compose_fn(cls, '__init__', body or ['pass'], signature=str(signature))
- # This is a placeholder at the moment since the TorchScript interpreter doesn't call __repr__
- def synthesize__repr__(cls) -> ParsedDef:
- return compose_fn(
- cls, '__repr__',
- [f"return '{cls.__name__}(" + ", ".join([
- f"{field.name}=self.{field.name}"
- for field in dataclasses.fields(cls) if field.repr
- ]) + ")'"],
- signature='(self) -> str'
- )
- def synthesize__hash__(cls) -> ParsedDef:
- return compose_fn(
- cls, '__hash__',
- [
- # This is just a placeholder to prevent compilation from failing; this won't even get called at
- # all right now because the TorchScript interpreter doesn't call custom __hash__ implementations
- "raise NotImplementedError('__hash__ is not supported for dataclasses in TorchScript')"
- ],
- signature='(self) -> int'
- )
- # Implementation for __eq__ and __ne__
- def synthesize_equality(cls, name: str, converse: str) -> ParsedDef:
- return synthesize_comparison(cls, name, allow_eq=True, raise_on_none=False, inner=[
- f"if val1 {converse} val2: return False"
- ])
- def synthesize_inequality(cls, name: str, op: str, allow_eq: bool) -> ParsedDef:
- return synthesize_comparison(cls, name, allow_eq, raise_on_none=True, inner=[
- f"if val1 {op} val2: return True",
- f"elif val2 {op} val1: return False",
- ])
- def synthesize_comparison(cls, name: str, allow_eq: bool, raise_on_none: bool, inner: List[str]) -> ParsedDef:
- body = []
- for field in dataclasses.fields(cls):
- if not field.compare:
- continue
- body.extend([
- f"val1 = self.{field.name}",
- f"val2 = other.{field.name}",
- ])
- body.extend(
- inner if not is_optional(field.type) else [
- # Type refinement for optional fields; we need this to avoid type errors from the interpreter
- "if val1 is not None and val2 is not None:",
- *[' ' + line for line in inner],
- "elif (val1 is None) != (val2 is None):",
- f" raise TypeError('Cannot compare {cls.__name__} with None')" if raise_on_none else " return False"
- ]
- )
- body.append(f"return {allow_eq}")
- return compose_fn(cls, name, body, signature=f'(self, other: {cls.__name__}) -> bool')
- DATACLASS_MAGIC_METHODS: Dict[str, Callable] = {
- "__init__": synthesize__init__,
- "__repr__": synthesize__repr__,
- "__hash__": synthesize__hash__,
- "__eq__": partial(synthesize_equality, name="__eq__", converse="!="),
- "__ne__": partial(synthesize_equality, name="__ne__", converse="=="),
- "__lt__": partial(synthesize_inequality, name="__lt__", op="<", allow_eq=False),
- "__le__": partial(synthesize_inequality, name="__le__", op="<", allow_eq=True),
- "__gt__": partial(synthesize_inequality, name="__gt__", op=">", allow_eq=False),
- "__ge__": partial(synthesize_inequality, name="__ge__", op=">", allow_eq=True),
- }
|