unsupported_tensor_ops.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. import torch.jit
  2. from textwrap import dedent
  3. from typing import Dict, Any
  4. def execWrapper(code, glob, loc):
  5. exec(code, glob, loc)
  6. def _gen_unsupported_methods_properties():
  7. tensor_attrs = set(filter(lambda x: x[0] != "_", dir(torch.Tensor)))
  8. tensor = torch.tensor([2])
  9. funcs_template = dedent('''
  10. def func(x):
  11. return x.{op}()
  12. ''')
  13. deprecated_apis = {"volatile", "resize", "reinforce", "new", "name", "map2_", "has_names", "grad_fn", "resize_as"}
  14. tensor_attrs = tensor_attrs - deprecated_apis
  15. properties = []
  16. methods = []
  17. sorted_tensor_attrs = sorted(tensor_attrs, key=lambda x: x.lower())
  18. for attr in sorted_tensor_attrs:
  19. funcs_str = funcs_template.format(op=attr)
  20. scope: Dict[str, Any] = {}
  21. execWrapper(funcs_str, globals(), scope)
  22. try:
  23. cu = torch.jit.CompilationUnit(funcs_str)
  24. except Exception as e:
  25. if "nonexistent attribute" not in repr(e):
  26. continue
  27. attr_repr = repr(getattr(tensor, attr))
  28. if "bound method" in attr_repr or "built-in method" in attr_repr:
  29. methods.append(attr)
  30. else:
  31. properties.append(attr)
  32. mapped_methods = ("\t* :meth:`~torch.Tensor." + x + r"`" for x in methods)
  33. mapped_properties = ("\t* :attr:`~torch.Tensor." + x + r"`" for x in properties)
  34. return "\n".join(mapped_methods), "\n".join(mapped_properties)
  35. def _list_unsupported_tensor_ops():
  36. header = """\n\n
  37. Unsupported Tensor Methods
  38. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  39. """
  40. methods, properties = _gen_unsupported_methods_properties()
  41. return header + "\n" + methods + """
  42. Unsupported Tensor Properties
  43. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  44. """ + "\n" + properties
  45. __doc__ = _list_unsupported_tensor_ops()