supported_ops.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. import torch.jit
  2. from torch.jit._builtins import _find_builtin
  3. import inspect
  4. import textwrap
  5. # this file is for generating documentation using sphinx autodoc
  6. # > help(torch.jit.supported_ops) will also give a nice listed of the
  7. # supported ops programmatically
  8. def _hidden(name):
  9. return name.startswith('_') and not name.startswith('__')
  10. def _emit_type(type):
  11. return str(type)
  12. def _emit_arg(indent, i, arg):
  13. v = "{} : {}".format(arg.name, _emit_type(arg.type))
  14. default = arg.default_value
  15. if default is not None:
  16. v = "{}={}".format(v, str(default))
  17. if i > 0:
  18. v = "\n{}{}".format(" " * indent, v)
  19. return v
  20. def _emit_args(indent, arguments):
  21. return ",".join(_emit_arg(indent, i, arg) for i, arg in enumerate(arguments))
  22. def _emit_ret(ret):
  23. return _emit_type(ret.type)
  24. def _emit_rets(returns):
  25. if len(returns) == 1:
  26. return _emit_ret(returns[0])
  27. return "Tuple[{}]".format(", ".join(_emit_ret(r) for r in returns))
  28. def _emit_schema(mod, name, schema, arg_start=0, padding=4):
  29. if mod is None:
  30. qualified_name = name
  31. else:
  32. qualified_name = "{}.{}".format(mod, name)
  33. schema_str = "{}({}) -> {}".format(qualified_name,
  34. _emit_args(len(qualified_name) + 1 + padding, schema.arguments[arg_start:]),
  35. _emit_rets(schema.returns))
  36. return schema_str
  37. def _get_tensor_ops():
  38. def is_tensor_method(schema):
  39. if len(schema.arguments) == 0:
  40. return False
  41. self = schema.arguments[0]
  42. if self.name != 'self':
  43. return False
  44. if not self.type.isSubtypeOf(torch._C.TensorType.get()):
  45. return False
  46. return True
  47. methods = []
  48. # discover methods
  49. for elem in dir(torch.Tensor):
  50. if not _hidden(elem):
  51. schemas = torch._C._jit_get_schemas_for_operator("aten::" + elem)
  52. for schema in schemas:
  53. if is_tensor_method(schema):
  54. methods.append(_emit_schema('Tensor', elem, schema, arg_start=1))
  55. return "Supported Tensor Methods", methods
  56. def _get_nn_functional_ops():
  57. functions = []
  58. # Iterate over torch.nn.functional
  59. mod = torch.nn.functional
  60. name = mod.__name__
  61. for elem in dir(torch.nn.functional):
  62. attr = getattr(mod, elem)
  63. if not inspect.isfunction(attr) or _hidden(elem[0]):
  64. # Ignore non-functions and internal methods
  65. continue
  66. attr_module = inspect.getmodule(attr)
  67. if not attr_module:
  68. raise RuntimeError(f'Module for {attr} not found')
  69. if 'torch.nn.functional' not in attr_module.__name__:
  70. # Ignore functions from outside torch.nn.functional
  71. continue
  72. try:
  73. # compile fn, get schema
  74. scripted = torch.jit.script(attr)
  75. schema = scripted.schema
  76. functions.append(_emit_schema(name, elem, schema))
  77. except: # noqa: B001,E722
  78. # Skip interpolate / boolean dispatched things
  79. pass
  80. # Iterate over modules that we know contain a lot of builtins
  81. for mod in torch.jit._builtins._modules_containing_builtins:
  82. name = mod.__name__
  83. for elem in dir(mod):
  84. builtin = _find_builtin(getattr(mod, elem))
  85. if builtin is not None:
  86. schemas = torch._C._jit_get_schemas_for_operator(builtin)
  87. for schema in schemas:
  88. # remove _tan but not __and__
  89. if not _hidden(elem):
  90. functions.append(_emit_schema(name, elem, schema))
  91. return "Supported PyTorch Functions", functions
  92. def _get_builtins_helper():
  93. builtins = []
  94. for fn, _builtin_name in torch.jit._builtins._builtin_ops:
  95. mod = inspect.getmodule(fn)
  96. if not hasattr(fn, '__name__'):
  97. # typing classes
  98. continue
  99. if not mod:
  100. continue
  101. if _hidden(fn.__name__) or _hidden(fn.__qualname__) or _hidden(mod.__name__):
  102. # skip internal-only methods
  103. continue
  104. if 'torch._C' in mod.__name__:
  105. continue
  106. builtins.append((fn, _builtin_name))
  107. return builtins
  108. def _is_math_fn(fn):
  109. mod = inspect.getmodule(fn)
  110. if not mod:
  111. raise RuntimeError(f'Module for {fn} not found')
  112. return mod.__name__ == 'math'
  113. def _get_torchscript_builtins():
  114. functions = []
  115. builtins = filter(lambda fn: not _is_math_fn(fn[0]), _get_builtins_helper())
  116. builtins_list = list(builtins)
  117. # Iterate over the specially added builtins
  118. for fn, _builtin_name in builtins_list:
  119. mod = inspect.getmodule(fn)
  120. if not mod:
  121. raise RuntimeError(f'Module for {fn} not found')
  122. builtin = _find_builtin(fn)
  123. if builtin is not None:
  124. schemas = torch._C._jit_get_schemas_for_operator(builtin)
  125. for schema in schemas:
  126. functions.append(_emit_schema(mod.__name__, fn.__name__, schema))
  127. pass
  128. return "TorchScript Builtin Functions", functions
  129. def _get_math_builtins():
  130. functions = []
  131. builtins = filter(lambda fn: _is_math_fn(fn[0]), _get_builtins_helper())
  132. builtins_list = list(builtins)
  133. # Iterate over the specially added builtins
  134. for fn, _builtin_name in builtins_list:
  135. mod = inspect.getmodule(fn)
  136. if not mod:
  137. raise RuntimeError(f'Module for {fn} not found')
  138. builtin = _find_builtin(fn)
  139. if builtin is not None:
  140. schemas = torch._C._jit_get_schemas_for_operator(builtin)
  141. for schema in schemas:
  142. schema_str = _emit_schema(mod.__name__, fn.__name__, schema)
  143. if 'Tensor' in schema_str:
  144. # Skip Tensor ops that have the same name as math functions
  145. # (they will show up in the tensor methods section)
  146. continue
  147. functions.append(schema)
  148. pass
  149. return "``math`` Module", functions
  150. def _get_global_builtins():
  151. # Taken from the 'globals' map in torch/csrc/jit/frontend/ir_emitter.cpp
  152. supported_builtins = [
  153. 'print',
  154. 'tuple',
  155. 'float',
  156. 'complex',
  157. 'int',
  158. 'bool',
  159. 'str',
  160. 'getattr',
  161. 'hasattr',
  162. 'isinstance',
  163. 'len',
  164. 'hex',
  165. 'oct',
  166. 'round',
  167. 'hash',
  168. 'min',
  169. 'max',
  170. 'abs',
  171. 'all',
  172. 'divmod',
  173. 'list',
  174. 'ord',
  175. 'chr',
  176. 'bin',
  177. 'range',
  178. 'zip',
  179. 'enumerate',
  180. 'sorted',
  181. ]
  182. op_renames = {
  183. 'bool': 'aten::Bool',
  184. 'int': 'aten::Int',
  185. 'float': 'aten::Float',
  186. 'complex': 'aten::Complex',
  187. 'abs': 'prim::abs',
  188. 'max': 'prim::max',
  189. 'min': 'prim::min',
  190. 'range': 'fake::does_not_exist',
  191. }
  192. schemaless_op_explanations = {
  193. 'print': 'Print any value',
  194. 'tuple': 'Lists cannot be converted to tuples with this method since their size is not statically known',
  195. 'getattr': 'Attribute name must be a literal string',
  196. 'hasattr': 'Attribute name must be a literal string',
  197. 'isinstance': 'Result is static',
  198. 'zip': 'Arguments must be iterable. See :ref:`Iterables <jit_iterables>` for details.',
  199. 'enumerate': 'Arguments must be iterable. See :ref:`Iterables <jit_iterables>` for details.',
  200. 'range': 'Can only be used as an iterator in a for loop',
  201. }
  202. magic_methods = [
  203. ('complex', '__complex__'),
  204. ('float', '__float__'),
  205. ('int', '__int__'),
  206. ('bool', '__bool__'),
  207. ('str', '__str__'),
  208. ('len', '__len__'),
  209. ('hex', '__hex__'),
  210. ('oct', '__oct__'),
  211. ]
  212. magic_methods_rows = []
  213. for fn, magic_method in magic_methods:
  214. magic_methods_rows.append('"{}", "``{}``"'.format(fn, magic_method))
  215. schematized_ops = []
  216. schemaless_ops = []
  217. for fn in supported_builtins:
  218. op_name = 'aten::{}'.format(fn)
  219. if fn in op_renames:
  220. op_name = op_renames[fn]
  221. schemas = torch._C._jit_get_schemas_for_operator(op_name)
  222. for s in schemas:
  223. schematized_ops.append(_emit_schema(None, fn, s, padding=0))
  224. if len(schemas) > 0:
  225. schematized_ops.append('')
  226. else:
  227. table_row = '":any:`{}`", "{}"'.format(fn, schemaless_op_explanations[fn])
  228. schemaless_ops.append(table_row)
  229. schematized_ops_str = '\n'.join(schematized_ops)
  230. schemaless_ops_str = '\n'.join(schemaless_ops)
  231. magic_methods_rows_str = '\n'.join(magic_methods_rows)
  232. schematized_ops_str = textwrap.indent(schematized_ops_str, '\t')
  233. schemaless_ops_str = textwrap.indent(schemaless_ops_str, '\t')
  234. magic_methods_rows_str = textwrap.indent(magic_methods_rows_str, '\t')
  235. section = """
  236. The functions in the following table are supported but do not have a static schema
  237. .. csv-table::
  238. :header: "Function", "Note"
  239. {}
  240. The following functions will use the corresponding magic method on :any:`TorchScript classes`
  241. .. csv-table::
  242. :header: "Function", "Magic Method"
  243. {}
  244. These built-in functions use the schema
  245. .. rst-class:: codeblock-height-limiter
  246. ::
  247. {}
  248. """.format(schemaless_ops_str, magic_methods_rows_str, schematized_ops_str)
  249. return "Python Built-in Functions", section
  250. def _list_supported_ops():
  251. def emit_block(decls):
  252. return '\n.. rst-class:: codeblock-height-limiter\n\n::\n\n{}\n'.format(''.join(' {}\n\n'.format(d) for d in decls))
  253. body = ''
  254. op_gathering_fns = (
  255. _get_tensor_ops,
  256. _get_nn_functional_ops,
  257. _get_torchscript_builtins,
  258. _get_global_builtins,
  259. _get_math_builtins,
  260. )
  261. for fn in op_gathering_fns:
  262. header, items = fn()
  263. link_target = header.replace('`', '').replace('-', '').lower().replace(' ', '-')
  264. if isinstance(items, str):
  265. section = "{}\n{}\n{}\n".format(header, '~' * len(header), items)
  266. else:
  267. section = "{}\n{}\n{}".format(header, '~' * len(header), emit_block(items))
  268. section = '.. _{}:'.format(link_target) + '\n\n' + section
  269. body += section
  270. return body
  271. __doc__ = _list_supported_ops()