blas_compare_setup.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. import collections
  2. import os
  3. import shutil
  4. import subprocess
  5. try:
  6. # no type stub for conda command line interface
  7. import conda.cli.python_api # type: ignore[import]
  8. from conda.cli.python_api import Commands as conda_commands
  9. except ImportError:
  10. # blas_compare.py will fail to import these when it's inside a conda env,
  11. # but that's fine as it only wants the constants.
  12. pass
  13. WORKING_ROOT = "/tmp/pytorch_blas_compare_environments"
  14. MKL_2020_3 = "mkl_2020_3"
  15. MKL_2020_0 = "mkl_2020_0"
  16. OPEN_BLAS = "open_blas"
  17. EIGEN = "eigen"
  18. GENERIC_ENV_VARS = ("USE_CUDA=0", "USE_ROCM=0")
  19. BASE_PKG_DEPS = (
  20. "cmake",
  21. "hypothesis",
  22. "ninja",
  23. "numpy",
  24. "pyyaml",
  25. "setuptools",
  26. "typing_extensions",
  27. )
  28. SubEnvSpec = collections.namedtuple(
  29. "SubEnvSpec", (
  30. "generic_installs",
  31. "special_installs",
  32. "environment_variables",
  33. # Validate install.
  34. "expected_blas_symbols",
  35. "expected_mkl_version",
  36. ))
  37. SUB_ENVS = {
  38. MKL_2020_3: SubEnvSpec(
  39. generic_installs=(),
  40. special_installs=("intel", ("mkl=2020.3", "mkl-include=2020.3")),
  41. environment_variables=("BLAS=MKL",) + GENERIC_ENV_VARS,
  42. expected_blas_symbols=("mkl_blas_sgemm",),
  43. expected_mkl_version="2020.0.3",
  44. ),
  45. MKL_2020_0: SubEnvSpec(
  46. generic_installs=(),
  47. special_installs=("intel", ("mkl=2020.0", "mkl-include=2020.0")),
  48. environment_variables=("BLAS=MKL",) + GENERIC_ENV_VARS,
  49. expected_blas_symbols=("mkl_blas_sgemm",),
  50. expected_mkl_version="2020.0.0",
  51. ),
  52. OPEN_BLAS: SubEnvSpec(
  53. generic_installs=("openblas",),
  54. special_installs=(),
  55. environment_variables=("BLAS=OpenBLAS",) + GENERIC_ENV_VARS,
  56. expected_blas_symbols=("exec_blas",),
  57. expected_mkl_version=None,
  58. ),
  59. # EIGEN: SubEnvSpec(
  60. # generic_installs=(),
  61. # special_installs=(),
  62. # environment_variables=("BLAS=Eigen",) + GENERIC_ENV_VARS,
  63. # expected_blas_symbols=(),
  64. # ),
  65. }
  66. def conda_run(*args):
  67. """Convenience method."""
  68. stdout, stderr, retcode = conda.cli.python_api.run_command(*args)
  69. if retcode:
  70. raise OSError(f"conda error: {str(args)} retcode: {retcode}\n{stderr}")
  71. return stdout
  72. def main():
  73. if os.path.exists(WORKING_ROOT):
  74. print("Cleaning: removing old working root.")
  75. shutil.rmtree(WORKING_ROOT)
  76. os.makedirs(WORKING_ROOT)
  77. git_root = subprocess.check_output(
  78. "git rev-parse --show-toplevel",
  79. shell=True,
  80. cwd=os.path.dirname(os.path.realpath(__file__))
  81. ).decode("utf-8").strip()
  82. for env_name, env_spec in SUB_ENVS.items():
  83. env_path = os.path.join(WORKING_ROOT, env_name)
  84. print(f"Creating env: {env_name}: ({env_path})")
  85. conda_run(
  86. conda_commands.CREATE,
  87. "--no-default-packages",
  88. "--prefix", env_path,
  89. "python=3",
  90. )
  91. print("Testing that env can be activated:")
  92. base_source = subprocess.run(
  93. f"source activate {env_path}",
  94. shell=True,
  95. capture_output=True,
  96. )
  97. if base_source.returncode:
  98. raise OSError(
  99. "Failed to source base environment:\n"
  100. f" stdout: {base_source.stdout.decode('utf-8')}\n"
  101. f" stderr: {base_source.stderr.decode('utf-8')}"
  102. )
  103. print("Installing packages:")
  104. conda_run(
  105. conda_commands.INSTALL,
  106. "--prefix", env_path,
  107. *(BASE_PKG_DEPS + env_spec.generic_installs)
  108. )
  109. if env_spec.special_installs:
  110. channel, channel_deps = env_spec.special_installs
  111. print(f"Installing packages from channel: {channel}")
  112. conda_run(
  113. conda_commands.INSTALL,
  114. "--prefix", env_path,
  115. "-c", channel, *channel_deps
  116. )
  117. if env_spec.environment_variables:
  118. print("Setting environment variables.")
  119. # This does not appear to be possible using the python API.
  120. env_set = subprocess.run(
  121. f"source activate {env_path} && "
  122. f"conda env config vars set {' '.join(env_spec.environment_variables)}",
  123. shell=True,
  124. capture_output=True,
  125. )
  126. if env_set.returncode:
  127. raise OSError(
  128. "Failed to set environment variables:\n"
  129. f" stdout: {env_set.stdout.decode('utf-8')}\n"
  130. f" stderr: {env_set.stderr.decode('utf-8')}"
  131. )
  132. # Check that they were actually set correctly.
  133. actual_env_vars = subprocess.run(
  134. f"source activate {env_path} && env",
  135. shell=True,
  136. capture_output=True,
  137. ).stdout.decode("utf-8").strip().splitlines()
  138. for e in env_spec.environment_variables:
  139. assert e in actual_env_vars, f"{e} not in envs"
  140. print(f"Building PyTorch for env: `{env_name}`")
  141. # We have to re-run during each build to pick up the new
  142. # build config settings.
  143. build_run = subprocess.run(
  144. f"source activate {env_path} && "
  145. f"cd {git_root} && "
  146. "python setup.py install --cmake",
  147. shell=True,
  148. capture_output=True,
  149. )
  150. print("Checking configuration:")
  151. check_run = subprocess.run(
  152. # Shameless abuse of `python -c ...`
  153. f"source activate {env_path} && "
  154. "python -c \""
  155. "import torch;"
  156. "from torch.utils.benchmark import Timer;"
  157. "print(torch.__config__.show());"
  158. "setup = 'x=torch.ones((128, 128));y=torch.ones((128, 128))';"
  159. "counts = Timer('torch.mm(x, y)', setup).collect_callgrind(collect_baseline=False);"
  160. "stats = counts.as_standardized().stats(inclusive=True);"
  161. "print(stats.filter(lambda l: 'blas' in l.lower()))\"",
  162. shell=True,
  163. capture_output=True,
  164. )
  165. if check_run.returncode:
  166. raise OSError(
  167. "Failed to set environment variables:\n"
  168. f" stdout: {check_run.stdout.decode('utf-8')}\n"
  169. f" stderr: {check_run.stderr.decode('utf-8')}"
  170. )
  171. check_run_stdout = check_run.stdout.decode('utf-8')
  172. print(check_run_stdout)
  173. for e in env_spec.environment_variables:
  174. if "BLAS" in e:
  175. assert e in check_run_stdout, f"PyTorch build did not respect `BLAS=...`: {e}"
  176. for s in env_spec.expected_blas_symbols:
  177. assert s in check_run_stdout
  178. if env_spec.expected_mkl_version is not None:
  179. assert f"- Intel(R) Math Kernel Library Version {env_spec.expected_mkl_version}" in check_run_stdout
  180. print(f"Build complete: {env_name}")
  181. if __name__ == "__main__":
  182. main()