_decompositions.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. import torch
  2. from torch import Tensor
  3. aten = torch.ops.aten
  4. from typing import Optional, List, Dict, Set
  5. import inspect
  6. from torch.fx.operator_schemas import get_signature_for_torch_op
  7. import warnings
  8. decomposition_table: Dict[str, torch.jit.ScriptFunction] = {}
  9. function_name_set: Set[str] = set()
  10. def check_decomposition_has_type_annotations(f):
  11. inspect_empty = inspect._empty # type: ignore[attr-defined]
  12. sig = inspect.signature(f)
  13. for param in sig.parameters.values():
  14. assert param.annotation != inspect_empty, \
  15. "No signature on param {name} for function {func}".format(name=param.name, func=f.name)
  16. assert sig.return_annotation != inspect_empty, "No return annotation for function {func}".format(func=f.name)
  17. def signatures_match(decomposition_sig, torch_op_sig):
  18. decomp_params = decomposition_sig.parameters
  19. op_params = torch_op_sig.parameters
  20. if len(decomp_params) != len(op_params):
  21. return False
  22. for decomp_param, op_param in zip(decomp_params.values(), op_params.values()):
  23. # can't check full equality yet because not all fields are correcly deduced
  24. # in the torch_op_sig - like default value
  25. # can't check 'kind' bc
  26. # kwarg-only values with defaults not yet supported in TS
  27. inspect_empty = inspect._empty # type: ignore[attr-defined]
  28. for field in ['name', 'annotation']:
  29. if field == 'name' and decomp_param.name == "self":
  30. warnings.warn("PyTorch uses 'input' instead of 'self' on public api")
  31. if getattr(decomp_param, field) != getattr(op_param, field):
  32. return False
  33. decomp_default = decomp_param.default
  34. op_default = op_param.default
  35. # default value not always correctly inferred as being present on torch schema,
  36. # but if specified on both they should be equal
  37. if decomp_default != inspect_empty and op_default != inspect_empty:
  38. if decomp_default != op_default:
  39. return False
  40. return decomposition_sig.return_annotation == torch_op_sig.return_annotation
  41. def register_decomposition(aten_op, registry=None):
  42. def decomposition_decorator(f):
  43. nonlocal registry
  44. if registry is None:
  45. registry = decomposition_table
  46. check_decomposition_has_type_annotations(f)
  47. torch_op_sigs, torch_op_schemas = get_signature_for_torch_op(aten_op, return_schemas=True)
  48. decomposition_sig = inspect.signature(f)
  49. found_index = None
  50. for i, torch_op_sig in enumerate(torch_op_sigs):
  51. if signatures_match(decomposition_sig, torch_op_sig):
  52. found_index = i
  53. break
  54. assert found_index is not None, "Could not find matching signature: " + str(f)
  55. # Need unique name for jit function serialization
  56. assert f.__name__ not in function_name_set, "Duplicated function name {}".format(f.__name__)
  57. function_name_set.add(f.__name__)
  58. scripted_func = torch.jit.script(f)
  59. torch._C._jit_pass_inline(scripted_func.graph)
  60. for _ in range(2):
  61. torch._C._jit_pass_peephole(scripted_func.graph)
  62. torch._C._jit_pass_constant_propagation(scripted_func.graph)
  63. registry[str(torch_op_schemas[found_index])] = scripted_func
  64. return f
  65. return decomposition_decorator
  66. # TODO: replace torch.sigmoid -> aten.sigmoid
  67. @register_decomposition(aten.var)
  68. def var_decomposition(input: Tensor, dim: Optional[List[int]] = None, correction: Optional[int] = None,
  69. keepdim: bool = False) -> Tensor:
  70. if dim is None:
  71. dim_i: List[int] = []
  72. dim = dim_i
  73. if isinstance(dim, (tuple, list)) and len(dim) == 0:
  74. n = input.numel()
  75. else:
  76. n = 1
  77. for dim_i in dim: # type: ignore[assignment]
  78. n *= input.shape[dim_i] # type: ignore[call-overload]
  79. mean = aten.mean(input, dim, True)
  80. sub = input - mean
  81. sq = sub * sub
  82. sum = aten.sum(sq, dim, keepdim)
  83. if correction is not None:
  84. n = n - correction
  85. return sum / n
  86. @register_decomposition(aten.var)
  87. def var(input: Tensor, unbiased: bool = True) -> Tensor:
  88. return var_decomposition(input, correction=(1 if unbiased else 0))