123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114 |
- from . import allowed_functions, convert_frame, eval_frame, resume_execution
- from .backends.registry import list_backends, register_backend
- from .convert_frame import replay
- from .eval_frame import (
- assume_constant_result,
- disable,
- explain,
- export,
- optimize,
- optimize_assert,
- OptimizedModule,
- reset_code,
- run,
- skip,
- )
- from .external_utils import is_compiling
- from .utils import compilation_metrics, guard_failures, orig_code_map, reset_frame_count
- __all__ = [
- "allow_in_graph",
- "assume_constant_result",
- "disallow_in_graph",
- "graph_break",
- "optimize",
- "optimize_assert",
- "export",
- "explain",
- "run",
- "replay",
- "disable",
- "reset",
- "skip",
- "OptimizedModule",
- "is_compiling",
- "register_backend",
- "list_backends",
- ]
- def reset():
- """Clear all compile caches and restore initial state"""
- for weak_code in convert_frame.input_codes.seen + convert_frame.output_codes.seen:
- code = weak_code()
- if code:
- reset_code(code)
- convert_frame.input_codes.clear()
- convert_frame.output_codes.clear()
- orig_code_map.clear()
- guard_failures.clear()
- resume_execution.ContinueExecutionCache.cache.clear()
- eval_frame.most_recent_backend = None
- compilation_metrics.clear()
- reset_frame_count()
- def allow_in_graph(fn):
- """
- Customize which functions TorchDynamo will include in the generated
- graph. Similar to `torch.fx.wrap()`.
- ::
- torch._dynamo.allow_in_graph(my_custom_function)
- @torch._dynamo.optimize(...)
- def fn(a):
- x = torch.add(x, 1)
- x = my_custom_function(x)
- x = torch.add(x, 1)
- return x
- fn(...)
- Will capture a single graph containing `my_custom_function()`.
- """
- if isinstance(fn, (list, tuple)):
- return [allow_in_graph(x) for x in fn]
- assert callable(fn), "allow_in_graph expects a callable"
- allowed_functions._allowed_function_ids.add(id(fn))
- allowed_functions._disallowed_function_ids.remove(id(fn))
- return fn
- def disallow_in_graph(fn):
- """
- Customize which functions TorchDynamo will exclude in the generated
- graph and force a graph break on.
- ::
- torch._dynamo.disallow_in_graph(torch.sub)
- @torch._dynamo.optimize(...)
- def fn(a):
- x = torch.add(x, 1)
- x = torch.sub(x, 1)
- x = torch.add(x, 1)
- return x
- fn(...)
- Will break the graph on `torch.sub`, and give two graphs each with a
- single `torch.add()` op.
- """
- if isinstance(fn, (list, tuple)):
- return [disallow_in_graph(x) for x in fn]
- assert callable(fn), "disallow_in_graph expects a callable"
- allowed_functions._allowed_function_ids.remove(id(fn))
- allowed_functions._disallowed_function_ids.add(id(fn))
- return fn
- @disallow_in_graph
- def graph_break():
- """Force a graph break"""
- pass
|