graph_manipulation.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. from typing import Any, Dict, List, NamedTuple, Optional
  2. import torch
  3. from torch.fx._compatibility import compatibility
  4. from torch.fx.graph import Graph
  5. from torch.fx.graph_module import GraphModule
  6. from torch.fx.node import (
  7. map_arg,
  8. Node,
  9. Target,
  10. )
  11. from torch.fx.passes.shape_prop import ShapeProp
  12. __all__ = ['replace_target_nodes_with', 'size_bytes', 'get_size_of_all_nodes', 'get_tensor_meta',
  13. 'get_size_of_node']
  14. @compatibility(is_backward_compatible=False)
  15. def replace_target_nodes_with(
  16. fx_module: GraphModule,
  17. old_op: str,
  18. old_target: Target,
  19. new_op: str,
  20. new_target: Target,
  21. ):
  22. """Modifies all nodes in fx_module.graph.nodes which match the specified op code and target,
  23. and updates them to match the new op code and target"""
  24. new_graph = Graph()
  25. val_map: Dict[Node, Node] = {}
  26. for node in fx_module.graph.nodes:
  27. if node.op == old_op and node.target == old_target:
  28. args = map_arg(node.args, lambda n: val_map[n])
  29. kwargs = map_arg(node.kwargs, lambda n: val_map[n])
  30. assert isinstance(args, tuple)
  31. assert isinstance(kwargs, dict)
  32. val_map[node] = new_graph.create_node(
  33. new_op, new_target, args, kwargs, node.name
  34. )
  35. else:
  36. val_map[node] = new_graph.node_copy(node, lambda n: val_map[n])
  37. fx_module.graph = new_graph
  38. @compatibility(is_backward_compatible=False)
  39. class size_bytes(NamedTuple):
  40. output_size: int
  41. total_size: int
  42. @compatibility(is_backward_compatible=False)
  43. def get_size_of_all_nodes(
  44. fx_module: GraphModule, args: Optional[List[torch.Tensor]] = None
  45. ) -> None:
  46. """Given a fx graph module, update each node with its total size (weights + bias + output)
  47. and its output_size(output). For a non-module node, the total size is the output size.
  48. return total size"""
  49. if args is not None:
  50. # Mark shape and dtype for each node (node.shape and node.dtype)
  51. ShapeProp(fx_module).propagate(*args)
  52. # Calculate the total size of the whole fx graph
  53. total_size_of_graph = 0.0
  54. for node in fx_module.graph.nodes:
  55. if node.op == "output":
  56. break
  57. node.size_bytes = get_size_of_node(fx_module, node)
  58. return
  59. @compatibility(is_backward_compatible=False)
  60. def get_tensor_meta(node: Node) -> Any:
  61. tensor_meta = node.meta.get("tensor_meta")
  62. if not tensor_meta:
  63. raise RuntimeError(
  64. f"Node {node} has no tensor metadata associated with it! "
  65. f"Check that shape propagation has run."
  66. )
  67. return tensor_meta
  68. @compatibility(is_backward_compatible=False)
  69. def get_size_of_node(fx_module: GraphModule, node: Node) -> size_bytes:
  70. """Given a node with node.dtype and node.shape, return its total size and its output size.
  71. total_size = weights + bias + output_size
  72. """
  73. # Total num of elements
  74. total_num_of_elems = 0
  75. # For a module, conside all parameters
  76. if node.op == "call_module":
  77. submodule_dict = dict(fx_module.named_modules())
  78. submodule = submodule_dict[node.target]
  79. parameters = submodule.named_parameters()
  80. # Parameters are named tuples
  81. for name, p in parameters:
  82. total_num_of_elems += p.numel()
  83. # Don't forget the output size
  84. # node.shape is the shape of this node's output
  85. tensor_meta = get_tensor_meta(node)
  86. output_elem = tensor_meta.shape.numel()
  87. total_num_of_elems += output_elem
  88. # Assume for now if it's quantized then it's qint8 or quint8
  89. if tensor_meta.is_quantized:
  90. size_per_elem_bytes = torch._empty_affine_quantized(
  91. [], dtype=tensor_meta.dtype
  92. ).element_size()
  93. else:
  94. size_per_elem_bytes = torch.tensor([], dtype=tensor_meta.dtype).element_size()
  95. total_size = size_per_elem_bytes * total_num_of_elems
  96. output_size = size_per_elem_bytes * output_elem
  97. return size_bytes(output_size, total_size)