123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160 |
- import contextlib
- import torch
- from typing import List, Tuple
- @contextlib.contextmanager
- def optimized_execution(should_optimize):
- """
- A context manager that controls whether the JIT's executor will run
- optimizations before executing a function.
- """
- stored_flag = torch._C._get_graph_executor_optimize()
- torch._C._set_graph_executor_optimize(should_optimize)
- try:
- yield
- finally:
- torch._C._set_graph_executor_optimize(stored_flag)
- @contextlib.contextmanager
- def fuser(name):
- """
- A context manager that facilitates switching between
- backend fusers.
- Valid names:
- * ``fuser0`` - enables only legacy fuser
- * ``fuser1`` - enables only NNC
- * ``fuser2`` - enables only nvFuser
- * ``fuser3`` - enables oneDNN Graph
- """
- old_cpu_fuse = torch._C._jit_can_fuse_on_cpu()
- old_gpu_fuse = torch._C._jit_can_fuse_on_gpu()
- old_texpr_fuser_state = torch._C._jit_texpr_fuser_enabled()
- old_nvfuser_state = torch._C._jit_nvfuser_enabled()
- old_llga_state = torch._C._jit_llga_enabled()
- if name == 'fuser0': # legacy fuser
- torch._C._jit_override_can_fuse_on_cpu(True)
- torch._C._jit_override_can_fuse_on_gpu(True)
- torch._C._jit_set_texpr_fuser_enabled(False)
- torch._C._jit_set_nvfuser_enabled(False)
- torch._C._jit_set_llga_enabled(False)
- elif name == 'fuser1': # NNC
- old_profiling_executor = torch._C._jit_set_profiling_executor(True)
- old_profiling_mode = torch._C._get_graph_executor_optimize(True)
- torch._C._jit_override_can_fuse_on_cpu(True)
- torch._C._jit_override_can_fuse_on_gpu(True)
- torch._C._jit_set_texpr_fuser_enabled(True)
- torch._C._jit_set_nvfuser_enabled(False)
- torch._C._jit_set_llga_enabled(False)
- elif name == 'fuser2': # nvFuser
- torch._C._jit_override_can_fuse_on_cpu(False)
- torch._C._jit_override_can_fuse_on_gpu(False)
- torch._C._jit_set_texpr_fuser_enabled(False)
- torch._C._jit_set_nvfuser_enabled(True)
- torch._C._jit_set_llga_enabled(False)
- elif name == 'fuser3': # oneDNN Graph
- old_profiling_executor = torch._C._jit_set_profiling_executor(True)
- old_profiling_mode = torch._C._get_graph_executor_optimize(True)
- torch._C._jit_override_can_fuse_on_cpu(True)
- torch._C._jit_override_can_fuse_on_gpu(False)
- torch._C._jit_set_texpr_fuser_enabled(True)
- torch._C._jit_set_nvfuser_enabled(False)
- torch._C._jit_set_llga_enabled(True)
- elif name == 'none': # Turn Pytorch fuser off
- torch._C._jit_override_can_fuse_on_cpu(False)
- torch._C._jit_override_can_fuse_on_gpu(False)
- torch._C._jit_set_texpr_fuser_enabled(False)
- torch._C._jit_set_nvfuser_enabled(False)
- torch._C._jit_set_llga_enabled(False)
- else:
- raise Exception(f"unrecognized fuser option (name: {name})")
- try:
- yield
- finally:
- if name in ['fuser1', 'fuser3']: # NNC or oneDNN Graph
- torch._C._jit_set_profiling_executor(old_profiling_executor)
- torch._C._get_graph_executor_optimize(old_profiling_mode)
- # recover the previous values
- torch._C._jit_override_can_fuse_on_cpu(old_cpu_fuse)
- torch._C._jit_override_can_fuse_on_gpu(old_gpu_fuse)
- torch._C._jit_set_texpr_fuser_enabled(old_texpr_fuser_state)
- torch._C._jit_set_nvfuser_enabled(old_nvfuser_state)
- torch._C._jit_set_llga_enabled(old_llga_state)
- last_executed_optimized_graph = torch._C._last_executed_optimized_graph
- def _get_differentiable_graph_node(node, diff_node):
- if node.kind() == 'prim::DifferentiableGraph':
- diff_node.append(node)
- else:
- for block in node.blocks():
- for n in block.nodes():
- _get_differentiable_graph_node(n, diff_node)
- def _graph_for(self, *args, **kwargs):
- return _script_method_graph_for(self, self, *args, **kwargs)
- def _script_method_graph_for(self, parent, *args, **kwargs):
- try:
- dbs = parent.get_debug_state()
- eps = list(dbs.execution_plans.values())
- assert(len(eps) == 1)
- graph = eps[0].graph.copy()
- # graph_executor_states for differentiable node
- fw_states = eps[0].code.differentiable_op_executor_states()
- diff_nodes: List[torch._C.Node] = []
- for n in graph.nodes():
- _get_differentiable_graph_node(n, diff_nodes)
- assert(len(fw_states) == len(diff_nodes))
- # swap each differentiable graph with optimized graph in their execution plan
- for n, state in zip(diff_nodes, fw_states):
- fw_execution_plans = list(state.execution_plans.values())
- # we can only update the subgraph when there's a unique execution
- # plan. Avoid assert here so we would skip the ones that can't be
- # updated while try the best effort to update other nodes.
- if len(fw_execution_plans) == 1:
- n.g_('Subgraph', fw_execution_plans[0].graph)
- return graph
- except Exception:
- # fallback approach, we just ran the graph and return the recorded optimized
- # graph
- self(*args, **kwargs)
- return last_executed_optimized_graph()
- def set_fusion_strategy(strategy: List[Tuple[str, int]]):
- """
- Sets the type and number of specializations that can occur during fusion.
- Usage: provide a list of pairs (type, depth) where type is one of "STATIC" or "DYNAMIC"
- and depth is an integer.
- Behavior - static vs dynamic:
- In STATIC fusion, fused ops are compiled to have fixed input shapes. The shape is determined
- based on some initial profiling runs.
- In DYNAMIC fusion, fused ops are compiled to have variable input shapes, so that multiple
- shapes are possible.
- In both cases, we also recompile on new striding behavior, device, or dtype.
- Behavior - fallback functions & depth:
- When an input doesn't match the format required by the specialized compiled op, it will run
- a fallback function. Fallback functions are recursively be compiled and specialized based
- on the observed tensor shapes. Since compilation can be slow, the "depth" parameter is provided to
- limit the number of specializations that can be compiled, before giving up on recompiling and
- falling back to a completely un-fused, un-specialized implementation.
- The list of (type, depth) pairs controls the type of specializations and the number of
- specializations. For example: [("STATIC", 2), ("DYNAMIC", 2)] indicates that the first
- two specializations will use static fusions, the following two specializations will use
- dynamic fusion, and any inputs that satisfy none of the 4 options will run an
- unfused implementation.
- NB: in the future, if more as more fusion backends are added there may be more granular
- apis for specific fusers.
- """
- return torch._C._jit_set_fusion_strategy(strategy)
|