_property_propagation.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. """
  2. Tools to help with tensor property propagation.
  3. This is not intended to be imported directly; please use the exposed
  4. functionalities in `torch.jit`.
  5. """
  6. from typing import Any, List
  7. import torch
  8. from torch import TensorType
  9. from torch._C import Graph
  10. def apply_input_props_using_example(graph: Graph, example_input: List[Any]):
  11. """
  12. Applies properties for each tensor in the graph inputs
  13. using the example supplied.
  14. """
  15. graph_inputs = list(graph.inputs())
  16. if len(graph_inputs) == 0:
  17. return
  18. # Strip self args off for methods
  19. in_0 = graph_inputs[0]
  20. if isinstance(in_0.type(), torch._C.ClassType) and in_0.debugName() == "self":
  21. graph_inputs = graph_inputs[1:]
  22. if not len(graph_inputs) == len(example_input):
  23. raise RuntimeError(
  24. "Number of inputs in graph does not match number of inputs in the example")
  25. for i, (graph_i, example_i) in enumerate(zip(graph_inputs, example_input)):
  26. if example_i is None:
  27. continue # Skip the type check
  28. if isinstance(example_i, torch.Tensor) != isinstance(graph_i.type(), TensorType):
  29. raise RuntimeError(f"Input {i} does not match type of example", graph_i, example_i)
  30. if isinstance(example_i, torch.Tensor):
  31. graph_i.setType(TensorType.create_from_tensor(example_i)) # type: ignore[arg-type]