utils.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. import inspect
  2. from typing import Any, Callable, Dict, Mapping, Tuple
  3. from torch.onnx._internal import _beartype
  4. from torch.onnx._internal.diagnostics.infra import _infra, formatter
  5. @_beartype.beartype
  6. def python_frame(frame: inspect.FrameInfo) -> _infra.StackFrame:
  7. """Returns a StackFrame for the given inspect.FrameInfo."""
  8. snippet = (
  9. frame.code_context[frame.index].strip()
  10. if frame.code_context is not None and frame.index is not None
  11. else None
  12. )
  13. return _infra.StackFrame(
  14. location=_infra.Location(
  15. uri=frame.filename,
  16. line=frame.lineno,
  17. snippet=snippet,
  18. function=frame.function,
  19. message=snippet,
  20. )
  21. )
  22. @_beartype.beartype
  23. def python_call_stack(frames_to_skip: int = 0, frames_to_log: int = 16) -> _infra.Stack:
  24. """Returns the current Python call stack."""
  25. if frames_to_skip < 0:
  26. raise ValueError("frames_to_skip must be non-negative")
  27. if frames_to_log < 0:
  28. raise ValueError("frames_to_log must be non-negative")
  29. frames_to_skip += 2 # Skip this function and beartype.
  30. stack = _infra.Stack()
  31. stack.frames = [
  32. python_frame(frame)
  33. # TODO(bowbao): Rewrite with 'traceback' to speedup performance.
  34. # Reference code: `torch/fx/proxy.py`.
  35. # `inspect.stack(0)` will speedup the call greatly, but loses line snippet.
  36. for frame in inspect.stack()[frames_to_skip : frames_to_skip + frames_to_log]
  37. ]
  38. stack.message = "Python call stack"
  39. return stack
  40. @_beartype.beartype
  41. def function_location(fn: Callable) -> _infra.Location:
  42. """Returns a Location for the given function."""
  43. source_lines, lineno = inspect.getsourcelines(fn)
  44. snippet = source_lines[0].strip() if len(source_lines) > 0 else "<unknown>"
  45. return _infra.Location(
  46. uri=inspect.getsourcefile(fn),
  47. line=lineno,
  48. snippet=snippet,
  49. message=formatter.display_name(fn),
  50. )
  51. @_beartype.beartype
  52. def function_state(
  53. fn: Callable, args: Tuple[Any, ...], kwargs: Dict[str, Any]
  54. ) -> Mapping[str, Any]:
  55. bind = inspect.signature(fn).bind(*args, **kwargs)
  56. return bind.arguments