cudagraphs.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. import torch
  2. from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
  3. from torch.fx.passes.operator_support import OperatorSupport
  4. from torch.fx.passes.tools_common import CALLABLE_NODE_OPS
  5. from torch.fx.passes.fake_tensor_prop import FakeTensorProp
  6. from torch.utils._pytree import tree_map
  7. import operator
  8. class CudaGraphsSupport(OperatorSupport):
  9. # TODO: why is submodules passed here
  10. def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
  11. if node.op not in CALLABLE_NODE_OPS:
  12. return False
  13. if node.target in [torch.ops.aten.embedding_dense_backward.default]:
  14. return False
  15. if node.target in [operator.getitem]:
  16. return True
  17. found_not_cuda = False
  18. def meta_fk(meta):
  19. return meta["val"] if "val" in meta else meta["fake_result"]
  20. def find_not_cuda(t):
  21. nonlocal found_not_cuda
  22. if isinstance(t, torch.Tensor) and t.device.type != 'cuda':
  23. found_not_cuda = True
  24. for n in node.all_input_nodes:
  25. tree_map(find_not_cuda, meta_fk(n.meta))
  26. tree_map(find_not_cuda, meta_fk(node.meta))
  27. # NB: factory function is accounted for because the result would be
  28. # cpu or cuda
  29. return not found_not_cuda
  30. def partition_cudagraphs(gm, inputs):
  31. """
  32. Partition an FX graph into sub-GraphModules that can be validly run under
  33. CUDA graphs. For a subgraph to be runnable under CUDA, all of the operations
  34. must involve CUDA tensors only/
  35. """
  36. FakeTensorProp(gm).propagate(*inputs)
  37. supported_ops = CudaGraphsSupport()
  38. # TODO: single node partition may be wrong due to the pessimization
  39. # from copying in and out the data. Check in benchmarks, perhaps
  40. partitioner = CapabilityBasedPartitioner(gm, supported_ops, allows_single_node_partition=True)
  41. partitions = partitioner.propose_partitions()
  42. fused_graph = partitioner.fuse_partitions(partitions)
  43. return fused_graph