annotate_getitem_nodes.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. import operator
  2. import torch
  3. def annotate_getitem_nodes(graph: torch.fx.Graph) -> None:
  4. """
  5. Annotate the type of getitem nodes, inferred from the type of sequence node.
  6. If sequence node is not annotated with a type, do nothing.
  7. Currently support getitem nodes from Tuple, List, and NamedTuple sequence node.
  8. This is helpful since annotations on local names within function are lost during FX transforms.
  9. Adding back known type annotation for getitem nodes to improve jit scriptability.
  10. Args:
  11. graph (Graph): The graph to be annotated
  12. """
  13. for node in graph.nodes:
  14. if node.target == operator.getitem:
  15. sequence_node, index_node = node.args
  16. if not sequence_node.type:
  17. continue
  18. # container types
  19. if hasattr(sequence_node.type, "_name"):
  20. parameterized_types = sequence_node.type.__args__
  21. if sequence_node.type._name == "Tuple":
  22. if len(parameterized_types) == 2 and isinstance(
  23. parameterized_types[1], type(...)
  24. ):
  25. node.type = parameterized_types[0]
  26. else:
  27. assert len(parameterized_types) > index_node
  28. node_type = parameterized_types[index_node]
  29. node.type = node_type
  30. elif sequence_node.type._name == "List":
  31. assert len(parameterized_types) == 1
  32. node.type = parameterized_types[0]
  33. # NamedTuple type
  34. elif hasattr(sequence_node.type, "__annotations__"):
  35. sequence_node_field_types = sequence_node.type.__annotations__
  36. field_name = sequence_node.type._fields[index_node]
  37. node.type = sequence_node_field_types[field_name]