executor.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. from typing import Callable, Optional
  2. from torch._prims.context import NvfuserPrimsMode, TorchRefsMode
  3. from torch._prims.nvfuser_executor import nvfuser_execute, nvfuser_execute_partitioned
  4. from torch.fx import GraphModule
  5. from torch.fx.experimental.proxy_tensor import make_fx, wrapper_and_args_for_make_fx
  6. def execute(
  7. gm: GraphModule,
  8. *args,
  9. executor: str = "aten",
  10. executor_parameters: Optional[dict] = None,
  11. ):
  12. """
  13. Prototype ATen executor.
  14. Just executes the context's graph.
  15. """
  16. if executor == "aten":
  17. return gm.forward(*args)
  18. elif executor == "nvfuser":
  19. return nvfuser_execute_partitioned(
  20. gm, *args, executor_parameters=executor_parameters
  21. )
  22. elif executor == "strictly_nvfuser":
  23. return nvfuser_execute(gm, *args, executor_parameters=executor_parameters)
  24. msg = "Received unexpected value for 'executor': {0}. Allowed values are: aten, nvfuser.".format(
  25. executor
  26. )
  27. raise ValueError(msg)
  28. def make_traced(fn: Callable):
  29. """
  30. Returns a function that, when called, will
  31. trace its torch operations to prims and then
  32. execute those prims on the requested trace executor
  33. (possibly lowering them to that trace executor first).
  34. Only supports the torch operations defined in _torch_to_reference_map
  35. in context.py and operations with positional args. All args must
  36. be tensors.
  37. In the near future all these restrictions will be lifted.
  38. Example usage:
  39. def foo(a, b):
  40. return torch.add(a, b)
  41. traced_foo = make_traced(foo)
  42. a = torch.randn((1, 2, 3, 4, 5), device='cuda')
  43. b = torch.randn((1, 2, 3, 4, 5), device='cuda')
  44. result = traced_foo(a, b, executor='nvfuser')
  45. Executor may be either 'aten' or 'nvfuser'.
  46. """
  47. def _traced(*args, executor="aten", **kwargs):
  48. # TODO: caching
  49. wrapped, all_args = wrapper_and_args_for_make_fx(fn, args, kwargs)
  50. with NvfuserPrimsMode(), TorchRefsMode():
  51. gm = make_fx(wrapped)(all_args)
  52. return execute(gm, all_args, executor=executor)
  53. return _traced