exc.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. import os
  2. import tempfile
  3. import textwrap
  4. from functools import lru_cache
  5. if os.environ.get("TORCHINDUCTOR_WRITE_MISSING_OPS") == "1":
  6. @lru_cache(None)
  7. def _record_missing_op(target):
  8. with open(f"{tempfile.gettempdir()}/missing_ops.txt", "a") as fd:
  9. fd.write(str(target) + "\n")
  10. else:
  11. def _record_missing_op(target):
  12. pass
  13. class OperatorIssue(RuntimeError):
  14. @staticmethod
  15. def operator_str(target, args, kwargs):
  16. lines = [f"target: {target}"] + [
  17. f"args[{i}]: {arg}" for i, arg in enumerate(args)
  18. ]
  19. if kwargs:
  20. lines.append(f"kwargs: {kwargs}")
  21. return textwrap.indent("\n".join(lines), " ")
  22. class MissingOperatorWithoutDecomp(OperatorIssue):
  23. def __init__(self, target, args, kwargs):
  24. _record_missing_op(target)
  25. super().__init__(f"missing lowering\n{self.operator_str(target, args, kwargs)}")
  26. class MissingOperatorWithDecomp(OperatorIssue):
  27. def __init__(self, target, args, kwargs):
  28. _record_missing_op(target)
  29. super().__init__(
  30. f"missing decomposition\n{self.operator_str(target, args, kwargs)}"
  31. + textwrap.dedent(
  32. f"""
  33. There is a decomposition available for {target} in
  34. torch._decomp.get_decompositions(). Please add this operator to the
  35. `decompositions` list in torch._inductor.decompositions
  36. """
  37. )
  38. )
  39. class LoweringException(OperatorIssue):
  40. def __init__(self, exc, target, args, kwargs):
  41. super().__init__(
  42. f"{type(exc).__name__}: {exc}\n{self.operator_str(target, args, kwargs)}"
  43. )
  44. class InvalidCxxCompiler(RuntimeError):
  45. def __init__(self):
  46. from . import config
  47. super().__init__(
  48. f"No working C++ compiler found in {config.__name__}.cpp.cxx: {config.cpp.cxx}"
  49. )
  50. class CppCompileError(RuntimeError):
  51. def __init__(self, cmd, output):
  52. super().__init__(
  53. textwrap.dedent(
  54. """
  55. C++ compile error
  56. Command:
  57. {cmd}
  58. Output:
  59. {output}
  60. """
  61. )
  62. .strip()
  63. .format(cmd=" ".join(cmd), output=output.decode("utf-8"))
  64. )