common.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. import functools
  2. import logging
  3. import torch
  4. from torch._dynamo import eval_frame
  5. from torch._dynamo.utils import counters
  6. from torch._functorch.aot_autograd import aot_module_simplified
  7. from torch._subclasses import FakeTensor
  8. from torch.utils._python_dispatch import _disable_current_modes
  9. log = logging.getLogger(__name__)
  10. def aot_autograd(**kwargs):
  11. def compiler_fn(gm: torch.fx.GraphModule, example_inputs):
  12. import functorch.compile
  13. # Hack to get around circular import problems with aot_eager_decomp_partition
  14. if callable(kwargs.get("decompositions")):
  15. kwargs["decompositions"] = kwargs["decompositions"]()
  16. # TODO: stop monkeypatching here (without even cleaning up, UGH!)
  17. functorch.compile.config.use_functionalize = True
  18. functorch.compile.config.use_fake_tensor = True
  19. counters["aot_autograd"]["total"] += 1
  20. use_fallback = False
  21. if use_fallback:
  22. log.debug("Unable to use AOT Autograd because graph has mutation")
  23. counters["aot_autograd"]["not_ok"] += 1
  24. return gm
  25. # OK attempt to compile
  26. def _wrapped_bw_compiler(*args, **kwargs):
  27. # stop TorchDynamo from trying to compile our generated backwards pass
  28. return eval_frame.disable(eval_frame.disable(bw_compiler)(*args, **kwargs))
  29. bw_compiler = kwargs.get("bw_compiler") or kwargs["fw_compiler"]
  30. kwargs["bw_compiler"] = _wrapped_bw_compiler
  31. from torch._inductor.debug import enable_aot_logging
  32. try:
  33. # NB: NOT cloned!
  34. with enable_aot_logging():
  35. cg = aot_module_simplified(gm, example_inputs, **kwargs)
  36. counters["aot_autograd"]["ok"] += 1
  37. return eval_frame.disable(cg)
  38. except Exception:
  39. counters["aot_autograd"]["not_ok"] += 1
  40. raise
  41. return compiler_fn
  42. def mem_efficient_fusion_kwargs(use_decomps):
  43. from functorch.compile import (
  44. default_decompositions,
  45. min_cut_rematerialization_partition,
  46. ts_compile,
  47. )
  48. kwargs = {
  49. # these are taken from memory_efficient_fusion()
  50. "fw_compiler": ts_compile,
  51. "bw_compiler": ts_compile,
  52. "partition_fn": min_cut_rematerialization_partition,
  53. }
  54. if use_decomps:
  55. kwargs["decompositions"] = default_decompositions
  56. return kwargs
  57. def fake_tensor_unsupported(fn):
  58. """
  59. Decorator for backends that need real inputs. We swap out fake
  60. tensors for zero tensors.
  61. """
  62. def defake(x):
  63. if not isinstance(x, FakeTensor):
  64. return x
  65. if x._has_symbolic_sizes_strides:
  66. size = [s.node.shape_env.size_hint(s.node.expr) for s in x.size()]
  67. stride = [s.node.shape_env.size_hint(s.node.expr) for s in x.stride()]
  68. else:
  69. size = x.size()
  70. stride = x.stride()
  71. y = torch.empty_strided(
  72. size,
  73. stride,
  74. dtype=x.dtype,
  75. device=x.device,
  76. requires_grad=x.requires_grad,
  77. )
  78. y.zero_()
  79. return y
  80. @functools.wraps(fn)
  81. def wrapper(model, inputs, **kwargs):
  82. with _disable_current_modes():
  83. inputs = list(map(defake, inputs))
  84. return fn(model, inputs, **kwargs)
  85. return wrapper
  86. def device_from_inputs(example_inputs) -> torch.device:
  87. for x in example_inputs:
  88. if hasattr(x, "device"):
  89. return x.device
  90. def dtype_from_inputs(example_inputs) -> torch.dtype:
  91. for x in example_inputs:
  92. if hasattr(x, "dtype"):
  93. return x.dtype