_stats.py 810 B

123456789101112131415161718192021
  1. # NOTE! PLEASE KEEP THIS FILE *FREE* OF TORCH DEPS! IT SHOULD BE IMPORTABLE ANYWHERE.
  2. # IF YOU FEEL AN OVERWHELMING URGE TO ADD A TORCH DEP, MAKE A TRAMPOLINE FILE A LA torch._dynamo.utils
  3. # AND SCRUB AWAY TORCH NOTIONS THERE.
  4. import collections
  5. import functools
  6. from typing import OrderedDict
  7. simple_call_counter: OrderedDict[str, int] = collections.OrderedDict()
  8. def count_label(label):
  9. prev = simple_call_counter.setdefault(label, 0)
  10. simple_call_counter[label] = prev + 1
  11. def count(fn):
  12. @functools.wraps(fn)
  13. def wrapper(*args, **kwargs):
  14. if fn.__qualname__ not in simple_call_counter:
  15. simple_call_counter[fn.__qualname__] = 0
  16. simple_call_counter[fn.__qualname__] = simple_call_counter[fn.__qualname__] + 1
  17. return fn(*args, **kwargs)
  18. return wrapper