debug_utils.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128
  1. import copy
  2. import functools
  3. import getpass
  4. import logging
  5. import os
  6. import shutil
  7. import subprocess
  8. import textwrap
  9. import uuid
  10. from collections import Counter
  11. from importlib import import_module
  12. from tempfile import TemporaryFile
  13. import torch
  14. import torch.fx as fx
  15. from torch._prims_common import is_float_dtype
  16. from . import config
  17. from .backends.registry import lookup_backend, register_debug_backend
  18. from .utils import clone_inputs, get_debug_dir
  19. log = logging.getLogger(__name__)
  20. inductor_config = import_module("torch._inductor.config")
  21. use_buck = inductor_config.is_fbcode()
  22. extra_deps = []
  23. extra_imports = ""
  24. if use_buck:
  25. extra_deps = [
  26. "//caffe2/fb/custom_ops/sparsenn:sparsenn-all_operators",
  27. "//caffe2/torch/fb/sparsenn:sparsenn_operators_gpu",
  28. "//caffe2/torch/fb/sparsenn:sparsenn_operators",
  29. "//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu",
  30. "//deeplearning/fbgemm/fbgemm_gpu:sparse_ops",
  31. ]
  32. extra_imports = "\n".join([f'torch.ops.load_library("{x}")' for x in extra_deps])
  33. class BuckTargetWriter:
  34. def __init__(self, filename):
  35. self.subdir, self.py_file = os.path.split(filename)
  36. self.target = self.py_file.replace(".py", "")
  37. # Get main_module path from fbcode
  38. self.path = f'{self.subdir.replace("/", ".")}.{self.target}'
  39. self.path = self.path[self.path.find("fbcode.") :]
  40. self.path = self.path[7:]
  41. # Get cmd line path
  42. tmp = self.subdir
  43. tmp = tmp[tmp.find("fbcode/") :][7:]
  44. self.cmd_line_path = f"//{tmp}:{self.target}"
  45. def build(self):
  46. extra_cpp_deps = "\n".join([f' "{x}",' for x in extra_deps])
  47. return textwrap.dedent(
  48. f"""
  49. load("@fbcode_macros//build_defs:python_binary.bzl", "python_binary")
  50. python_binary(
  51. name="{self.target}",
  52. srcs = ["{self.py_file}"],
  53. compile = False,
  54. deps = [
  55. "//caffe2:torch",
  56. "//caffe2/functorch:functorch",
  57. "//triton:triton",
  58. ],
  59. cpp_deps = [
  60. {extra_cpp_deps}
  61. ],
  62. main_module = "{self.path}",
  63. )
  64. """
  65. )
  66. def write(self, print_msg=True):
  67. target_file = os.path.join(self.subdir, "TARGETS")
  68. with open(target_file, "w") as fd:
  69. fd.write(self.build())
  70. # log.warning(f"Wrote isolation TARGETS file at {target_file}")
  71. cmd = ["buck2", "run", "@mode/dev-nosan", self.cmd_line_path]
  72. if print_msg:
  73. log.warning(
  74. f'Found an example that reproduces the error. Run this cmd to repro - {" ".join(cmd)}'
  75. )
  76. return cmd
  77. def minifier_dir():
  78. path = os.path.join(get_debug_dir(), "minifier")
  79. if path is None:
  80. path = f"{tempfile.gettempdir()}/minifier_{getpass.getuser()}"
  81. if not os.path.exists(path):
  82. os.makedirs(path, exist_ok=True)
  83. return path
  84. class NNModuleToString:
  85. safe_reprs = [
  86. torch.nn.Linear,
  87. torch.nn.Conv1d,
  88. torch.nn.Conv2d,
  89. torch.nn.Conv3d,
  90. torch.nn.BatchNorm1d,
  91. torch.nn.BatchNorm2d,
  92. torch.nn.BatchNorm3d,
  93. torch.nn.LayerNorm,
  94. torch.nn.Dropout,
  95. torch.nn.Softmax,
  96. torch.nn.ReLU,
  97. torch.nn.GELU,
  98. torch.nn.Identity,
  99. torch.nn.MaxPool2d,
  100. torch.nn.Embedding,
  101. torch.nn.Tanh,
  102. torch.nn.ConvTranspose1d,
  103. torch.nn.GLU,
  104. torch.nn.LSTM,
  105. torch.nn.Flatten,
  106. torch.nn.AdaptiveAvgPool2d,
  107. ]
  108. @staticmethod
  109. def can_convert_to_string(gm):
  110. cant_convert = set()
  111. for _, module in gm.named_children():
  112. if type(module) not in NNModuleToString.safe_reprs:
  113. cant_convert.add(module)
  114. if len(cant_convert) > 0:
  115. log.warning(f"We have not tested reprs of some modules - {cant_convert}")
  116. # TODO - Assuming that all modules can be safely repr'd. Check if that assumption is correct.
  117. return True
  118. @staticmethod
  119. def convert(gm):
  120. from torch.nn.modules.module import _addindent
  121. tab = " " * 4
  122. model_str = textwrap.dedent(
  123. """
  124. from torch.nn import *
  125. class Repro(torch.nn.Module):
  126. def __init__(self):
  127. super().__init__()
  128. """
  129. )
  130. for module_name, module in gm.named_children():
  131. module_str = f"{module.__repr__()}"
  132. # module should be a core torch.nn.Module, so all parameters
  133. # should be on the same device.
  134. example_param = next(module.parameters(), None)
  135. if example_param is not None and example_param.is_cuda:
  136. module_str = f"{module_str}.cuda()"
  137. model_str += f"{tab*2}self.{module_name} = {module_str}\n"
  138. for buffer_name, buffer in gm._buffers.items():
  139. if buffer is None:
  140. continue
  141. if torch.is_floating_point(buffer):
  142. tensor_str = f"torch.randn({list(buffer.shape)}, dtype={buffer.dtype})"
  143. else:
  144. tensor_str = (
  145. f"torch.randint(1, size={list(buffer.shape)}, dtype={buffer.dtype})"
  146. )
  147. if buffer.is_cuda:
  148. tensor_str = f"{tensor_str}.cuda()"
  149. model_str += f"{tab*2}self.register_buffer('{buffer_name}', {tensor_str})\n"
  150. for param_name, param in gm._parameters.items():
  151. if param is None:
  152. continue
  153. tensor_str = f"torch.nn.Parameter(torch.randn({list(param.shape)}, dtype={param.dtype}))"
  154. if param.is_cuda:
  155. tensor_str = f"{tensor_str}.cuda()"
  156. model_str += f"{tab*2}self.{param_name} = {tensor_str}\n"
  157. # TODO - Keep this code for now. But, I don't think we will need this.
  158. # attrs = dir(gm)
  159. # for attr in attrs:
  160. # if "_tensor_constant" in attr:
  161. # val = getattr(gm, attr)
  162. # model_str += f" {attr} = {val!r}\n"
  163. model_str += f"{_addindent(gm.code, 4)}\n"
  164. return model_str
  165. @functools.lru_cache(None) # subprocess is expensive
  166. def _cuda_system_info_comment():
  167. if not torch.cuda.is_available():
  168. return "# torch.cuda.is_available()==False, no GPU info collected\n"
  169. model_str = "# CUDA Info: \n"
  170. try:
  171. cuda_version_out = subprocess.run(["nvcc", "--version"], stdout=subprocess.PIPE)
  172. cuda_version_lines = cuda_version_out.stdout.decode().split("\n")
  173. cuda_version_out = "".join(
  174. [f"# {s} \n" for s in cuda_version_lines if s not in [""]]
  175. )
  176. model_str += f"{cuda_version_out}\n"
  177. except FileNotFoundError:
  178. model_str += "# nvcc not found\n"
  179. gpu_names = Counter(
  180. torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())
  181. )
  182. model_str += "# GPU Hardware Info: \n"
  183. for name, count in gpu_names.items():
  184. model_str += f"# {name} : {count} \n"
  185. model_str += "\n"
  186. return model_str
  187. def generate_config_string():
  188. import torch._functorch.config
  189. import torch._inductor.config
  190. return textwrap.dedent(
  191. f"""\
  192. import torch._dynamo.config
  193. import torch._inductor.config
  194. import torch._functorch.config
  195. torch._dynamo.config.load_config({repr(torch._dynamo.config.save_config())})
  196. torch._inductor.config.load_config({repr(torch._inductor.config.save_config())})
  197. torch._functorch.config.load_config({repr(torch._functorch.config.save_config())})
  198. """
  199. )
  200. TEST_REPLACEABLE_COMMENT = "# REPLACEABLE COMMENT FOR TESTING PURPOSES"
  201. def generate_compiler_repro_string(gm, args):
  202. model_str = textwrap.dedent(
  203. f"""
  204. import torch
  205. from torch import tensor, device
  206. import torch.fx as fx
  207. from torch._dynamo.testing import rand_strided
  208. from math import inf
  209. from torch.fx.experimental.proxy_tensor import make_fx
  210. {generate_config_string()}
  211. {TEST_REPLACEABLE_COMMENT}
  212. {extra_imports}
  213. """
  214. )
  215. model_str += f"# torch version: {torch.version.__version__}\n"
  216. if hasattr(torch.version, "cuda"):
  217. model_str += f"# torch cuda version: {torch.version.cuda}\n"
  218. if hasattr(torch.version, "git_version"):
  219. model_str += f"# torch git version: {torch.version.git_version}\n\n\n"
  220. model_str += _cuda_system_info_comment()
  221. model_str += NNModuleToString.convert(gm)
  222. model_str += f"args = {[(tuple(a.shape), tuple(a.stride()), a.dtype, a.device.type) for a in args]!r}\n"
  223. model_str += (
  224. "args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args]\n"
  225. )
  226. # TODO: fake may be better for performance here
  227. tracing_mode = "real"
  228. if config.dynamic_shapes:
  229. tracing_mode = "symbolic"
  230. model_str += f"mod = make_fx(Repro(), tracing_mode={repr(tracing_mode)})(*args)\n"
  231. return model_str
  232. INDUCTOR_IMPORT = """
  233. from torch._inductor.compile_fx import compile_fx_inner
  234. from torch._dynamo.debug_utils import same_two_models
  235. """
  236. COMPILER_REPRO_OPTIONS = {
  237. "inductor": (INDUCTOR_IMPORT, "compile_fx_inner", "inductor_fails"),
  238. "inductor_accuracy": (
  239. INDUCTOR_IMPORT,
  240. "compile_fx_inner",
  241. "inductor_accuracy_fails",
  242. ),
  243. }
  244. def dump_compiler_graph_state(gm, args, compiler_name):
  245. subdir = os.path.join(minifier_dir(), "checkpoints")
  246. if not os.path.exists(subdir):
  247. os.makedirs(subdir, exist_ok=True)
  248. file_name = os.path.join(subdir, f"{len(gm.graph.nodes)}.py")
  249. log.warning(f"Writing checkpoint with {len(gm.graph.nodes)} nodes to {file_name}")
  250. with open(file_name, "w") as fd:
  251. save_graph_repro(fd, gm, args, compiler_name)
  252. curdir = os.getcwd()
  253. repro_path = os.path.join(curdir, "repro.py")
  254. try:
  255. shutil.copyfile(file_name, repro_path)
  256. log.warning(f"Copying repro file for convenience to {repro_path}")
  257. if use_buck:
  258. BuckTargetWriter(file_name).write()
  259. except OSError:
  260. log.warning(f"No write permissions for {repro_path}")
  261. pass
  262. def save_graph_repro(fd, gm, args, compiler_name):
  263. sync_line = ""
  264. for arg in args:
  265. if arg.is_cuda:
  266. sync_line = "torch.cuda.synchronize() # Ensures that segfaults are surfaced"
  267. break
  268. if "inductor" in compiler_name:
  269. fd.write("import torch._inductor.overrides\n")
  270. fd.write(generate_compiler_repro_string(gm, args))
  271. fd.write(COMPILER_REPRO_OPTIONS[compiler_name][0])
  272. if "_accuracy" in compiler_name:
  273. fd.write(
  274. textwrap.dedent(
  275. f"""
  276. compiled = {COMPILER_REPRO_OPTIONS[compiler_name][1]}(mod, args)
  277. class AccuracyError(Exception):
  278. pass
  279. if not same_two_models(mod, compiled, args, only_fwd=True):
  280. raise AccuracyError("Bad accuracy detected")
  281. """
  282. )
  283. )
  284. else:
  285. fd.write(
  286. textwrap.dedent(
  287. f"""
  288. compiled = {COMPILER_REPRO_OPTIONS[compiler_name][1]}(mod, args)
  289. ref = compiled(args)
  290. {sync_line}
  291. """
  292. )
  293. )
  294. def isolate_fails(fx_g, args, compiler_name: str, env=None, patch_code=None):
  295. if env is None:
  296. env = {}
  297. subdir = os.path.join(os.getcwd(), "isolate")
  298. if not os.path.exists(subdir):
  299. os.makedirs(subdir, exist_ok=True)
  300. file_name = os.path.join(subdir, f"{str(uuid.uuid4())[:5]}.py")
  301. with open(file_name, "w") as fd:
  302. repro_code = generate_compiler_repro_string(fx_g, args)
  303. if patch_code is not None:
  304. repro_code = repro_code.replace(TEST_REPLACEABLE_COMMENT, patch_code)
  305. fd.write(repro_code)
  306. fail_fn = COMPILER_REPRO_OPTIONS[compiler_name][2]
  307. fd.write(
  308. textwrap.dedent(
  309. f"""
  310. from {__name__} import {fail_fn}
  311. """
  312. )
  313. )
  314. fd.write(
  315. textwrap.dedent(
  316. f"""
  317. if {fail_fn}(mod, args):
  318. exit(1)
  319. else:
  320. exit(0)
  321. """
  322. )
  323. )
  324. # with open(file_name, "r") as fd:
  325. # print(fd.read())
  326. new_env = os.environ.copy()
  327. new_env = {**new_env, **env}
  328. stdout, stderr = TemporaryFile(), TemporaryFile()
  329. if use_buck:
  330. cmd = BuckTargetWriter(file_name).write(print_msg=False)
  331. else:
  332. cmd = ["python", file_name]
  333. p = subprocess.Popen(
  334. cmd,
  335. cwd=subdir,
  336. stdout=stdout,
  337. stderr=stderr,
  338. env=new_env,
  339. )
  340. p.wait()
  341. if p.returncode != 0:
  342. stdout.seek(0)
  343. stderr.seek(0)
  344. print(textwrap.indent(stdout.read().decode("utf-8"), prefix=">> "))
  345. print(textwrap.indent(stderr.read().decode("utf-8"), prefix=">> "))
  346. # print(f"Isolated test failed - {file_name}")
  347. return True
  348. return False
  349. def inductor_fails(fx_g, args, check_str=None):
  350. has_cuda = False
  351. for arg in args:
  352. if arg.is_cuda:
  353. has_cuda = True
  354. break
  355. def sync():
  356. if has_cuda:
  357. # Ensures that segfaults are surfaced
  358. torch.cuda.synchronize()
  359. from torch._inductor.compile_fx import compile_fx_inner
  360. try:
  361. result = fx_g(*args)
  362. assert isinstance(result, (tuple, list))
  363. assert not any([isinstance(x, (tuple, list)) for x in result])
  364. except Exception:
  365. return False
  366. sync()
  367. try:
  368. compile_mod = compile_fx_inner(fx_g, args)
  369. compile_mod(args)
  370. sync()
  371. except Exception as e:
  372. if check_str is not None and check_str not in repr(e):
  373. return False
  374. print(repr(e))
  375. return True
  376. return False
  377. def inductor_accuracy_fails(fx_g, args, check_str=None):
  378. from torch._inductor.compile_fx import compile_fx_inner
  379. return backend_aot_accuracy_fails(fx_g, args, compile_fx_inner)
  380. def get_minifier_repro_path():
  381. return os.path.join(minifier_dir(), "minifier_launcher.py")
  382. def helper_for_dump_minify(contents):
  383. minified_repro_path = get_minifier_repro_path()
  384. log.warning(f"Writing minified repro to {minified_repro_path}")
  385. if use_buck:
  386. BuckTargetWriter(minified_repro_path).write()
  387. try:
  388. with open(minified_repro_path, "w") as fd:
  389. fd.write(contents)
  390. except OSError as e:
  391. log.exception(e)
  392. raise NotImplementedError("Could not write to {minified_repro_path}") from e
  393. def dump_to_minify(gm, args, compiler_name: str):
  394. favored_device = 1 if torch.cuda.device_count() >= 2 else 0
  395. contents = textwrap.dedent(
  396. f"""
  397. isolate_fails_code_str = None
  398. {generate_compiler_repro_string(gm, args)}
  399. from functools import partial
  400. from {__name__} import (
  401. isolate_fails,
  402. dump_compiler_graph_state,
  403. )
  404. from functorch.compile import minifier
  405. env_variables = {{"CUDA_VISIBLE_DEVICES": "{favored_device}"}}
  406. minifier(
  407. mod,
  408. args,
  409. module_fails=partial(isolate_fails, env=env_variables, compiler_name="{compiler_name}", patch_code=isolate_fails_code_str),
  410. dump_state=partial(dump_compiler_graph_state, compiler_name="{compiler_name}"),
  411. )
  412. """
  413. )
  414. return helper_for_dump_minify(contents)
  415. class AccuracyError(Exception):
  416. pass
  417. def wrap_compiler_debug(unconfigured_compiler_fn, compiler_name: str):
  418. """
  419. Minifier for Fx Graph modules after Aot Autograd has finished. We wrap both
  420. forward and backward call separately with the backend compiler_fn - like
  421. inductor or nvfuser. Intercepting after Aot Autograd presents neat
  422. abstration, where all the params are lifted as graph inputs, making it easy
  423. to save the graph as a string.
  424. """
  425. @functools.wraps(unconfigured_compiler_fn)
  426. def debug_wrapper(gm, example_inputs, **kwargs):
  427. from torch._subclasses import FakeTensorMode
  428. compiler_fn = functools.partial(unconfigured_compiler_fn, **kwargs)
  429. orig_graph = copy.deepcopy(gm.graph)
  430. assert config.repro_after in ("dynamo", "aot", None)
  431. inner_compiled_fn = None
  432. def deferred_for_real_inputs(real_inputs):
  433. """
  434. Aot Autograd fw_compiler and bw_compiler can have fake tensors. So,
  435. example_inputs can be fake tensors. We can call compiler_fn (which is
  436. inductor or nvfuser) with fake tensors but the actualy compiled_fn
  437. should be called with real tensors. Therefore, the actual invocation
  438. is deffered.
  439. """
  440. # Avoid re-compiling when we call the compiled function twice. This happens
  441. # when we run the model inference or training in a for loop like here
  442. # https://github.com/pytorch/torchdynamo/issues/1687#issuecomment-1280040633
  443. nonlocal inner_compiled_fn
  444. # Copy the tensor attrs like shape, stride etc by converting to Fake Tensor
  445. # because inductor clears the tensor list in its codegen. And example_inputs
  446. # are available only for the first invocation.
  447. fake_mode = FakeTensorMode()
  448. copy_tensor_attrs = [fake_mode.from_tensor(x) for x in real_inputs]
  449. if config.repro_level == 3:
  450. # Always dump the original module in case we have segfaults
  451. dump_to_minify(
  452. fx.GraphModule(gm, orig_graph), real_inputs, compiler_name
  453. )
  454. if config.repro_level == 4:
  455. if compiler_name != "inductor":
  456. raise NotImplementedError(
  457. "Accuracy minification is supported for inductor only"
  458. )
  459. if inner_compiled_fn is None:
  460. inner_compiled_fn = compiler_fn(gm, example_inputs)
  461. if backend_aot_accuracy_fails(gm, real_inputs, compiler_fn):
  462. log.warning("Accuracy failed for the AOT Autograd graph")
  463. dump_compiler_graph_state(
  464. fx.GraphModule(gm, orig_graph),
  465. copy_tensor_attrs,
  466. f"{compiler_name}_accuracy",
  467. )
  468. dump_to_minify(
  469. fx.GraphModule(gm, orig_graph),
  470. copy_tensor_attrs,
  471. f"{compiler_name}_accuracy",
  472. )
  473. raise AccuracyError("Bad accuracy detected")
  474. else:
  475. # Call the compiled function with real inputs
  476. return inner_compiled_fn(real_inputs)
  477. else:
  478. try:
  479. # Call the compiler_fn - which is either aot_autograd or inductor
  480. # with fake inputs
  481. if inner_compiled_fn is None:
  482. inner_compiled_fn = compiler_fn(gm, example_inputs)
  483. # Call the compiled function with real inputs
  484. return inner_compiled_fn(real_inputs)
  485. except Exception as e:
  486. if config.repro_level == 1:
  487. dump_compiler_graph_state(
  488. fx.GraphModule(gm, orig_graph),
  489. copy_tensor_attrs,
  490. compiler_name,
  491. )
  492. elif config.repro_level == 2:
  493. dump_to_minify(
  494. fx.GraphModule(gm, orig_graph),
  495. copy_tensor_attrs,
  496. compiler_name,
  497. )
  498. log.error("CompilerError")
  499. raise
  500. if config.repro_after == "aot":
  501. compiled_fn = deferred_for_real_inputs
  502. compiled_fn._boxed_call = True
  503. else:
  504. compiled_fn = compiler_fn(gm, example_inputs)
  505. return compiled_fn
  506. return debug_wrapper
  507. def run_fwd_maybe_bwd(gm, args, only_fwd=False):
  508. """
  509. Runs a forward and possibly backward iteration for a given mod and args.
  510. """
  511. from torch._functorch.aot_autograd import make_boxed_func
  512. from .testing import collect_results, reduce_to_scalar_loss, requires_bwd_pass
  513. gm = copy.deepcopy(gm)
  514. new_args = clone_inputs(args)
  515. # Set the requires_grad field explicitly because clone_inputs only sets
  516. # requires_grad for leaf tensors.
  517. for narg, arg in zip(new_args, args):
  518. narg.requires_grad_(arg.requires_grad)
  519. args = new_args
  520. if hasattr(gm, "zero_grad"):
  521. gm.zero_grad(True)
  522. # TorchInductor returned callable expects lists. So, boxing the call.
  523. orig_named_parameters = getattr(gm, "named_parameters", None)
  524. orig_named_buffers = getattr(gm, "named_buffers", None)
  525. if not hasattr(gm, "_boxed_call") and (
  526. orig_named_parameters is not None or orig_named_buffers is not None
  527. ):
  528. gm = make_boxed_func(gm)
  529. if orig_named_parameters is not None:
  530. gm.named_parameters = orig_named_parameters
  531. if orig_named_buffers is not None:
  532. gm.named_buffers = orig_named_buffers
  533. out = gm(args)
  534. if only_fwd:
  535. return out
  536. if requires_bwd_pass(out):
  537. loss = reduce_to_scalar_loss(out)
  538. loss.backward()
  539. return collect_results(gm, out, None, args)
  540. def same_two_models(gm, opt_gm, example_inputs, only_fwd=False):
  541. """
  542. Check two models have same accuracy.
  543. """
  544. from .eval_frame import OptimizedModule
  545. from .testing import (
  546. named_buffers_for_optimized_module,
  547. named_parameters_for_optimized_module,
  548. )
  549. from .utils import same
  550. if isinstance(gm, OptimizedModule):
  551. gm.named_parameters = named_parameters_for_optimized_module(gm)
  552. gm.named_buffers = named_buffers_for_optimized_module(gm)
  553. if isinstance(opt_gm, OptimizedModule):
  554. opt_gm.named_parameters = named_parameters_for_optimized_module(opt_gm)
  555. opt_gm.named_buffers = named_buffers_for_optimized_module(opt_gm)
  556. ref = run_fwd_maybe_bwd(gm, example_inputs, only_fwd)
  557. try:
  558. fp64_model, fp64_examples = cast_to_fp64(
  559. copy.deepcopy(gm), clone_inputs(example_inputs)
  560. )
  561. fp64_ref = run_fwd_maybe_bwd(fp64_model, fp64_examples, only_fwd)
  562. except Exception:
  563. log.warning("Could not generate fp64 outputs")
  564. fp64_ref = None
  565. try:
  566. res = run_fwd_maybe_bwd(opt_gm, example_inputs, only_fwd)
  567. except Exception as e:
  568. # This means that the the minified graph is bad/exposes a different problem.
  569. # As we are checking accuracy here, lets log the exception and return True.
  570. log.exception(
  571. (
  572. "While minifying the program in accuracy minification mode, "
  573. "ran into a runtime exception which is likely an unrelated issue."
  574. " Skipping this graph."
  575. )
  576. )
  577. return True
  578. passing = same(ref, res, fp64_ref, tol=config.repro_tolerance, equal_nan=True)
  579. return passing
  580. def cast_convert_element_type_to_fp64(model):
  581. for node in model.graph.nodes:
  582. if (
  583. node.op == "call_function"
  584. and node.target == torch.ops.prims.convert_element_type.default
  585. ):
  586. assert len(node.args) == 2
  587. if is_float_dtype(node.args[1]) and node.args[1] != torch.float64:
  588. node.args = (node.args[0], torch.float64)
  589. model.graph.lint()
  590. model.recompile()
  591. return model
  592. def cast_to(dtype, model, inputs):
  593. from torch.utils._pytree import tree_map
  594. model = model.to(dtype)
  595. if dtype == torch.float64:
  596. # If casting to fp64 for accuracy comparison, we need to
  597. # take care of convert_element_type explicitly
  598. model = cast_convert_element_type_to_fp64(model)
  599. inputs = tree_map(
  600. lambda x: x.to(dtype)
  601. if isinstance(x, torch.Tensor) and x.is_floating_point()
  602. else x,
  603. inputs,
  604. )
  605. return model, inputs
  606. def cast_to_fp64(model, inputs):
  607. return cast_to(torch.float64, model, inputs)
  608. def generate_dynamo_fx_repro_string(
  609. model_str, args, compiler_name, check_accuracy=False
  610. ):
  611. """
  612. Generate a repro string for backend-agnostic minified version.
  613. """
  614. run_code = textwrap.dedent(
  615. f"""
  616. with torch.cuda.amp.autocast(enabled={torch.is_autocast_enabled()}):
  617. ref = run_fwd_maybe_bwd(mod, args)
  618. res = run_fwd_maybe_bwd(opt_mod, args)
  619. """
  620. )
  621. if config.repro_level == 4 or check_accuracy:
  622. run_code = textwrap.dedent(
  623. f"""
  624. mod.eval()
  625. opt_mod.eval()
  626. class AccuracyError(Exception):
  627. pass
  628. with torch.cuda.amp.autocast(enabled={torch.is_autocast_enabled()}):
  629. assert same_two_models(mod, mod, args), "Eager itself failed"
  630. if not same_two_models(mod, opt_mod, args):
  631. raise AccuracyError("Dynamo failed")
  632. """
  633. )
  634. return textwrap.dedent(
  635. f"""
  636. from math import inf
  637. import torch
  638. from torch import tensor, device
  639. import torch.fx as fx
  640. import torch._dynamo
  641. from torch._dynamo.testing import rand_strided
  642. from torch._dynamo.debug_utils import run_fwd_maybe_bwd
  643. from torch._dynamo.debug_utils import same_two_models
  644. {generate_config_string()}
  645. {TEST_REPLACEABLE_COMMENT}
  646. {extra_imports}
  647. args = {[(tuple(a.shape), tuple(a.stride()), a.dtype, a.device.type, a.requires_grad) for a in args]}
  648. args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args]
  649. {model_str}
  650. mod = Repro()
  651. opt_mod = torch._dynamo.optimize("{compiler_name}")(mod)
  652. {run_code}
  653. """
  654. )
  655. def dump_backend_repro_as_file(gm, args, compiler_name, check_accuracy=False):
  656. """
  657. Saves the repro to a repro.py file
  658. """
  659. curdir = os.getcwd()
  660. subdir = os.path.join(os.getcwd(), "checkpoints")
  661. if not os.path.exists(subdir):
  662. os.makedirs(subdir, exist_ok=True)
  663. file_name = os.path.join(subdir, f"minified_{len(gm.graph.nodes)}_nodes.py")
  664. log.warning(f"Writing checkpoint with {len(gm.graph.nodes)} nodes to {file_name}")
  665. model_str = NNModuleToString.convert(gm)
  666. with open(file_name, "w") as fd:
  667. fd.write(
  668. generate_dynamo_fx_repro_string(
  669. model_str, args, compiler_name, check_accuracy
  670. )
  671. )
  672. latest_repro = os.path.join(curdir, "repro.py")
  673. log.warning(f"Copying {file_name} to {latest_repro} for convenience")
  674. if use_buck:
  675. BuckTargetWriter(latest_repro).write()
  676. shutil.copyfile(file_name, latest_repro)
  677. # TODO - Commented because we are assuming that nn.Modules can be safely repr'd
  678. # If that does not work, we might have to bring this code back. So, keeping it
  679. # as it is for now.
  680. # def dump_backend_repro_as_tarfile(gm, args, compiler_name):
  681. # """
  682. # Saves the repro in repro.tar.gz, as opposed to a file. This is used for
  683. # cases, where we can't convert a Fx GraphModule to a string, and therefore
  684. # fallback to to_folder for serialization. We accompany this with a repro.py
  685. # script that imports the saved module, sets it up and runs the model to repro
  686. # the error.
  687. # """
  688. # import tarfile
  689. # subdir = os.path.join(minifier_dir(), "checkpoints")
  690. # if not os.path.exists(subdir):
  691. # os.makedirs(subdir, exist_ok=True)
  692. # tmp_dir = os.path.join(subdir, f"{len(gm.graph.nodes)}")
  693. # if os.path.exists(tmp_dir):
  694. # shutil.rmtree(tmp_dir)
  695. # os.makedirs(tmp_dir, exist_ok=True)
  696. # file_name = os.path.join(tmp_dir, "repro.py")
  697. # gm_dir = os.path.join(tmp_dir, "module")
  698. # if not os.path.exists(gm_dir):
  699. # os.makedirs(gm_dir, exist_ok=True)
  700. # for node in gm.graph.nodes:
  701. # new_kwargs = {}
  702. # for k, v in node.kwargs.items():
  703. # if isinstance(v, torch.device):
  704. # v = v.type
  705. # new_kwargs[k] = v
  706. # node.kwargs = new_kwargs
  707. # gm.recompile()
  708. # print(f"Writing checkpoint with {len(gm.graph.nodes)} nodes to {file_name}")
  709. # with open(file_name, "w") as fd:
  710. # # TODO - Add the readable version of to_folder when available
  711. # gm.to_folder(gm_dir, "Repro")
  712. # fd.write(
  713. # generate_dynamo_fx_repro_string(
  714. # "from module import Repro", args, compiler_name
  715. # )
  716. # )
  717. # local_dir = os.path.join(config.base_dir, "repro")
  718. # if os.path.exists(local_dir):
  719. # shutil.rmtree(local_dir)
  720. # shutil.copytree(tmp_dir, local_dir)
  721. # local_tar_file = os.path.join(config.base_dir, "repro.tar.gz")
  722. # print(f"Writing checkpoint with {len(gm.graph.nodes)} locally to {local_tar_file}")
  723. # with tarfile.open(local_tar_file, "w:gz") as tar:
  724. # tar.add(local_dir, arcname=os.path.basename(local_dir))
  725. def dump_backend_state(gm, args, compiler_name, check_accuracy=False):
  726. """
  727. Dumps the dynamo graph to repro the issue.
  728. 1) It tries to convert Fx GraphModule to a string. If we can, it writes to a
  729. repro.py file.
  730. 2) If we can't convert Fx GraphModule to a string, we use to_folder to save
  731. the module and save a tar file.
  732. """
  733. assert NNModuleToString.can_convert_to_string(gm)
  734. return dump_backend_repro_as_file(gm, args, compiler_name, check_accuracy)
  735. # return dump_backend_repro_as_tarfile(gm, args, compiler_name)
  736. def backend_accuracy_fails(gm, example_inputs, compiler_fn, only_fwd=False):
  737. try:
  738. compiled_gm = compiler_fn(copy.deepcopy(gm), clone_inputs(example_inputs))
  739. except Exception as e:
  740. # This means that the the minified graph is bad/exposes a different problem.
  741. # As we are checking accuracy here, lets log the exception and return False.
  742. log.exception(
  743. (
  744. "While minifying the program in accuracy minification mode, "
  745. "ran into a runtime exception which is likely an unrelated issue."
  746. " Skipping this graph"
  747. )
  748. )
  749. return False
  750. return not same_two_models(gm, compiled_gm, example_inputs, only_fwd)
  751. backend_aot_accuracy_fails = functools.partial(backend_accuracy_fails, only_fwd=True)
  752. # Please see NOTE: [Real Tensors in Accuracy Evaluation]
  753. MINIFIER_SPAWNED = False
  754. def backend_fails(gm, example_inputs, compiler_fn, orig_failure):
  755. """
  756. Minifier uses this function to identify if the minified graph module fails
  757. with the same error.
  758. One caveat is that minifier can potentially go into a wrong direction when
  759. the resulting graph module fails for a different reason. To avoid this, we
  760. save the string for the original exception and check similarity between new
  761. and old exception. They can be somewhat different in some cases, when the
  762. exception string depends on the failing node information. So, we have a
  763. loose similarity metric to guide the minifier path.
  764. """
  765. from difflib import SequenceMatcher
  766. try:
  767. compiled_gm = compiler_fn(gm, example_inputs)
  768. run_fwd_maybe_bwd(compiled_gm, clone_inputs(example_inputs))
  769. return False
  770. except Exception as e:
  771. new_failure = str(e)
  772. if SequenceMatcher(None, orig_failure, new_failure).ratio() > 0.5:
  773. return True
  774. return False
  775. def dump_to_minify_after_dynamo(gm, args, compiler_name):
  776. model_str = NNModuleToString.convert(gm)
  777. minifier_backend = "dynamo_minifier_backend"
  778. if config.repro_level == 4:
  779. minifier_backend = "dynamo_accuracy_minifier_backend"
  780. custom_compiler_error = (
  781. textwrap.dedent(
  782. """\
  783. raise RuntimeError(
  784. 'Compiler name is None - this likely means that a custom compiler '
  785. 'was called by torchdynamo. Please remove this error, import your '
  786. 'custom compiler function, and replace the compiler_name="None" '
  787. 'line below to compiler_name=<my_imported_custom_function>'
  788. )
  789. """
  790. )
  791. if compiler_name is None
  792. else ""
  793. )
  794. contents = textwrap.dedent(
  795. f"""
  796. import os
  797. from math import inf
  798. import torch
  799. from torch import tensor, device
  800. import torch.fx as fx
  801. import functools
  802. import torch._dynamo
  803. from torch._dynamo.debug_utils import run_fwd_maybe_bwd
  804. from torch._dynamo.backends.registry import lookup_backend
  805. from torch._dynamo.testing import rand_strided
  806. {generate_config_string()}
  807. {TEST_REPLACEABLE_COMMENT}
  808. {extra_imports}
  809. args = {[(tuple(a.shape), tuple(a.stride()), a.dtype, a.device.type, a.requires_grad) for a in args]}
  810. args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args]
  811. {model_str}
  812. mod = Repro()
  813. # Setup debug minifier compiler
  814. torch._dynamo.debug_utils.MINIFIER_SPAWNED = True
  815. compiler_fn = lookup_backend("{minifier_backend}")
  816. {custom_compiler_error}
  817. dynamo_minifier_backend = functools.partial(
  818. compiler_fn,
  819. compiler_name="{compiler_name}",
  820. )
  821. opt_mod = torch._dynamo.optimize(dynamo_minifier_backend)(mod)
  822. with torch.cuda.amp.autocast(enabled={torch.is_autocast_enabled()}):
  823. opt_mod(*args)
  824. """
  825. )
  826. helper_for_dump_minify(contents)
  827. def wrap_backend_debug(unconfigured_compiler_fn, compiler_name: str):
  828. """
  829. A minifier decorator that wraps the TorchDynamo produced Fx graph modules.
  830. As opposed to wrap_compiler_debug, this wrapper intercepts at the
  831. TorchDynamo produced Fx Graph Module. This makes it backend-agnostic to some
  832. level, e.g., it is useful for minifying issues related to Aot Autograd
  833. tracing. If an error is found, we minify and save the minified repro in
  834. repro.tar.gz.
  835. """
  836. @functools.wraps(unconfigured_compiler_fn)
  837. def debug_wrapper(gm, example_inputs, **kwargs):
  838. compiler_fn = functools.partial(unconfigured_compiler_fn, **kwargs)
  839. assert config.repro_after in ("dynamo", "aot", None)
  840. if config.repro_after == "dynamo":
  841. if config.repro_level == 3:
  842. dump_to_minify_after_dynamo(gm, example_inputs, compiler_name)
  843. # Check for either accuracy (level 4) or other type of failures.
  844. if config.repro_level == 4:
  845. # Check Accuracy
  846. compiled_gm = compiler_fn(copy.deepcopy(gm), example_inputs)
  847. if backend_accuracy_fails(gm, example_inputs, compiler_fn):
  848. log.warning(
  849. "Accuracy failed for the TorchDyanmo produced graph. Creating script to minify the error."
  850. )
  851. dump_to_minify_after_dynamo(
  852. fx.GraphModule(gm, copy.deepcopy(gm.graph)),
  853. example_inputs,
  854. compiler_name,
  855. )
  856. exc = AccuracyError("Bad accuracy detected.")
  857. exc.minifier_path = os.path.join(
  858. minifier_dir(), "minifier_launcher.py"
  859. )
  860. raise exc
  861. else:
  862. try:
  863. compiled_gm = compiler_fn(copy.deepcopy(gm), example_inputs)
  864. run_fwd_maybe_bwd(compiled_gm, example_inputs)
  865. except Exception as exc:
  866. log.warning(
  867. "Compiled Fx GraphModule failed. Creating script to minify the error."
  868. )
  869. if config.repro_level == 1:
  870. dump_state_fn = functools.partial(
  871. dump_backend_state, compiler_name=compiler_name
  872. )
  873. dump_state_fn(
  874. fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs
  875. )
  876. elif config.repro_level == 2:
  877. dump_to_minify_after_dynamo(
  878. fx.GraphModule(gm, copy.deepcopy(gm.graph)),
  879. example_inputs,
  880. compiler_name,
  881. )
  882. exc.minifier_path = os.path.join(
  883. minifier_dir(), "minifier_launcher.py"
  884. )
  885. raise
  886. else:
  887. compiled_gm = compiler_fn(gm, example_inputs)
  888. return compiled_gm
  889. debug_wrapper._torchdynamo_orig_callable = unconfigured_compiler_fn
  890. return debug_wrapper
  891. @register_debug_backend
  892. def dynamo_minifier_backend(gm, example_inputs, compiler_name):
  893. from functorch.compile import minifier
  894. compiler_fn = lookup_backend(compiler_name)
  895. try:
  896. compiled_gm = compiler_fn(gm, example_inputs)
  897. run_fwd_maybe_bwd(compiled_gm, example_inputs)
  898. raise ValueError("No issue was detected")
  899. except Exception as exc:
  900. orig_failure = str(exc)
  901. log.warning(
  902. "Compiled Fx GraphModule failed. Creating script to minify the error."
  903. )
  904. dump_state_fn = functools.partial(
  905. dump_backend_state, compiler_name=compiler_name
  906. )
  907. dump_state_fn(fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs)
  908. fails_fn = functools.partial(
  909. backend_fails,
  910. compiler_fn=compiler_fn,
  911. orig_failure=orig_failure,
  912. )
  913. minifier(
  914. gm,
  915. example_inputs,
  916. module_fails=fails_fn,
  917. dump_state=dump_state_fn,
  918. )
  919. return gm
  920. @register_debug_backend
  921. def dynamo_accuracy_minifier_backend(gm, example_inputs, compiler_name):
  922. from functorch.compile import minifier
  923. compiler_fn = lookup_backend(compiler_name)
  924. # Set the eval mode to remove randomness.
  925. gm.eval()
  926. # Check Accuracy
  927. if backend_accuracy_fails(
  928. gm, example_inputs, compiler_fn, only_fwd=config.repro_forward_only
  929. ):
  930. log.warning("Accuracy failed for the TorchDynamo produced graph")
  931. dump_state_fn = functools.partial(
  932. dump_backend_state, compiler_name=compiler_name, check_accuracy=True
  933. )
  934. fails_fn = functools.partial(
  935. backend_accuracy_fails,
  936. compiler_fn=compiler_fn,
  937. only_fwd=config.repro_forward_only,
  938. )
  939. dump_state_fn(fx.GraphModule(gm, copy.deepcopy(gm.graph)), example_inputs)
  940. minifier(
  941. gm,
  942. example_inputs,
  943. module_fails=fails_fn,
  944. dump_state=dump_state_fn,
  945. )
  946. else:
  947. log.error("Input graph does not fail accuracy testing")
  948. return gm