config.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. import os
  2. import sys
  3. import torch
  4. # add some debug printouts
  5. debug = False
  6. # Whether to disable a progress bar for autotuning
  7. disable_progress = True
  8. # Whether to enable printing the source code for each future
  9. verbose_progress = False
  10. # use cpp wrapper instead of python wrapper
  11. cpp_wrapper = False
  12. # dead code elimination
  13. dce = False
  14. # assume weight tensors are fixed size
  15. static_weight_shapes = True
  16. # put correctness assertions in generated code
  17. size_asserts = True
  18. # enable loop reordering based on input orders
  19. pick_loop_orders = True
  20. # generate inplace computations
  21. inplace_buffers = True
  22. # codegen benchmark harness
  23. benchmark_harness = True
  24. # fuse pointwise into templates
  25. epilogue_fusion = False
  26. # do epilogue fusions before other fusions
  27. epilogue_fusion_first = False
  28. # enable pattern match+replace optimizations
  29. pattern_matcher = True
  30. # enable reordering pass
  31. reordering = False
  32. # enable slow autotuning passes to select algorithms
  33. max_autotune = os.environ.get("TORCHINDUCTOR_MAX_AUTOTUNE") == "1"
  34. # control store vs recompute heuristic
  35. # For fanouts, rematearialization can lead to exponential blowup. So, have
  36. # smaller threshold
  37. realize_reads_threshold = 4
  38. realize_bytes_threshold = 2000
  39. # Threshold to prevent excessive accumulation of ops in one buffer during lowering
  40. realize_acc_reads_threshold = 8
  41. # fallback to eager for random/dropout, this is slow but useful for debugging
  42. fallback_random = False
  43. # automatically create fallbacks when encountering an unhandled op
  44. implicit_fallbacks = True
  45. # do bench to decide best layout, currently only for aten.conv
  46. tune_layout = False
  47. # fuse even in cases without common reads
  48. aggressive_fusion = False
  49. # how many nodes to allow into a single fusion
  50. max_fusion_size = 64
  51. # replace small reductions with pointwise, disable with `= 1`
  52. unroll_reductions_threshold = 8
  53. comment_origin = False
  54. def is_fbcode():
  55. return not hasattr(torch.version, "git_version")
  56. # warnings intended for PyTorch developers, disable for point releases
  57. developer_warnings = is_fbcode() or "+" in torch.__version__
  58. compile_threads = (
  59. 1
  60. if sys.platform == "win32" or is_fbcode()
  61. else min(
  62. 32,
  63. len(os.sched_getaffinity(0))
  64. if hasattr(os, "sched_getaffinity")
  65. else os.cpu_count(),
  66. )
  67. )
  68. # If kernel is fused, the name is generated from the origin node op names
  69. # for larger kernels limit this
  70. kernel_name_max_ops = 10
  71. # Pad input tensors of matmul/bmm/addmm to leverage Tensor Cores in NVIDIA GPUs
  72. shape_padding = os.environ.get("TORCHINDUCTOR_SHAPE_PADDING", "0") == "1"
  73. # Fx-based linear/matmul/bmm + permute/transpose vertical fusion
  74. permute_fusion = os.environ.get("TORCHINDUCTOR_PERMUTE_FUSION", "0") == "1"
  75. # Mark the wrapper call in PyTorch profiler
  76. profiler_mark_wrapper_call = False
  77. # used for debugging to make sure config is properly set
  78. _raise_error_for_testing = False
  79. # config specific to codegen/cpp.pp
  80. class cpp:
  81. # set to torch.get_num_threads()
  82. threads = -1
  83. # Assume number of threads is dynamic, don't specialize thread number.
  84. # Kernels don't recompile on thread number changes with this flag on.
  85. # For single-threaded workload, turning it on would incur a slight
  86. # performance degradation.
  87. dynamic_threads = False
  88. simdlen = None
  89. min_chunk_size = 4096
  90. cxx = (
  91. None, # download gcc12 from conda-forge if conda is installed
  92. # "g++-12",
  93. # "g++-11",
  94. # "g++-10",
  95. # "clang++",
  96. "g++",
  97. # "g++.par",
  98. )
  99. # Allow kernel performance profiling via PyTorch profiler
  100. enable_kernel_profile = False
  101. # enable weight prepacking to get a better performance; may lead to large memory footprint
  102. weight_prepack = True
  103. # config specific to codegen/triton.py
  104. class triton:
  105. # Use cudagraphs on output code
  106. cudagraphs = False
  107. # Synchronize before and after every compiled graph.
  108. debug_sync_graph = False
  109. # Synchronize after every kernel launch, to help pinpoint bugs
  110. debug_sync_kernel = False
  111. # Always load full blocks (rather than broadcasting inside the block)
  112. dense_indexing = False
  113. # limit tiling dimensions
  114. max_tiles = 2
  115. # use triton.autotune for pointwise ops with complex layouts
  116. # this should only be disabled for debugging/testing
  117. autotune_pointwise = True
  118. # should we stop a fusion to allow better tiling?
  119. tiling_prevents_pointwise_fusion = True
  120. tiling_prevents_reduction_fusion = True
  121. # should we give different names to kernels
  122. ordered_kernel_names = False
  123. # should we put op names in kernel names
  124. descriptive_kernel_names = False
  125. # use alternate codegen for smaller reductions
  126. persistent_reductions = False
  127. # create a directory containing lots of debug information
  128. class trace:
  129. # master switch for all debugging flags below
  130. enabled = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1"
  131. # Save python logger call >=logging.DEBUG
  132. debug_log = True
  133. # Save python logger call >=logging.INFO
  134. info_log = False
  135. # Save input FX graph (post decomps, pre optimization)
  136. fx_graph = True
  137. # Save FX graph after transformations
  138. fx_graph_transformed = True
  139. # Save TorchInductor IR before fusion pass
  140. ir_pre_fusion = True
  141. # Save TorchInductor IR after fusion pass
  142. ir_post_fusion = True
  143. # Copy generated code to trace dir
  144. output_code = True
  145. # SVG figure showing post-fusion graph
  146. graph_diagram = False
  147. # Store cProfile (see snakeviz to view)
  148. compile_profile = False
  149. # Upload the .tar.gz file
  150. # Needs to be overriden based on specific environment needs
  151. upload_tar = None
  152. from .._dynamo.config_utils import install_config_module
  153. # adds patch, save_config, etc
  154. install_config_module(sys.modules[__name__])