1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071 |
- from typing import Callable, Optional
- from torch._prims.context import NvfuserPrimsMode, TorchRefsMode
- from torch._prims.nvfuser_executor import nvfuser_execute, nvfuser_execute_partitioned
- from torch.fx import GraphModule
- from torch.fx.experimental.proxy_tensor import make_fx, wrapper_and_args_for_make_fx
- def execute(
- gm: GraphModule,
- *args,
- executor: str = "aten",
- executor_parameters: Optional[dict] = None,
- ):
- """
- Prototype ATen executor.
- Just executes the context's graph.
- """
- if executor == "aten":
- return gm.forward(*args)
- elif executor == "nvfuser":
- return nvfuser_execute_partitioned(
- gm, *args, executor_parameters=executor_parameters
- )
- elif executor == "strictly_nvfuser":
- return nvfuser_execute(gm, *args, executor_parameters=executor_parameters)
- msg = "Received unexpected value for 'executor': {0}. Allowed values are: aten, nvfuser.".format(
- executor
- )
- raise ValueError(msg)
- def make_traced(fn: Callable):
- """
- Returns a function that, when called, will
- trace its torch operations to prims and then
- execute those prims on the requested trace executor
- (possibly lowering them to that trace executor first).
- Only supports the torch operations defined in _torch_to_reference_map
- in context.py and operations with positional args. All args must
- be tensors.
- In the near future all these restrictions will be lifted.
- Example usage:
- def foo(a, b):
- return torch.add(a, b)
- traced_foo = make_traced(foo)
- a = torch.randn((1, 2, 3, 4, 5), device='cuda')
- b = torch.randn((1, 2, 3, 4, 5), device='cuda')
- result = traced_foo(a, b, executor='nvfuser')
- Executor may be either 'aten' or 'nvfuser'.
- """
- def _traced(*args, executor="aten", **kwargs):
- # TODO: caching
- wrapped, all_args = wrapper_and_args_for_make_fx(fn, args, kwargs)
- with NvfuserPrimsMode(), TorchRefsMode():
- gm = make_fx(wrapped)(all_args)
- return execute(gm, all_args, executor=executor)
- return _traced
|