123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137 |
- import ast
- import functools
- import inspect
- from textwrap import dedent
- from typing import Any, List, NamedTuple, Optional, Tuple
- from torch._C import ErrorReport
- from torch._C._jit_tree_views import SourceRangeFactory
- def get_source_lines_and_file(
- obj: Any,
- error_msg: Optional[str] = None,
- ) -> Tuple[List[str], int, Optional[str]]:
- """
- Wrapper around inspect.getsourcelines and inspect.getsourcefile.
- Returns: (sourcelines, file_lino, filename)
- """
- filename = None # in case getsourcefile throws
- try:
- filename = inspect.getsourcefile(obj)
- sourcelines, file_lineno = inspect.getsourcelines(obj)
- except OSError as e:
- msg = (
- f"Can't get source for {obj}. TorchScript requires source access in "
- "order to carry out compilation, make sure original .py files are "
- "available."
- )
- if error_msg:
- msg += "\n" + error_msg
- raise OSError(msg) from e
- return sourcelines, file_lineno, filename
- def normalize_source_lines(sourcelines: List[str]) -> List[str]:
- """
- This helper function accepts a list of source lines. It finds the
- indentation level of the function definition (`def`), then it indents
- all lines in the function body to a point at or greater than that
- level. This allows for comments and continued string literals that
- are at a lower indentation than the rest of the code.
- Args:
- sourcelines: function source code, separated into lines by
- the '\n' character
- Returns:
- A list of source lines that have been correctly aligned
- """
- def remove_prefix(text, prefix):
- return text[text.startswith(prefix) and len(prefix) :]
- # Find the line and line number containing the function definition
- idx = None
- for i, l in enumerate(sourcelines):
- if l.lstrip().startswith("def"):
- idx = i
- break
- # This will happen when the function is a lambda- we won't find "def" anywhere in the source
- # lines in that case. Currently trying to JIT compile a lambda will throw an error up in
- # `parse_def()`, but we might want to handle this case in the future.
- if idx is None:
- return sourcelines
- # Get a string representing the amount of leading whitespace
- fn_def = sourcelines[idx]
- whitespace = fn_def.split("def")[0]
- # Add this leading whitespace to all lines before and after the `def`
- aligned_prefix = [
- whitespace + remove_prefix(s, whitespace) for s in sourcelines[:idx]
- ]
- aligned_suffix = [
- whitespace + remove_prefix(s, whitespace) for s in sourcelines[idx + 1 :]
- ]
- # Put it together again
- aligned_prefix.append(fn_def)
- return aligned_prefix + aligned_suffix
- # Thin wrapper around SourceRangeFactory to store extra metadata
- # about the function-to-be-compiled.
- class SourceContext(SourceRangeFactory):
- def __init__(
- self,
- source,
- filename,
- file_lineno,
- leading_whitespace_len,
- uses_true_division=True,
- funcname=None,
- ):
- super().__init__(source, filename, file_lineno, leading_whitespace_len)
- self.uses_true_division = uses_true_division
- self.filename = filename
- self.funcname = funcname
- @functools.lru_cache(maxsize=None)
- def make_source_context(*args):
- return SourceContext(*args)
- def fake_range():
- return SourceContext("", None, 0, 0).make_raw_range(0, 1)
- class ParsedDef(NamedTuple):
- ast: ast.Module
- ctx: SourceContext
- source: str
- filename: Optional[str]
- file_lineno: int
- def parse_def(fn):
- sourcelines, file_lineno, filename = get_source_lines_and_file(
- fn, ErrorReport.call_stack()
- )
- sourcelines = normalize_source_lines(sourcelines)
- source = "".join(sourcelines)
- dedent_src = dedent(source)
- py_ast = ast.parse(dedent_src)
- if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
- raise RuntimeError(
- f"Expected a single top-level function: {filename}:{file_lineno}"
- )
- leading_whitespace_len = len(source.split("\n", 1)[0]) - len(
- dedent_src.split("\n", 1)[0]
- )
- ctx = make_source_context(
- source, filename, file_lineno, leading_whitespace_len, True, fn.__name__
- )
- return ParsedDef(py_ast, ctx, source, filename, file_lineno)
|