_sources.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. import ast
  2. import functools
  3. import inspect
  4. from textwrap import dedent
  5. from typing import Any, List, NamedTuple, Optional, Tuple
  6. from torch._C import ErrorReport
  7. from torch._C._jit_tree_views import SourceRangeFactory
  8. def get_source_lines_and_file(
  9. obj: Any,
  10. error_msg: Optional[str] = None,
  11. ) -> Tuple[List[str], int, Optional[str]]:
  12. """
  13. Wrapper around inspect.getsourcelines and inspect.getsourcefile.
  14. Returns: (sourcelines, file_lino, filename)
  15. """
  16. filename = None # in case getsourcefile throws
  17. try:
  18. filename = inspect.getsourcefile(obj)
  19. sourcelines, file_lineno = inspect.getsourcelines(obj)
  20. except OSError as e:
  21. msg = (
  22. f"Can't get source for {obj}. TorchScript requires source access in "
  23. "order to carry out compilation, make sure original .py files are "
  24. "available."
  25. )
  26. if error_msg:
  27. msg += "\n" + error_msg
  28. raise OSError(msg) from e
  29. return sourcelines, file_lineno, filename
  30. def normalize_source_lines(sourcelines: List[str]) -> List[str]:
  31. """
  32. This helper function accepts a list of source lines. It finds the
  33. indentation level of the function definition (`def`), then it indents
  34. all lines in the function body to a point at or greater than that
  35. level. This allows for comments and continued string literals that
  36. are at a lower indentation than the rest of the code.
  37. Args:
  38. sourcelines: function source code, separated into lines by
  39. the '\n' character
  40. Returns:
  41. A list of source lines that have been correctly aligned
  42. """
  43. def remove_prefix(text, prefix):
  44. return text[text.startswith(prefix) and len(prefix) :]
  45. # Find the line and line number containing the function definition
  46. idx = None
  47. for i, l in enumerate(sourcelines):
  48. if l.lstrip().startswith("def"):
  49. idx = i
  50. break
  51. # This will happen when the function is a lambda- we won't find "def" anywhere in the source
  52. # lines in that case. Currently trying to JIT compile a lambda will throw an error up in
  53. # `parse_def()`, but we might want to handle this case in the future.
  54. if idx is None:
  55. return sourcelines
  56. # Get a string representing the amount of leading whitespace
  57. fn_def = sourcelines[idx]
  58. whitespace = fn_def.split("def")[0]
  59. # Add this leading whitespace to all lines before and after the `def`
  60. aligned_prefix = [
  61. whitespace + remove_prefix(s, whitespace) for s in sourcelines[:idx]
  62. ]
  63. aligned_suffix = [
  64. whitespace + remove_prefix(s, whitespace) for s in sourcelines[idx + 1 :]
  65. ]
  66. # Put it together again
  67. aligned_prefix.append(fn_def)
  68. return aligned_prefix + aligned_suffix
  69. # Thin wrapper around SourceRangeFactory to store extra metadata
  70. # about the function-to-be-compiled.
  71. class SourceContext(SourceRangeFactory):
  72. def __init__(
  73. self,
  74. source,
  75. filename,
  76. file_lineno,
  77. leading_whitespace_len,
  78. uses_true_division=True,
  79. funcname=None,
  80. ):
  81. super().__init__(source, filename, file_lineno, leading_whitespace_len)
  82. self.uses_true_division = uses_true_division
  83. self.filename = filename
  84. self.funcname = funcname
  85. @functools.lru_cache(maxsize=None)
  86. def make_source_context(*args):
  87. return SourceContext(*args)
  88. def fake_range():
  89. return SourceContext("", None, 0, 0).make_raw_range(0, 1)
  90. class ParsedDef(NamedTuple):
  91. ast: ast.Module
  92. ctx: SourceContext
  93. source: str
  94. filename: Optional[str]
  95. file_lineno: int
  96. def parse_def(fn):
  97. sourcelines, file_lineno, filename = get_source_lines_and_file(
  98. fn, ErrorReport.call_stack()
  99. )
  100. sourcelines = normalize_source_lines(sourcelines)
  101. source = "".join(sourcelines)
  102. dedent_src = dedent(source)
  103. py_ast = ast.parse(dedent_src)
  104. if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
  105. raise RuntimeError(
  106. f"Expected a single top-level function: {filename}:{file_lineno}"
  107. )
  108. leading_whitespace_len = len(source.split("\n", 1)[0]) - len(
  109. dedent_src.split("\n", 1)[0]
  110. )
  111. ctx = make_source_context(
  112. source, filename, file_lineno, leading_whitespace_len, True, fn.__name__
  113. )
  114. return ParsedDef(py_ast, ctx, source, filename, file_lineno)