__init__.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. from . import allowed_functions, convert_frame, eval_frame, resume_execution
  2. from .backends.registry import list_backends, register_backend
  3. from .convert_frame import replay
  4. from .eval_frame import (
  5. assume_constant_result,
  6. disable,
  7. explain,
  8. export,
  9. optimize,
  10. optimize_assert,
  11. OptimizedModule,
  12. reset_code,
  13. run,
  14. skip,
  15. )
  16. from .external_utils import is_compiling
  17. from .utils import compilation_metrics, guard_failures, orig_code_map, reset_frame_count
  18. __all__ = [
  19. "allow_in_graph",
  20. "assume_constant_result",
  21. "disallow_in_graph",
  22. "graph_break",
  23. "optimize",
  24. "optimize_assert",
  25. "export",
  26. "explain",
  27. "run",
  28. "replay",
  29. "disable",
  30. "reset",
  31. "skip",
  32. "OptimizedModule",
  33. "is_compiling",
  34. "register_backend",
  35. "list_backends",
  36. ]
  37. def reset():
  38. """Clear all compile caches and restore initial state"""
  39. for weak_code in convert_frame.input_codes.seen + convert_frame.output_codes.seen:
  40. code = weak_code()
  41. if code:
  42. reset_code(code)
  43. convert_frame.input_codes.clear()
  44. convert_frame.output_codes.clear()
  45. orig_code_map.clear()
  46. guard_failures.clear()
  47. resume_execution.ContinueExecutionCache.cache.clear()
  48. eval_frame.most_recent_backend = None
  49. compilation_metrics.clear()
  50. reset_frame_count()
  51. def allow_in_graph(fn):
  52. """
  53. Customize which functions TorchDynamo will include in the generated
  54. graph. Similar to `torch.fx.wrap()`.
  55. ::
  56. torch._dynamo.allow_in_graph(my_custom_function)
  57. @torch._dynamo.optimize(...)
  58. def fn(a):
  59. x = torch.add(x, 1)
  60. x = my_custom_function(x)
  61. x = torch.add(x, 1)
  62. return x
  63. fn(...)
  64. Will capture a single graph containing `my_custom_function()`.
  65. """
  66. if isinstance(fn, (list, tuple)):
  67. return [allow_in_graph(x) for x in fn]
  68. assert callable(fn), "allow_in_graph expects a callable"
  69. allowed_functions._allowed_function_ids.add(id(fn))
  70. allowed_functions._disallowed_function_ids.remove(id(fn))
  71. return fn
  72. def disallow_in_graph(fn):
  73. """
  74. Customize which functions TorchDynamo will exclude in the generated
  75. graph and force a graph break on.
  76. ::
  77. torch._dynamo.disallow_in_graph(torch.sub)
  78. @torch._dynamo.optimize(...)
  79. def fn(a):
  80. x = torch.add(x, 1)
  81. x = torch.sub(x, 1)
  82. x = torch.add(x, 1)
  83. return x
  84. fn(...)
  85. Will break the graph on `torch.sub`, and give two graphs each with a
  86. single `torch.add()` op.
  87. """
  88. if isinstance(fn, (list, tuple)):
  89. return [disallow_in_graph(x) for x in fn]
  90. assert callable(fn), "disallow_in_graph expects a callable"
  91. allowed_functions._allowed_function_ids.remove(id(fn))
  92. allowed_functions._disallowed_function_ids.add(id(fn))
  93. return fn
  94. @disallow_in_graph
  95. def graph_break():
  96. """Force a graph break"""
  97. pass