123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473 |
- #!/usr/bin/env python3
- from typing import Any, TypeVar, Optional, Tuple, List, NamedTuple, Union, Sequence, Dict, Callable
- import textwrap
- import torch
- from torch._C import TupleType, ListType
- from torch.jit._recursive import wrap_cpp_module
- T = TypeVar("T")
- MAX_RAW_TENSOR_SIZE = 16
- class InflatableArg(NamedTuple):
- """ Helper type for bundled inputs.
- 'value' is the compressed/deflated input that is stored in the model. Value
- must be of the same type as the argument to the function that it is a deflated
- input for.
- 'fmt' is a formatable code string that is executed to inflate the compressed data into
- the appropriate input. It can use 'value' as an input to the format str. It must result
- in a value of the same type as 'value'.
- 'fmt_fn' is a formatable function code string that is executed to inflate the compressed
- data into the appropriate input. It must result in a value of the same type as 'value'.
- The function name should be the formatable part of the string.
- Note: Only top level InflatableArgs can be inflated. i.e. you cannot place
- an inflatable arg inside of some other structure. You should instead create
- an inflatable arg such that the fmt code string returns the full structure
- of your input.
- """
- value: Any
- fmt: str = "{}"
- fmt_fn: str = ""
- def bundle_inputs(
- model: torch.jit.ScriptModule,
- inputs: Union[Optional[Sequence[Tuple[Any, ...]]], Dict[Callable, Optional[Sequence[Tuple[Any, ...]]]]],
- info: Optional[Union[List[str], Dict[Callable, List[str]]]] = None,
- *,
- _receive_inflate_expr: Optional[List[str]] = None,
- ) -> torch.jit.ScriptModule:
- """Creates and returns a copy of the specified model with inputs attached. The original model is
- not mutated or changed in any way.
- Models with bundled inputs can be invoked in a uniform manner by
- benchmarking and code coverage tools.
- If inputs is passed in as a list then the inputs will be bundled for 'forward'.
- If inputs is instead passed in as a map then all the methods specified in the map
- will have their corresponding inputs bundled. Info should match watchever type is
- chosen for the inputs.
- The returned model will support the following methods:
- `get_all_bundled_inputs_for_<function_name>() -> List[Tuple[Any, ...]]`
- Returns a list of tuples suitable for passing to the model like
- `for inp in model.get_all_bundled_inputs_for_foo(): model.foo(*inp)`
- `get_bundled_inputs_functions_and_info() -> Dict[str, Dict[str: List[str]]]`
- Returns a dictionary mapping function names to a metadata dictionary.
- This nested dictionary maps preset strings like:
- 'get_inputs_function_name' -> the name of a function attribute in this model that can be
- run to get back a list of inputs corresponding to that function.
- 'info' -> the user provided extra information about the bundled inputs
- If forward has bundled inputs then these following functions will also be defined on the returned module:
- `get_all_bundled_inputs() -> List[Tuple[Any, ...]]`
- Returns a list of tuples suitable for passing to the model like
- `for inp in model.get_all_bundled_inputs(): model(*inp)`
- `get_num_bundled_inputs() -> int`
- Equivalent to `len(model.get_all_bundled_inputs())`,
- but slightly easier to call from C++.
- Inputs can be specified in one of two ways:
- - The model can define `_generate_bundled_inputs_for_<function_name>`.
- If the user chooses this method inputs[<function>] should map to None
- - The `inputs` argument to this function can be a dictionary mapping functions to a
- list of inputs, of the same form that will be returned by get_all_bundled_inputs_for_<function_name>.
- Alternatively if only bundling inputs for forward the map can be omitted and a singular list of inputs
- can be provided instead.
- The type of the inputs is List[Tuple[Any, ...]]. The outer list corresponds with a
- list of inputs, the inner tuple is the list of args that together make up one input.
- For inputs of functions that take one arg, this will be a tuple of length one. The Any, ...
- is the actual data that makes up the args, e.g. a tensor.
- Info is an optional parameter that maps functions to a list of strings providing extra information about that
- function's bundled inputs. Alternatively if only bundling inputs for forward the map can be omitted and
- a singular list of information can be provided instead. This could be descriptions, expected outputs, etc.
- - Ex: info={model.forward : ['man eating icecream', 'an airplane', 'a dog']}
- This function will attempt to optimize arguments so that (e.g.)
- arguments like `torch.zeros(1000)` will be represented compactly.
- Only top-level arguments will be optimized.
- Tensors in lists or tuples will not.
- """
- if not isinstance(model, torch.jit.ScriptModule):
- raise Exception("Only ScriptModule is supported.")
- ignored_methods, ignored_attrs = _get_bundled_inputs_attributes_and_methods(model)
- clone = torch._C._hack_do_not_use_clone_module_with_class( # type: ignore[attr-defined]
- model._c,
- ignored_methods,
- ignored_attrs,
- )
- # The above cloning function returns a torch._C.scriptmodule and we need a torch.jit.scriptmodule.
- # Fortunately theres a function in _recursive that does exactly that conversion.
- cloned_module = wrap_cpp_module(clone)
- if isinstance(inputs, dict):
- assert(isinstance(info, dict) or info is None)
- augment_many_model_functions_with_bundled_inputs(cloned_module, inputs, _receive_inflate_expr, info)
- else:
- assert(isinstance(info, list) or info is None)
- augment_model_with_bundled_inputs(cloned_module, inputs, _receive_inflate_expr, info)
- return cloned_module
- def augment_model_with_bundled_inputs(
- model: torch.jit.ScriptModule,
- inputs: Optional[Sequence[Tuple[Any, ...]]] = None,
- _receive_inflate_expr: Optional[List[str]] = None, # For debugging.
- info: Optional[List[str]] = None, # Optional argument to provide info about forward or its inputs
- skip_size_check=False,
- ) -> None:
- """ Add bundled sample inputs to a model for the forward function.
- Models with bundled inputs can be invoked in a uniform manner by
- benchmarking and code coverage tools.
- Augmented models will support the following methods:
- `get_all_bundled_inputs() -> List[Tuple[Any, ...]]`
- Returns a list of tuples suitable for passing to the model like
- `for inp in model.get_all_bundled_inputs(): model(*inp)`
- `get_num_bundled_inputs() -> int`
- Equivalent to `len(model.get_all_bundled_inputs())`,
- but slightly easier to call from C++.
- `get_bundled_inputs_functions_and_info() -> Dict[str, Dict[str: List[str]]]`
- Returns a dictionary mapping function names to a metadata dictionary.
- This nested dictionary maps preset strings like:
- 'get_inputs_function_name' -> the name of a function attribute in this model that can be
- run to get back a list of inputs corresponding to that function.
- 'info' -> the user provided extra information about the bundled inputs
- Inputs can be specified in one of two ways:
- - The model can define `_generate_bundled_inputs_for_forward`.
- If the user chooses this method inputs should be None
- - `inputs` is a list of inputs of form List[Tuple[Any, ...]]. A list of tuples where the elements
- of each tuple are the args that make up one input.
- """
- if not isinstance(model, torch.jit.ScriptModule):
- raise Exception("Only ScriptModule is supported.")
- forward: Callable = model.forward
- # Sometimes forward won't have a name attached so just in case
- if not hasattr(forward, "__name__"):
- forward.__name__ = 'forward'
- augment_many_model_functions_with_bundled_inputs(
- model,
- inputs={forward : inputs},
- _receive_inflate_expr=_receive_inflate_expr,
- info={forward : info} if info else None,
- skip_size_check=skip_size_check,
- )
- def augment_many_model_functions_with_bundled_inputs(
- model: torch.jit.ScriptModule,
- inputs: Dict[Callable, Optional[Sequence[Tuple[Any, ...]]]],
- _receive_inflate_expr: Optional[List[str]] = None, # For debugging.
- info: Optional[Dict[Callable, List[str]]] = None, # Optional argument to provide info about the function or its inputs
- skip_size_check=False,
- ) -> None:
- """Add bundled sample inputs to a model for an arbitrary list of public functions.
- Models with bundled inputs can be invoked in a uniform manner by
- benchmarking and code coverage tools.
- Augmented models will support the following methods:
- `get_all_bundled_inputs_for_<function_name>() -> List[Tuple[Any, ...]]`
- Returns a list of tuples suitable for passing to the model like
- `for inp in model.get_all_bundled_inputs_for_foo(): model.foo(*inp)`
- `get_bundled_inputs_functions_and_info() -> Dict[str, Dict[str: List[str]]]`
- Returns a dictionary mapping function names to a metadata dictionary.
- This nested dictionary maps preset strings like:
- 'get_inputs_function_name' -> the name of a function attribute in this model that can be
- run to get back a list of inputs corresponding to that function.
- 'info' -> the user provided extra information about the bundled inputs
- If forward has bundled inputs then these following functions are also defined:
- `get_all_bundled_inputs() -> List[Tuple[Any, ...]]`
- Returns a list of tuples suitable for passing to the model like
- `for inp in model.get_all_bundled_inputs(): model(*inp)`
- `get_num_bundled_inputs() -> int`
- Equivalent to `len(model.get_all_bundled_inputs())`,
- but slightly easier to call from C++.
- Inputs can be specified in one of two ways:
- - The model can define `_generate_bundled_inputs_for_<function_name>`.
- If the user chooses this method inputs[<function>] should map to None
- - The `inputs` argument to this function can be a dictionary mapping functions to a
- list of inputs, of the same form that will be returned by get_all_bundled_inputs_for_<function_name>.
- The type of the inputs is List[Tuple[Any, ...]]. The outer list corresponds with a
- list of inputs, the inner tuple is the list of args that together make up one input.
- For inputs of functions that take one arg, this will be a tuple of length one. The Any, ...
- is the actual data that makes up the args, e.g. a tensor.
- Info is an optional parameter that maps functions to a list of strings providing extra information about that
- function's bundled inputs. This could be descriptions, expected outputs, etc.
- - Ex: info={model.forward : ['man eating icecream', 'an airplane', 'a dog']}
- This function will attempt to optimize arguments so that (e.g.)
- arguments like `torch.zeros(1000)` will be represented compactly.
- Only top-level arguments will be optimized.
- Tensors in lists or tuples will not.
- """
- if not isinstance(model, torch.jit.ScriptModule):
- raise Exception("Only ScriptModule is supported.")
- if not inputs:
- raise Exception("Please provide inputs for at least 1 function")
- if hasattr(model, "get_all_bundled_inputs") or hasattr(model, "get_bundled_inputs_functions_and_info"):
- raise Exception(
- "Models can only be augmented with bundled inputs once. "
- "This Model seems to have already been augmented with "
- "bundled inputs. Please start afresh with one that "
- "doesn't have bundled inputs.",
- )
- get_bundled_inputs_functions_and_info_template = ""
- for function, input_list in inputs.items():
- if hasattr(function, "__name__"):
- function_name = function.__name__
- else:
- if hasattr(function, "name"):
- function_name = function.name # type: ignore[attr-defined]
- else:
- raise Exception(
- 'At least one of your functions has no attribute name please ensure all have one. m.foo.name = "foo"')
- if input_list is not None and not isinstance(input_list, Sequence):
- raise TypeError("Error inputs for function {0} is not a Sequence".format(function_name))
- function_arg_types = [arg.type for arg in function.schema.arguments[1:]] # type: ignore[attr-defined]
- deflated_inputs_type: ListType = ListType(TupleType(function_arg_types))
- model._c._register_attribute("_bundled_inputs_deflated_{name}".format(name=function_name), deflated_inputs_type, [])
- if hasattr(model, "_generate_bundled_inputs_for_" + function_name):
- if input_list is not None:
- raise Exception(
- "inputs[{name}] is not None, but _generate_bundled_inputs_for_{name} is already defined".format(
- name=function_name
- )
- )
- # Model author already defined _generate_bundled_inputs_for_<function_name>.
- elif input_list is None or len(input_list) == 0:
- raise Exception(
- "inputs for {name} must be specified if _generate_bundled_inputs_for_{name} is not already defined".format(
- name=function_name,
- )
- )
- else:
- # Iterate over the inputs and args in each input.
- # Accumulate `deflated_inputs` as (possibly) compressed values
- # and `parts` to be joined into the expression that unpacks them.
- deflated_inputs = []
- parts = []
- for inp_idx, args in enumerate(input_list):
- if not isinstance(args, Tuple) and not isinstance(args, List): # type: ignore[arg-type]
- raise TypeError(
- "Error bundled input for function {0} idx: {1} is not a Tuple or a List".format(function_name, inp_idx)
- )
- deflated_args = []
- parts.append("(")
- for arg_idx, arg in enumerate(args):
- inflate_helper_fn_name = _get_inflate_helper_fn_name(arg_idx, inp_idx, function_name)
- deflated, inflater, helper_definition = _inflate_expr(
- arg,
- f"deflated[{inp_idx}][{arg_idx}]",
- inflate_helper_fn_name,
- skip_size_check=skip_size_check,
- )
- deflated_args.append(deflated)
- parts.append(f" {inflater},")
- if helper_definition:
- model.define(textwrap.dedent(helper_definition))
- deflated_inputs.append(tuple(deflated_args))
- parts.append("),")
- parts.append("")
- expr = "\n".join(parts)
- # Back-channel return this expr for debugging.
- if _receive_inflate_expr is not None:
- _receive_inflate_expr.append(expr)
- setattr(model, "_bundled_inputs_deflated_{name}".format(name=function_name), deflated_inputs)
- definition = textwrap.dedent("""
- def _generate_bundled_inputs_for_{name}(self):
- deflated = self._bundled_inputs_deflated_{name}
- return [
- {expr}
- ]
- """).format(expr=expr, name=function_name)
- model.define(definition)
- # Define get_all_bundled_inputs_for_<function_name> that caches the generated inputs.
- model.define(textwrap.dedent("""
- def get_all_bundled_inputs_for_{name}(self):
- all_inputs = self._generate_bundled_inputs_for_{name}()
- assert all_inputs is not None
- return all_inputs
- """).format(name=function_name))
- # Add to the high level helper methods
- inputs_info = repr(info[function]) if info and function in info else '[]'
- get_bundled_inputs_functions_and_info_template += """
- temp_dict : Dict[str,List[str]] = {{}}
- info: List[str] = {info}
- temp_dict['info'] = info
- temp_dict['get_inputs_function_name'] = ['get_all_bundled_inputs_for_{name}']
- all_inputs['{name}'] = temp_dict
- """.format(
- name=function_name,
- info=inputs_info,
- )
- # To ensure backwards compatibility and a streamlined api for forward these wrappers are provided
- if function_name == 'forward':
- model.define(textwrap.dedent("""
- def get_all_bundled_inputs(self):
- return self.get_all_bundled_inputs_for_forward()
- """))
- model.define(textwrap.dedent("""
- def get_num_bundled_inputs(self):
- return len(self.get_all_bundled_inputs_for_forward())
- """))
- # Define some high level helper methods that act on all bundled inputs
- model.define(textwrap.dedent("""
- def get_bundled_inputs_functions_and_info(self):
- all_inputs : Dict[str, Dict[str,List[str]]] = {{}}
- {template}
- return all_inputs
- """.format(template=get_bundled_inputs_functions_and_info_template)))
- def _inflate_expr(
- arg: T, ref: str, inflate_helper_fn_name: str, skip_size_check: bool = False
- ) -> Tuple[Union[T, torch.Tensor], str, Optional[str]]:
- # Allow custom inflation expressions any object.
- # For example, calling custom image-decoding ops.
- # Or just use "{}" as the format string to ignore size limits.
- if isinstance(arg, InflatableArg):
- if arg.fmt_fn:
- if arg.fmt not in ["{}", ""]:
- raise Exception(
- f"Bundled input argument at position '{ref}' has "
- f"both arg.fmt_fn => \n{arg.fmt_fn} "
- f"\n and arg.fmt => {arg.fmt}. "
- "Please choose `arg.fmt` if the deflater is straightforward or "
- "`arg.fmt_fn` if you need a function."
- )
- helper_definition = arg.fmt_fn.format(inflate_helper_fn_name)
- expr = f"self.{inflate_helper_fn_name}({ref})"
- return arg.value, expr, helper_definition
- else:
- return arg.value, arg.fmt.format(ref), None
- if isinstance(arg, torch.Tensor):
- # Small-storage tensors can just be saved directly.
- if arg._typed_storage().size() <= MAX_RAW_TENSOR_SIZE or skip_size_check:
- return arg, ref, None
- # Small contiguous tensors can be cloned to have small storage.
- # TODO: Should we do this even for non-contiguous tensors?
- if arg.is_contiguous() and arg.numel() <= MAX_RAW_TENSOR_SIZE:
- return arg.clone(), ref, None
- # Example inputs commonly come from torch.zeros, torch.ones, or torch.full.
- # These can be represented compactly.
- for fmt in [torch.contiguous_format, torch.channels_last]:
- if arg.is_contiguous(memory_format=fmt) and (arg == arg.flatten()[0]).all().item():
- return (arg.flatten()[0].clone().expand(*arg.size()),
- f"{ref}.contiguous(memory_format={fmt})", None)
- # Prevent big tensors from being bundled by default.
- # TODO: Provide more useful diagnostics.
- raise Exception(
- f"Bundled input argument at position '{ref}' is "
- f"a tensor with storage size {arg._typed_storage().size()}. "
- f"You probably don't want to bundle this as an input. "
- )
- else:
- return arg, ref, None
- def _get_bundled_inputs_attributes_and_methods(script_module: torch.jit.ScriptModule) -> Tuple[List[str], List[str]]:
- methods: List[str] = []
- attributes: List[str] = []
- # Has bundled inputs for forward
- if hasattr(script_module, 'get_all_bundled_inputs'):
- methods.append('get_all_bundled_inputs')
- methods.append('get_num_bundled_inputs')
- methods.append('run_on_bundled_input')
- if hasattr(script_module, 'get_bundled_inputs_functions_and_info'):
- methods.append('get_bundled_inputs_functions_and_info')
- all_info = script_module.get_bundled_inputs_functions_and_info()
- for function_name in all_info:
- methods.append("get_all_bundled_inputs_for_" + function_name)
- methods.append("_generate_bundled_inputs_for_" + function_name)
- attributes.append("_bundled_inputs_deflated_" + function_name)
- bundled_inputs_fn = getattr(
- script_module,
- f"get_all_bundled_inputs_for_{function_name}"
- )
- num_bundled_inputs: int = len(bundled_inputs_fn())
- # Check inflate helper functions for each function, argument and bundled input
- func = getattr(script_module, function_name)
- for arg_idx in range(len(func.schema.arguments) - 1):
- for input_idx in range(num_bundled_inputs):
- helper_fn_name = _get_inflate_helper_fn_name(
- arg_idx=arg_idx,
- input_idx=input_idx,
- function_name=function_name
- )
- # if the arg has an InflatableArg with fmt_fn, add the helper function name
- if hasattr(script_module, helper_fn_name):
- methods.append(helper_fn_name)
- return (methods, attributes)
- def _get_inflate_helper_fn_name(
- arg_idx: int,
- input_idx: int,
- function_name: str,
- ) -> str:
- return f"_inflate_helper_for_{function_name}_input_{input_idx}_arg_{arg_idx}"
- def bundle_randn(*size, dtype=None):
- """Generate a tensor that will be inflated with torch.randn."""
- stub = torch.zeros(1, dtype=dtype).expand(*size)
- return InflatableArg(value=stub, fmt="torch.randn_like({})")
- def bundle_large_tensor(t):
- """Wrap a tensor to allow bundling regardless of size."""
- return InflatableArg(value=t, fmt="{}")
|