123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324 |
- import torch.jit
- from torch.jit._builtins import _find_builtin
- import inspect
- import textwrap
- # this file is for generating documentation using sphinx autodoc
- # > help(torch.jit.supported_ops) will also give a nice listed of the
- # supported ops programmatically
- def _hidden(name):
- return name.startswith('_') and not name.startswith('__')
- def _emit_type(type):
- return str(type)
- def _emit_arg(indent, i, arg):
- v = "{} : {}".format(arg.name, _emit_type(arg.type))
- default = arg.default_value
- if default is not None:
- v = "{}={}".format(v, str(default))
- if i > 0:
- v = "\n{}{}".format(" " * indent, v)
- return v
- def _emit_args(indent, arguments):
- return ",".join(_emit_arg(indent, i, arg) for i, arg in enumerate(arguments))
- def _emit_ret(ret):
- return _emit_type(ret.type)
- def _emit_rets(returns):
- if len(returns) == 1:
- return _emit_ret(returns[0])
- return "Tuple[{}]".format(", ".join(_emit_ret(r) for r in returns))
- def _emit_schema(mod, name, schema, arg_start=0, padding=4):
- if mod is None:
- qualified_name = name
- else:
- qualified_name = "{}.{}".format(mod, name)
- schema_str = "{}({}) -> {}".format(qualified_name,
- _emit_args(len(qualified_name) + 1 + padding, schema.arguments[arg_start:]),
- _emit_rets(schema.returns))
- return schema_str
- def _get_tensor_ops():
- def is_tensor_method(schema):
- if len(schema.arguments) == 0:
- return False
- self = schema.arguments[0]
- if self.name != 'self':
- return False
- if not self.type.isSubtypeOf(torch._C.TensorType.get()):
- return False
- return True
- methods = []
- # discover methods
- for elem in dir(torch.Tensor):
- if not _hidden(elem):
- schemas = torch._C._jit_get_schemas_for_operator("aten::" + elem)
- for schema in schemas:
- if is_tensor_method(schema):
- methods.append(_emit_schema('Tensor', elem, schema, arg_start=1))
- return "Supported Tensor Methods", methods
- def _get_nn_functional_ops():
- functions = []
- # Iterate over torch.nn.functional
- mod = torch.nn.functional
- name = mod.__name__
- for elem in dir(torch.nn.functional):
- attr = getattr(mod, elem)
- if not inspect.isfunction(attr) or _hidden(elem[0]):
- # Ignore non-functions and internal methods
- continue
- attr_module = inspect.getmodule(attr)
- if not attr_module:
- raise RuntimeError(f'Module for {attr} not found')
- if 'torch.nn.functional' not in attr_module.__name__:
- # Ignore functions from outside torch.nn.functional
- continue
- try:
- # compile fn, get schema
- scripted = torch.jit.script(attr)
- schema = scripted.schema
- functions.append(_emit_schema(name, elem, schema))
- except: # noqa: B001,E722
- # Skip interpolate / boolean dispatched things
- pass
- # Iterate over modules that we know contain a lot of builtins
- for mod in torch.jit._builtins._modules_containing_builtins:
- name = mod.__name__
- for elem in dir(mod):
- builtin = _find_builtin(getattr(mod, elem))
- if builtin is not None:
- schemas = torch._C._jit_get_schemas_for_operator(builtin)
- for schema in schemas:
- # remove _tan but not __and__
- if not _hidden(elem):
- functions.append(_emit_schema(name, elem, schema))
- return "Supported PyTorch Functions", functions
- def _get_builtins_helper():
- builtins = []
- for fn, _builtin_name in torch.jit._builtins._builtin_ops:
- mod = inspect.getmodule(fn)
- if not hasattr(fn, '__name__'):
- # typing classes
- continue
- if not mod:
- continue
- if _hidden(fn.__name__) or _hidden(fn.__qualname__) or _hidden(mod.__name__):
- # skip internal-only methods
- continue
- if 'torch._C' in mod.__name__:
- continue
- builtins.append((fn, _builtin_name))
- return builtins
- def _is_math_fn(fn):
- mod = inspect.getmodule(fn)
- if not mod:
- raise RuntimeError(f'Module for {fn} not found')
- return mod.__name__ == 'math'
- def _get_torchscript_builtins():
- functions = []
- builtins = filter(lambda fn: not _is_math_fn(fn[0]), _get_builtins_helper())
- builtins_list = list(builtins)
- # Iterate over the specially added builtins
- for fn, _builtin_name in builtins_list:
- mod = inspect.getmodule(fn)
- if not mod:
- raise RuntimeError(f'Module for {fn} not found')
- builtin = _find_builtin(fn)
- if builtin is not None:
- schemas = torch._C._jit_get_schemas_for_operator(builtin)
- for schema in schemas:
- functions.append(_emit_schema(mod.__name__, fn.__name__, schema))
- pass
- return "TorchScript Builtin Functions", functions
- def _get_math_builtins():
- functions = []
- builtins = filter(lambda fn: _is_math_fn(fn[0]), _get_builtins_helper())
- builtins_list = list(builtins)
- # Iterate over the specially added builtins
- for fn, _builtin_name in builtins_list:
- mod = inspect.getmodule(fn)
- if not mod:
- raise RuntimeError(f'Module for {fn} not found')
- builtin = _find_builtin(fn)
- if builtin is not None:
- schemas = torch._C._jit_get_schemas_for_operator(builtin)
- for schema in schemas:
- schema_str = _emit_schema(mod.__name__, fn.__name__, schema)
- if 'Tensor' in schema_str:
- # Skip Tensor ops that have the same name as math functions
- # (they will show up in the tensor methods section)
- continue
- functions.append(schema)
- pass
- return "``math`` Module", functions
- def _get_global_builtins():
- # Taken from the 'globals' map in torch/csrc/jit/frontend/ir_emitter.cpp
- supported_builtins = [
- 'print',
- 'tuple',
- 'float',
- 'complex',
- 'int',
- 'bool',
- 'str',
- 'getattr',
- 'hasattr',
- 'isinstance',
- 'len',
- 'hex',
- 'oct',
- 'round',
- 'hash',
- 'min',
- 'max',
- 'abs',
- 'all',
- 'divmod',
- 'list',
- 'ord',
- 'chr',
- 'bin',
- 'range',
- 'zip',
- 'enumerate',
- 'sorted',
- ]
- op_renames = {
- 'bool': 'aten::Bool',
- 'int': 'aten::Int',
- 'float': 'aten::Float',
- 'complex': 'aten::Complex',
- 'abs': 'prim::abs',
- 'max': 'prim::max',
- 'min': 'prim::min',
- 'range': 'fake::does_not_exist',
- }
- schemaless_op_explanations = {
- 'print': 'Print any value',
- 'tuple': 'Lists cannot be converted to tuples with this method since their size is not statically known',
- 'getattr': 'Attribute name must be a literal string',
- 'hasattr': 'Attribute name must be a literal string',
- 'isinstance': 'Result is static',
- 'zip': 'Arguments must be iterable. See :ref:`Iterables <jit_iterables>` for details.',
- 'enumerate': 'Arguments must be iterable. See :ref:`Iterables <jit_iterables>` for details.',
- 'range': 'Can only be used as an iterator in a for loop',
- }
- magic_methods = [
- ('complex', '__complex__'),
- ('float', '__float__'),
- ('int', '__int__'),
- ('bool', '__bool__'),
- ('str', '__str__'),
- ('len', '__len__'),
- ('hex', '__hex__'),
- ('oct', '__oct__'),
- ]
- magic_methods_rows = []
- for fn, magic_method in magic_methods:
- magic_methods_rows.append('"{}", "``{}``"'.format(fn, magic_method))
- schematized_ops = []
- schemaless_ops = []
- for fn in supported_builtins:
- op_name = 'aten::{}'.format(fn)
- if fn in op_renames:
- op_name = op_renames[fn]
- schemas = torch._C._jit_get_schemas_for_operator(op_name)
- for s in schemas:
- schematized_ops.append(_emit_schema(None, fn, s, padding=0))
- if len(schemas) > 0:
- schematized_ops.append('')
- else:
- table_row = '":any:`{}`", "{}"'.format(fn, schemaless_op_explanations[fn])
- schemaless_ops.append(table_row)
- schematized_ops_str = '\n'.join(schematized_ops)
- schemaless_ops_str = '\n'.join(schemaless_ops)
- magic_methods_rows_str = '\n'.join(magic_methods_rows)
- schematized_ops_str = textwrap.indent(schematized_ops_str, '\t')
- schemaless_ops_str = textwrap.indent(schemaless_ops_str, '\t')
- magic_methods_rows_str = textwrap.indent(magic_methods_rows_str, '\t')
- section = """
- The functions in the following table are supported but do not have a static schema
- .. csv-table::
- :header: "Function", "Note"
- {}
- The following functions will use the corresponding magic method on :any:`TorchScript classes`
- .. csv-table::
- :header: "Function", "Magic Method"
- {}
- These built-in functions use the schema
- .. rst-class:: codeblock-height-limiter
- ::
- {}
- """.format(schemaless_ops_str, magic_methods_rows_str, schematized_ops_str)
- return "Python Built-in Functions", section
- def _list_supported_ops():
- def emit_block(decls):
- return '\n.. rst-class:: codeblock-height-limiter\n\n::\n\n{}\n'.format(''.join(' {}\n\n'.format(d) for d in decls))
- body = ''
- op_gathering_fns = (
- _get_tensor_ops,
- _get_nn_functional_ops,
- _get_torchscript_builtins,
- _get_global_builtins,
- _get_math_builtins,
- )
- for fn in op_gathering_fns:
- header, items = fn()
- link_target = header.replace('`', '').replace('-', '').lower().replace(' ', '-')
- if isinstance(items, str):
- section = "{}\n{}\n{}\n".format(header, '~' * len(header), items)
- else:
- section = "{}\n{}\n{}".format(header, '~' * len(header), emit_block(items))
- section = '.. _{}:'.format(link_target) + '\n\n' + section
- body += section
- return body
- __doc__ = _list_supported_ops()
|