test_minifier_common.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. import os
  2. import re
  3. import subprocess
  4. import tempfile
  5. import unittest
  6. import torch
  7. import torch._dynamo
  8. import torch._dynamo.test_case
  9. from torch._dynamo.debug_utils import TEST_REPLACEABLE_COMMENT
  10. class MinifierTestBase(torch._dynamo.test_case.TestCase):
  11. _debug_dir_obj = tempfile.TemporaryDirectory()
  12. DEBUG_DIR = _debug_dir_obj.name
  13. @classmethod
  14. def setUpClass(cls):
  15. super().setUpClass()
  16. cls._exit_stack.enter_context(
  17. unittest.mock.patch.object(
  18. torch._dynamo.config,
  19. "debug_dir_root",
  20. cls.DEBUG_DIR,
  21. )
  22. )
  23. os.makedirs(cls.DEBUG_DIR, exist_ok=True)
  24. @classmethod
  25. def tearDownClass(cls):
  26. cls._debug_dir_obj.cleanup()
  27. cls._exit_stack.close()
  28. # Search for the name of the first function defined in a code string.
  29. def _get_fn_name(self, code):
  30. fn_name_match = re.search(r"def (\w+)\(", code)
  31. if fn_name_match is not None:
  32. return fn_name_match.group(1)
  33. return None
  34. # Run `code` in a separate python process.
  35. # Returns the completed process state and the directory containing the
  36. # minifier launcher script, if `code` outputted it.
  37. def _run_test_code(self, code):
  38. proc = subprocess.run(
  39. ["python3", "-c", code], capture_output=True, cwd=self.DEBUG_DIR
  40. )
  41. repro_dir_match = re.search(
  42. r"(\S+)minifier_launcher.py", proc.stderr.decode("utf-8")
  43. )
  44. if repro_dir_match is not None:
  45. return proc, repro_dir_match.group(1)
  46. return proc, None
  47. # Patch generated files with testing patches
  48. def _inject_code(self, patch_code, filename):
  49. patch_code = f"""\
  50. {patch_code}
  51. torch._dynamo.config.debug_dir_root = "{self.DEBUG_DIR}"
  52. """
  53. with open(filename, "r") as f:
  54. code = f.read()
  55. code = code.replace(TEST_REPLACEABLE_COMMENT, patch_code)
  56. with open(filename, "w") as f:
  57. f.write(code)
  58. return code
  59. # Runs the minifier launcher script in `repro_dir`, patched with `patch_code`.
  60. def _run_minifier_launcher(self, patch_code, repro_dir):
  61. self.assertIsNotNone(repro_dir)
  62. launch_file = os.path.join(repro_dir, "minifier_launcher.py")
  63. self.assertTrue(os.path.exists(launch_file))
  64. launch_code = self._inject_code(patch_code, launch_file)
  65. launch_proc = subprocess.run(
  66. ["python3", launch_file],
  67. capture_output=True,
  68. cwd=repro_dir,
  69. )
  70. return launch_proc, launch_code
  71. # Runs the repro script in `repro_dir`, patched with `patch_code`
  72. def _run_repro(self, patch_code, repro_dir):
  73. self.assertIsNotNone(repro_dir)
  74. repro_file = os.path.join(repro_dir, "repro.py")
  75. self.assertTrue(os.path.exists(repro_file))
  76. repro_code = self._inject_code(patch_code, repro_file)
  77. repro_proc = subprocess.run(
  78. ["python3", repro_file], capture_output=True, cwd=repro_dir
  79. )
  80. return repro_proc, repro_code
  81. # Template for testing code.
  82. # `run_code` is the code to run for the test case.
  83. # `patch_code` is the code to be patched in every generated file.
  84. def _gen_test_code(self, run_code, repro_after, repro_level, patch_code):
  85. return f"""\
  86. import torch
  87. import torch._dynamo
  88. {patch_code}
  89. torch._dynamo.config.repro_after = "{repro_after}"
  90. torch._dynamo.config.repro_level = {repro_level}
  91. torch._dynamo.config.debug_dir_root = "{self.DEBUG_DIR}"
  92. {run_code}
  93. """
  94. # Runs a full minifier test.
  95. # Minifier tests generally consist of 3 stages:
  96. # 1. Run the problematic code (in a separate process since it could segfault)
  97. # 2. Run the generated minifier launcher script
  98. # 3. Run the generated repro script
  99. def _run_full_test(self, run_code, repro_after, repro_level, patch_code):
  100. test_code = self._gen_test_code(run_code, repro_after, repro_level, patch_code)
  101. test_proc, repro_dir = self._run_test_code(test_code)
  102. self.assertIsNotNone(repro_dir)
  103. launch_proc, launch_code = self._run_minifier_launcher(patch_code, repro_dir)
  104. repro_proc, repro_code = self._run_repro(patch_code, repro_dir)
  105. return ((test_proc, launch_proc, repro_proc), (launch_code, repro_code))