config.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. import os
  2. import sys
  3. import tempfile
  4. from os.path import abspath, dirname
  5. import torch
  6. from . import external_utils
  7. from .logging import get_loggers_level, set_loggers_level
  8. # log level (levels print what it says + all levels listed below it)
  9. # logging.DEBUG print full traces <-- lowest level + print tracing of every instruction
  10. # logging.INFO print the steps that dynamo is running and optionally, compiled functions + graphs
  11. # logging.WARN print warnings (including graph breaks)
  12. # logging.ERROR print exceptions (and what user code was being processed when it occurred)
  13. log_level = property(
  14. lambda _: get_loggers_level(), lambda _, lvl: set_loggers_level(lvl)
  15. )
  16. # log compiled function + graphs at level INFO
  17. output_code = False
  18. # the name of a file to write the logs to
  19. log_file_name = None
  20. # Verbose will print full stack traces on warnings and errors
  21. verbose = False
  22. # If true, traced graph outputs will be outputted as Python GraphModule code.
  23. # If false, traced graph outputs will be outputted in tabular form.
  24. output_graph_code = False
  25. # verify the correctness of optimized backend
  26. verify_correctness = False
  27. # need this many ops to create an FX graph
  28. minimum_call_count = 1
  29. # turn on/off DCE pass
  30. dead_code_elimination = True
  31. # disable (for a function) when cache reaches this size
  32. cache_size_limit = 64
  33. # specializing int/float by default
  34. specialize_int_float = True
  35. # Assume these functions return constants
  36. constant_functions = {
  37. torch.jit.is_scripting: False,
  38. torch.jit.is_tracing: False,
  39. torch._C._get_tracing_state: None,
  40. torch.fx._symbolic_trace.is_fx_tracing: False,
  41. torch.onnx.is_in_onnx_export: False,
  42. external_utils.is_compiling: True,
  43. torch._utils.is_compiling: True,
  44. }
  45. # don't specialize on shapes and strides and put shape ops in graph
  46. dynamic_shapes = os.environ.get("TORCHDYNAMO_DYNAMIC_SHAPES") == "1"
  47. # Set this to False to assume nn.Modules() contents are immutable (similar assumption as freezing)
  48. guard_nn_modules = False
  49. # This feature doesn't really work. We offer this flag for experimental
  50. # purposes / if you want to help us build out support.
  51. #
  52. # torchdynamo has very limited support for tensor subclasses that implement
  53. # __torch_function__. Our current support is limited to tensor subclasses
  54. # that DO NOT store metadata on the tensor (in general, dynamo does not
  55. # support Python code that stores extra attributes on tensors at present).
  56. # If your tensor subclass purely changes function call behavior via
  57. # __torch_function__, you can allow torchdynamo to trace into it by
  58. # adding it to traceable_tensor_subclasses. We don't do any safety checks,
  59. # so it is up to you to ensure that your subclass is well behaved. See also
  60. # https://github.com/pytorch/torchdynamo/issues/1948
  61. #
  62. # We do NOT currently support __torch_dispatch__. The implementation is
  63. # currently buggy, the main show stopper for nontrivial use is
  64. # https://github.com/pytorch/torchdynamo/issues/1952
  65. traceable_tensor_subclasses = set()
  66. # Suppress errors in torch._dynamo.optimize, instead forcing a fallback to eager.
  67. # This is a good way to get your model to work one way or another, but you may
  68. # lose optimization opportunities this way. Devs, if your benchmark model is failing
  69. # this way, you should figure out why instead of suppressing it.
  70. suppress_errors = bool(os.environ.get("TORCHDYNAMO_SUPPRESS_ERRORS", False))
  71. # Record and write an execution record of the current frame to a file
  72. # if an exception is encountered
  73. replay_record_enabled = bool(os.environ.get("TORCH_COMPILE_DEBUG", False))
  74. # Rewrite assert statement in python with torch._assert
  75. rewrite_assert_with_torch_assert = True
  76. # Show a warning on every graph break
  77. print_graph_breaks = False
  78. # Disable dynamo
  79. disable = os.environ.get("TORCH_COMPILE_DISABLE", False)
  80. # If a PyTorch module is in this allowlist, torchdynamo will be allowed
  81. # to inline objects from it or its children.
  82. skipfiles_inline_module_allowlist = {
  83. torch.nn,
  84. torch.distributions,
  85. torch.testing,
  86. torch.ao.nn,
  87. torch._refs,
  88. torch._prims,
  89. torch._decomp,
  90. }
  91. # If a string representing a PyTorch module is in this ignorelist,
  92. # the `allowed_functions.is_allowed` function will not consider it
  93. # when creating a list of PyTorch functions that will appear in
  94. # FX IR.
  95. allowed_functions_module_string_ignorelist = {
  96. "torch.distributions",
  97. "torch.testing",
  98. "torch._refs",
  99. "torch._prims",
  100. "torch._decomp",
  101. }
  102. # Debug Flag to try minifier at different stages. Possible values are {None, "aot", "dynamo"}
  103. # None - Minifier is switched off
  104. # dynamo - Runs minifier on the TorchDynamo produced graphs, if compilation fails
  105. # aot - Runs minifier on the Aot Autograd produced graphs, if compilation fails
  106. repro_after = os.environ.get("TORCHDYNAMO_REPRO_AFTER", None)
  107. # Compiler compilation debug info
  108. # 1: Dumps the original graph out to repro.py if compilation fails
  109. # 2: Dumps a minifier_launcher.py if compilation fails.
  110. # 3: Always dumps a minifier_laucher.py. Good for segfaults.
  111. # 4: Dumps a minifier_launcher.py if the accuracy fails.
  112. repro_level = int(os.environ.get("TORCHDYNAMO_REPRO_LEVEL", 2))
  113. # By default, we try to detect accuracy failure by running both forward
  114. # and backward of a torchdynamo produced graph (if you are using repro_after
  115. # 'dynamo'). This setting forces us to only test the forward graph and
  116. # not the backward graph. This can be helpful if you're trying to debug
  117. # an inference only problem, but the minifier seems to be choking on the
  118. # backwards step
  119. # TODO: Detect this situation automatically so the user doesn't need
  120. # to manually configure this
  121. repro_forward_only = os.environ.get("TORCHDYNAMO_REPRO_FORWARD_ONLY") == "1"
  122. # The tolerance we should use when testing if a compiled graph
  123. # has diverged so that we should treat it as an accuracy failure
  124. repro_tolerance = 1e-3
  125. # Not all backends support scalars. Some calls on torch.Tensor (like .item()) return a scalar type.
  126. # When this flag is set to False, we introduce a graph break instead of capturing.
  127. # This requires dynamic_shapes to be True.
  128. capture_scalar_outputs = False
  129. # Should almost always be true in prod. This relaxes the requirement that cond's true_fn and
  130. # false_fn produces code with identical guards.
  131. enforce_cond_guards_match = True
  132. # Automatically split model graph into pieces to match DDP bucket sizes
  133. # to allow DDP comm/compute overlap. Disable to allow DDP models to
  134. # run without graph-breaks, but also without comm/compute overlap.
  135. # set torch._dynamo.config.log_level to INFO or DEBUG for more info
  136. # about optimize_ddp behavior.
  137. optimize_ddp = True
  138. # If True, raises exception if TorchDynamo is called with a context manager
  139. raise_on_ctx_manager_usage = True
  140. # If True, raise when aot autograd is unsafe to use
  141. raise_on_unsafe_aot_autograd = False
  142. # Throw an error if backend changes without reset
  143. raise_on_backend_change = False
  144. # If true, error with a better message if we symbolically trace over a
  145. # dynamo-optimized function. If false, silently suppress dynamo.
  146. error_on_nested_fx_trace = True
  147. # Disables graph breaking on rnn. YMMV with backends.
  148. allow_rnn = False
  149. # root folder of the project
  150. base_dir = dirname(dirname(dirname(abspath(__file__))))
  151. def is_fbcode():
  152. return not hasattr(torch.version, "git_version")
  153. if is_fbcode():
  154. debug_dir_root = os.path.join(tempfile.gettempdir(), "torch_compile_debug")
  155. else:
  156. debug_dir_root = os.path.join(os.getcwd(), "torch_compile_debug")
  157. # this is to resolve a import problem in fbcode, we will be deleting
  158. # this very shortly
  159. DO_NOT_USE_legacy_non_fake_example_inputs = False
  160. _save_config_ignore = {
  161. "repro_after",
  162. "repro_level",
  163. # workaround: "cannot pickle PyCapsule"
  164. "constant_functions",
  165. # workaround: "cannot pickle module"
  166. "skipfiles_inline_module_allowlist",
  167. }
  168. from .config_utils import install_config_module
  169. install_config_module(sys.modules[__name__])