codecache.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721
  1. import base64
  2. import dataclasses
  3. import functools
  4. import getpass
  5. import hashlib
  6. import json
  7. import logging
  8. import multiprocessing
  9. import os
  10. import re
  11. import shutil
  12. import signal
  13. import subprocess
  14. import sys
  15. import sysconfig
  16. import tempfile
  17. import types
  18. from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor
  19. from ctypes import cdll
  20. from threading import Thread
  21. from time import sleep, time
  22. from typing import Any, Callable, Dict, List
  23. import torch
  24. from torch.hub import _Faketqdm, tqdm
  25. from torch.utils import cpp_extension
  26. from . import config, cuda_properties, exc
  27. from .utils import developer_warning
  28. LOCK_TIMEOUT = 600
  29. # timing metrics for time spent in the compilation
  30. _cumulative_compile_time = 0
  31. _t0 = None
  32. def _compile_start():
  33. global _t0
  34. if _t0 is None:
  35. _t0 = time()
  36. def _compile_end():
  37. global _cumulative_compile_time, _t0
  38. if _t0 is not None:
  39. t1 = time()
  40. _cumulative_compile_time += t1 - _t0
  41. _t0 = None
  42. # print("CUMULATIVE COMPILE TIME", _cumulative_compile_time)
  43. log = logging.getLogger(__name__)
  44. logging.getLogger("filelock").setLevel(logging.DEBUG if config.debug else logging.INFO)
  45. @functools.lru_cache(None)
  46. def cache_dir():
  47. return os.environ.get(
  48. "TORCHINDUCTOR_CACHE_DIR",
  49. f"{tempfile.gettempdir()}/torchinductor_{getpass.getuser()}",
  50. )
  51. class DiskCache:
  52. @staticmethod
  53. @functools.lru_cache(None)
  54. def _subdir():
  55. subdir = os.path.join(cache_dir(), "cached_tunings")
  56. os.makedirs(subdir, exist_ok=True)
  57. return subdir
  58. @staticmethod
  59. @functools.lru_cache(4096)
  60. def _read_file(path):
  61. with open(path, "r") as fd:
  62. return json.loads(fd.read())
  63. def __init__(self, unique_name):
  64. super().__init__()
  65. self.unique_name = unique_name
  66. def lookup(self, key: Any, generate: Callable[[], Any]):
  67. """
  68. Check if we have already generated key, if not call generate()
  69. to populate the cache.
  70. """
  71. path = os.path.join(self._subdir(), code_hash(self.unique_name + repr(key)))
  72. if not os.path.exists(path):
  73. value = generate()
  74. write_atomic(path, json.dumps(value))
  75. return self._read_file(path)
  76. def get_lock_dir():
  77. lock_dir = os.path.join(cache_dir(), "locks")
  78. if not os.path.exists(lock_dir):
  79. os.makedirs(lock_dir, exist_ok=True)
  80. return lock_dir
  81. def code_hash(code):
  82. return (
  83. "c"
  84. + base64.b32encode(hashlib.sha256(code.encode("utf-8")).digest())[:51]
  85. .decode("utf-8")
  86. .lower()
  87. )
  88. def get_code_path(source_code, ext, extra):
  89. basename = code_hash(source_code + extra)
  90. subdir = os.path.join(cache_dir(), basename[1:3])
  91. path = os.path.join(subdir, f"{basename}.{ext}")
  92. return basename, subdir, path
  93. def write(source_code, ext, extra=""):
  94. basename, subdir, path = get_code_path(source_code, ext, extra)
  95. if not os.path.exists(subdir):
  96. os.makedirs(subdir, exist_ok=True)
  97. if not os.path.exists(path):
  98. write_atomic(path, source_code)
  99. return basename, path
  100. def write_atomic(path: str, source_code: str):
  101. # use a temp file for thread safety
  102. fd, tmp_path = tempfile.mkstemp(dir=os.path.dirname(path))
  103. with os.fdopen(fd, "w") as f:
  104. f.write(source_code)
  105. os.rename(tmp_path, path)
  106. def cpp_compiler():
  107. if isinstance(config.cpp.cxx, (list, tuple)):
  108. search = tuple(config.cpp.cxx)
  109. else:
  110. search = (config.cpp.cxx,)
  111. return cpp_compiler_search(search)
  112. @functools.lru_cache(1)
  113. def cpp_compiler_search(search):
  114. for cxx in search:
  115. try:
  116. if cxx is None:
  117. # gxx package is only available for Linux
  118. # according to https://anaconda.org/conda-forge/gxx/
  119. if sys.platform != "linux":
  120. continue
  121. # Do not install GXX by default
  122. if not os.getenv("TORCH_INDUCTOR_INSTALL_GXX"):
  123. continue
  124. from filelock import FileLock
  125. lock_dir = get_lock_dir()
  126. lock = FileLock(
  127. os.path.join(lock_dir, "g++.lock"), timeout=LOCK_TIMEOUT
  128. )
  129. with lock:
  130. cxx = install_gcc_via_conda()
  131. subprocess.check_output([cxx, "--version"])
  132. return cxx
  133. except (subprocess.SubprocessError, FileNotFoundError, ImportError):
  134. continue
  135. raise exc.InvalidCxxCompiler()
  136. def install_gcc_via_conda():
  137. """On older systems, this is a quick way to get a modern compiler"""
  138. prefix = os.path.join(cache_dir(), "gcc")
  139. cxx_path = os.path.join(prefix, "bin", "g++")
  140. if not os.path.exists(cxx_path):
  141. log.info("Downloading GCC via conda")
  142. conda = os.environ.get("CONDA_EXE", "conda")
  143. if conda is None:
  144. conda = shutil.which("conda")
  145. if conda is not None:
  146. subprocess.check_call(
  147. [
  148. conda,
  149. "create",
  150. f"--prefix={prefix}",
  151. "--channel=conda-forge",
  152. "--quiet",
  153. "-y",
  154. "python=3.8",
  155. "gxx",
  156. ],
  157. stdout=subprocess.PIPE,
  158. )
  159. return cxx_path
  160. def is_gcc():
  161. return re.search(r"(gcc|g\+\+)", cpp_compiler())
  162. class VecISA:
  163. _bit_width: int
  164. _macro: str
  165. _arch_flags: str
  166. _dtype_nelements: Dict[torch.dtype, int]
  167. # TorchInductor CPU vectorization reuses PyTorch vectorization utility functions
  168. # Hence, TorchInductor would depend on Sleef* to accelerate mathematical functions
  169. # like exp, pow, sin, cos and etc.
  170. # But PyTorch and TorchInductor might use different compilers to build code. If
  171. # PyTorch uses gcc-7/g++-7 to build the release package, the libtorch_cpu.so
  172. # will not expose the Sleef* AVX512 symbols since gcc-7/g++-7 cannot pass
  173. # avx512 check in CMake - FindAVX.cmake. But TorchInductor install the latest
  174. # gcc/g++ compiler by default while it could support the AVX512 compilation.
  175. # Therefore, there would be a conflict sleef version between PyTorch and
  176. # TorchInductor. Hence, we dry-compile the following code to check whether current
  177. # HW platform and PyTorch both could support AVX512 or AVX2. And suppose ARM
  178. # also needs the logic
  179. _avx_code = """
  180. #if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2)
  181. #include <ATen/cpu/vec/functional.h>
  182. #include <ATen/cpu/vec/vec.h>
  183. #endif
  184. __attribute__((aligned(64))) float in_out_ptr0[16] = {0.0};
  185. extern "C" void __avx_chk_kernel() {
  186. auto tmp0 = at::vec::Vectorized<float>(1);
  187. auto tmp1 = tmp0.exp();
  188. tmp1.store(in_out_ptr0);
  189. }
  190. """
  191. _avx_py_load = """
  192. import torch
  193. from ctypes import cdll
  194. cdll.LoadLibrary("__lib_path__")
  195. """
  196. def bit_width(self):
  197. return self._bit_width
  198. def nelements(self, dtype: torch.dtype = torch.float):
  199. return self._dtype_nelements[dtype]
  200. def build_macro(self):
  201. return self._macro
  202. def build_arch_flags(self):
  203. return self._arch_flags
  204. def __hash__(self) -> int:
  205. return hash(str(self))
  206. @functools.lru_cache(None)
  207. def __bool__(self):
  208. key, input_path = write(VecISA._avx_code, "cpp", extra="")
  209. from filelock import FileLock
  210. lock_dir = get_lock_dir()
  211. lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
  212. with lock:
  213. output_path = input_path[:-3] + "so"
  214. build_cmd = cpp_compile_command(
  215. input_path, output_path, warning_all=False, vec_isa=self
  216. ).split(" ")
  217. try:
  218. # Check build result
  219. subprocess.check_output(build_cmd, stderr=subprocess.STDOUT)
  220. subprocess.check_call(
  221. [
  222. "python",
  223. "-c",
  224. VecISA._avx_py_load.replace("__lib_path__", output_path),
  225. ],
  226. stderr=subprocess.DEVNULL,
  227. )
  228. except Exception as e:
  229. return False
  230. return True
  231. @dataclasses.dataclass
  232. class VecAVX512(VecISA):
  233. _bit_width = 512
  234. _macro = "CPU_CAPABILITY_AVX512"
  235. _arch_flags = "-mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma"
  236. _dtype_nelements = {torch.float: 16, torch.bfloat16: 32}
  237. def __str__(self) -> str:
  238. return "avx512"
  239. __hash__: Callable[[VecISA], Any] = VecISA.__hash__
  240. @dataclasses.dataclass
  241. class VecAVX2(VecISA):
  242. _bit_width = 256
  243. _macro = "CPU_CAPABILITY_AVX2"
  244. _arch_flags = "-mavx2 -mfma"
  245. _dtype_nelements = {torch.float: 8, torch.bfloat16: 16}
  246. def __str__(self) -> str:
  247. return "avx2"
  248. __hash__: Callable[[VecISA], Any] = VecISA.__hash__
  249. class InvalidVecISA(VecISA):
  250. _bit_width = 0
  251. _macro = ""
  252. _arch_flags = ""
  253. _dtype_nelements = {}
  254. def __str__(self) -> str:
  255. return "INVALID_VEC_ISA"
  256. def __bool__(self):
  257. return False
  258. __hash__: Callable[[VecISA], Any] = VecISA.__hash__
  259. invalid_vec_isa = InvalidVecISA()
  260. supported_vec_isa_list = [VecAVX512(), VecAVX2()]
  261. # Cache the cpuinfo to avoid I/O overhead. Meanwhile, the cpuinfo content
  262. # might have too much redundant content that is useless for ISA check. Hence,
  263. # we only cache some key isa information.
  264. @functools.lru_cache(None)
  265. def valid_vec_isa_list():
  266. if sys.platform != "linux":
  267. return []
  268. isa_list = []
  269. with open("/proc/cpuinfo") as _cpu_info:
  270. _cpu_info_content = _cpu_info.read()
  271. for isa in supported_vec_isa_list:
  272. if str(isa) in _cpu_info_content and isa:
  273. isa_list.append(isa)
  274. return isa_list
  275. def pick_vec_isa():
  276. _valid_vec_isa_list: List[VecISA] = valid_vec_isa_list()
  277. if not _valid_vec_isa_list:
  278. return invalid_vec_isa
  279. # If the simdlen is None, it indicates determin the vectroization length automatically
  280. if config.cpp.simdlen is None:
  281. assert _valid_vec_isa_list
  282. return _valid_vec_isa_list[0]
  283. for isa in _valid_vec_isa_list:
  284. if config.cpp.simdlen == isa.bit_width():
  285. return isa
  286. return invalid_vec_isa
  287. def get_shared(shared=True):
  288. return "-shared -fPIC" if shared else ""
  289. def get_warning_all_flag(warning_all=True):
  290. return "-Wall" if warning_all else ""
  291. def cpp_flags():
  292. return "-std=c++17 -Wno-unused-variable"
  293. def optimization_flags():
  294. base_flags = "-O3 -ffast-math -fno-finite-math-only"
  295. if sys.platform == "darwin":
  296. # Per https://mac.r-project.org/openmp/ right way to pass `openmp` flags to MacOS is via `-Xclang`
  297. # Also, `-march=native` is unrecognized option on M1
  298. base_flags += " -Xclang -fopenmp"
  299. else:
  300. base_flags += " -march=native -fopenmp"
  301. return base_flags
  302. def use_custom_generated_macros():
  303. return "-D C10_USING_CUSTOM_GENERATED_MACROS"
  304. def get_include_and_linking_paths(
  305. include_pytorch=False, vec_isa: VecISA = invalid_vec_isa
  306. ):
  307. if sys.platform == "linux" and (
  308. include_pytorch
  309. or vec_isa != invalid_vec_isa
  310. or config.cpp.enable_kernel_profile
  311. ):
  312. # Note - We include pytorch only on linux right now. There is more work
  313. # to do to enable OMP build on darwin where PyTorch is built with IOMP
  314. # and we need a way to link to what PyTorch links.
  315. ipaths = cpp_extension.include_paths() + [sysconfig.get_path("include")]
  316. lpaths = cpp_extension.library_paths() + [sysconfig.get_config_var("LIBDIR")]
  317. libs = ["c10", "torch", "torch_cpu", "torch_python", "gomp"]
  318. macros = vec_isa.build_macro()
  319. if macros:
  320. macros = f"-D{macros}"
  321. else:
  322. # Note - this is effectively a header only inclusion. Usage of some header files may result in
  323. # symbol not found, if those header files require a library.
  324. # For those cases, include the lpath and libs command as we do for pytorch above.
  325. # This approach allows us to only pay for what we use.
  326. ipaths = cpp_extension.include_paths() + [sysconfig.get_path("include")]
  327. lpaths = []
  328. macros = ""
  329. if sys.platform == "darwin":
  330. # GNU OpenMP generally is not available on MacOS
  331. # There is either Intel OpenMP(for x86) or LLVM OpenMP (for both x86 and arm64)
  332. libs = ["omp"]
  333. if os.getenv("CONDA_PREFIX") is not None:
  334. # On MacOS OpenMP is not available via the system install
  335. # But on conda can be provided using https://anaconda.org/anaconda/llvm-openmp
  336. conda_lib_path = os.path.join(os.getenv("CONDA_PREFIX"), "lib")
  337. ipaths.append(os.path.join(os.getenv("CONDA_PREFIX"), "include"))
  338. lpaths.append(conda_lib_path)
  339. # Prefer Intel OpenMP on x86 machine
  340. if os.uname().machine == "x86_64" and os.path.exists(
  341. os.path.join(conda_lib_path, "libiomp5.dylib")
  342. ):
  343. libs = ["iomp5"]
  344. else:
  345. libs = ["gomp"]
  346. ipaths = " ".join(["-I" + p for p in ipaths])
  347. lpaths = " ".join(["-L" + p for p in lpaths])
  348. libs = " ".join(["-l" + p for p in libs])
  349. return ipaths, lpaths, libs, macros
  350. def cpp_compile_command(
  351. input,
  352. output,
  353. warning_all=True,
  354. shared=True,
  355. include_pytorch=False,
  356. vec_isa: VecISA = invalid_vec_isa,
  357. ):
  358. ipaths, lpaths, libs, macros = get_include_and_linking_paths(
  359. include_pytorch, vec_isa
  360. )
  361. return re.sub(
  362. r"[ \n]+",
  363. " ",
  364. f"""
  365. {cpp_compiler()} {input} {get_shared(shared)} {get_warning_all_flag(warning_all)} {cpp_flags()}
  366. {ipaths} {lpaths} {libs} {macros}
  367. {optimization_flags()}
  368. {use_custom_generated_macros()}
  369. -o{output}
  370. """,
  371. ).strip()
  372. class CppCodeCache:
  373. cache = dict()
  374. clear = staticmethod(cache.clear)
  375. @staticmethod
  376. def _load_library(path):
  377. try:
  378. return cdll.LoadLibrary(path)
  379. except OSError as e:
  380. if "gomp" in str(e) and os.path.exists("/usr/lib64/libgomp.so.1"):
  381. # hacky workaround for fbcode/buck
  382. global _libgomp
  383. _libgomp = cdll.LoadLibrary("/usr/lib64/libgomp.so.1")
  384. return cdll.LoadLibrary(path)
  385. if "failed to map segment from shared object" in str(e):
  386. raise OSError(
  387. f"{e}. The most common reason this may occur is if the {tempfile.gettempdir()} folder "
  388. "is mounted with noexec (e.g., by default Docker mounts tmp file systems "
  389. f"as noexec). Please remount {tempfile.gettempdir()} with exec enabled, or set another "
  390. "temporary directory with TORCHINDUCTOR_CACHE_DIR environment variable."
  391. ) from e
  392. raise
  393. @classmethod
  394. def load(cls, source_code):
  395. picked_vec_isa = pick_vec_isa()
  396. key, input_path = write(
  397. source_code,
  398. "cpp",
  399. extra=cpp_compile_command("i", "o", vec_isa=picked_vec_isa),
  400. )
  401. if key not in cls.cache:
  402. from filelock import FileLock
  403. lock_dir = get_lock_dir()
  404. lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
  405. with lock:
  406. output_path = input_path[:-3] + "so"
  407. if not os.path.exists(output_path):
  408. cmd = cpp_compile_command(
  409. input=input_path, output=output_path, vec_isa=picked_vec_isa
  410. ).split(" ")
  411. try:
  412. subprocess.check_output(cmd, stderr=subprocess.STDOUT)
  413. except subprocess.CalledProcessError as e:
  414. raise exc.CppCompileError(cmd, e.output) from e
  415. cls.cache[key] = cls._load_library(output_path)
  416. cls.cache[key].key = key
  417. return cls.cache[key]
  418. class PyCodeCache:
  419. cache = dict()
  420. clear = staticmethod(cache.clear)
  421. @classmethod
  422. def load(cls, source_code):
  423. key, path = write(source_code, "py")
  424. if key not in cls.cache:
  425. with open(path) as f:
  426. code = compile(f.read(), path, "exec")
  427. mod = types.ModuleType(f"{__name__}.{key}")
  428. mod.__file__ = path
  429. mod.key = key
  430. exec(code, mod.__dict__, mod.__dict__)
  431. # another thread might set this first
  432. cls.cache.setdefault(key, mod)
  433. return cls.cache[key]
  434. class TritonCodeCache:
  435. @staticmethod
  436. def get_name(mod):
  437. (name,) = [n for n in dir(mod) if n.startswith("triton_")]
  438. return name
  439. @classmethod
  440. def load(cls, source_code):
  441. mod = PyCodeCache.load(source_code)
  442. return getattr(mod, cls.get_name(mod))
  443. def _worker_compile(source_code, cc, device):
  444. cuda_properties.set_compiler_worker_current_device(device)
  445. kernel = TritonCodeCache.load(source_code)
  446. kernel.precompile(warm_cache_only_with_cc=cc)
  447. def _load_kernel(source_code):
  448. kernel = TritonCodeCache.load(source_code)
  449. kernel.precompile()
  450. return kernel
  451. def _load_kernel_name(source_code):
  452. return TritonCodeCache.get_name(PyCodeCache.load(source_code))
  453. class TritonFuture:
  454. def __init__(self, source_code, future):
  455. self.source_code = source_code
  456. self.future = future
  457. # @dynamo_utils.dynamo_timed
  458. def result(self):
  459. t0 = time()
  460. if hasattr(self, "kernel"):
  461. return self.kernel
  462. # If the worker failed this will throw an exception.
  463. self.future.result()
  464. kernel = self.kernel = _load_kernel(self.source_code)
  465. latency = time() - t0
  466. if latency > 50:
  467. name = _load_kernel_name(self.source_code)
  468. developer_warning(
  469. f"Detected long compilation time of {latency} seconds for kernel name {name}"
  470. )
  471. developer_warning(self.source_code)
  472. del self.source_code, self.future
  473. return kernel
  474. class AsyncCompile:
  475. def __init__(self):
  476. pass
  477. @staticmethod
  478. @functools.lru_cache(1)
  479. def pool():
  480. assert config.compile_threads > 1
  481. return ThreadPoolExecutor(config.compile_threads)
  482. @staticmethod
  483. @functools.lru_cache(1)
  484. def process_pool():
  485. # ensure properties have been calculated before processes
  486. # are forked
  487. cuda_properties._properties()
  488. assert config.compile_threads > 1
  489. orig_ppid = os.getpid()
  490. # if this process dies abnormally (e.g. segfault)
  491. # it will not shut down the workers. Instead
  492. # the workers will have their parent reassigned to the
  493. # init process. This launches a separate thread to
  494. # watch for the worker getting reassigned,
  495. # and cleans it up in this case.
  496. def init():
  497. def run():
  498. while True:
  499. sleep(1)
  500. if orig_ppid != os.getppid():
  501. os.kill(os.getpid(), signal.SIGKILL)
  502. global _watchdog_thread
  503. _watchdog_thread = Thread(target=run, daemon=True)
  504. _watchdog_thread.start()
  505. # we rely on 'fork' because we cannot control whether users
  506. # have an `if __name__ == '__main__'` in their main process.
  507. fork_context = multiprocessing.get_context("fork")
  508. pool = ProcessPoolExecutor(
  509. config.compile_threads, mp_context=fork_context, initializer=init
  510. )
  511. # when this pool is created in a subprocess object, the normal exit handler
  512. # doesn't run, and we need to register our own handler.
  513. # exitpriority has to be high, because another one of the finalizers will
  514. # kill the worker thread that sends the shutdown message to the workers...
  515. multiprocessing.util.Finalize(None, pool.shutdown, exitpriority=sys.maxsize)
  516. return pool
  517. @classmethod
  518. def warm_pool(cls):
  519. if config.compile_threads <= 1:
  520. return
  521. _compile_start()
  522. pool = cls.process_pool()
  523. # We have to fork processes for compiler workers, but the more memory and other resources that are loaded, the
  524. # slower the os.fork time is, quite drastically. It also holds the GIL so we can't put it on another thread.
  525. # Examples:
  526. # A simple x + x + x script: 10ms seconds in the middle of the program, 2ms at startup
  527. # tf_efficientnet_b0 benchmark: 50ms! in the middle of the program , 3ms at startup
  528. # So we want to start the workers early when it is still cheap, and also to allow the workers to get
  529. # ready before we have work for them.
  530. # ProcessPoolExecutor also does not launch the workers until it finds a point when all the workers are idle.
  531. # But if we waited until then fork time will be long and we will be waiting for the processes to initialize.
  532. # We force them to start here with some YOLOing of the internal methods.
  533. if hasattr(pool, "_start_queue_management_thread"):
  534. pool._start_queue_management_thread()
  535. else:
  536. for _ in range(config.compile_threads):
  537. pool._adjust_process_count()
  538. pool._start_executor_manager_thread()
  539. _compile_end()
  540. @classmethod
  541. def submit(cls, task):
  542. if config.compile_threads <= 1:
  543. return task()
  544. return cls.pool().submit(task)
  545. @classmethod
  546. def map(cls, fn, seq):
  547. if config.compile_threads <= 1 or len(seq) <= 1:
  548. return list(map(fn, seq))
  549. return [t.result() for t in [cls.pool().submit(fn, x) for x in seq]]
  550. def triton(self, source_code):
  551. _compile_start()
  552. if config.compile_threads > 1:
  553. major, minor = torch.cuda.get_device_capability()
  554. device = torch.cuda.current_device()
  555. cc = major * 10 + minor
  556. future = self.process_pool().submit(
  557. _worker_compile, source_code, cc, device
  558. )
  559. return TritonFuture(source_code, future)
  560. else:
  561. return _load_kernel(source_code)
  562. def cpp(self, source_code):
  563. def task():
  564. return CppCodeCache.load(source_code).kernel
  565. return self.submit(task)
  566. def wait(self, scope: Dict[str, Any]):
  567. num_kernels = len(
  568. [
  569. value
  570. for key, value in scope.items()
  571. if isinstance(value, (Future, TritonFuture))
  572. ]
  573. )
  574. pbar = tqdm(
  575. total=num_kernels,
  576. desc="Inductor Compilation",
  577. disable=config.disable_progress,
  578. delay=0,
  579. )
  580. if config.compile_threads > 1:
  581. for key, result in scope.items():
  582. if config.verbose_progress and not isinstance(pbar, _Faketqdm):
  583. pbar.set_postfix_str(key)
  584. if isinstance(result, (Future, TritonFuture)):
  585. scope[key] = result.result()
  586. pbar.update(1)
  587. _compile_end()
  588. AsyncCompile.warm_pool()