local.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. import threading
  2. from contextlib import contextmanager
  3. from typing import Iterator, Optional
  4. # Simple dynamic scoping implementation. The name "parametrize" comes
  5. # from Racket.
  6. #
  7. # WARNING WARNING: LOOKING TO EDIT THIS FILE? Think carefully about
  8. # why you need to add a toggle to the global behavior of code
  9. # generation. The parameters here should really only be used
  10. # for "temporary" situations, where we need to temporarily change
  11. # the codegen in some cases because we cannot conveniently update
  12. # all call sites, and are slated to be eliminated once all call
  13. # sites are eliminated. If you don't have a plan for how to get there,
  14. # DON'T add a new entry here.
  15. class Locals(threading.local):
  16. use_const_ref_for_mutable_tensors: Optional[bool] = None
  17. use_ilistref_for_tensor_lists: Optional[bool] = None
  18. _locals = Locals()
  19. def use_const_ref_for_mutable_tensors() -> bool:
  20. assert _locals.use_const_ref_for_mutable_tensors is not None, (
  21. "need to initialize local.use_const_ref_for_mutable_tensors with "
  22. "local.parametrize"
  23. )
  24. return _locals.use_const_ref_for_mutable_tensors
  25. def use_ilistref_for_tensor_lists() -> bool:
  26. assert _locals.use_ilistref_for_tensor_lists is not None, (
  27. "need to initialize local.use_ilistref_for_tensor_lists with "
  28. "local.parametrize"
  29. )
  30. return _locals.use_ilistref_for_tensor_lists
  31. @contextmanager
  32. def parametrize(
  33. *, use_const_ref_for_mutable_tensors: bool, use_ilistref_for_tensor_lists: bool
  34. ) -> Iterator[None]:
  35. old_use_const_ref_for_mutable_tensors = _locals.use_const_ref_for_mutable_tensors
  36. old_use_ilistref_for_tensor_lists = _locals.use_ilistref_for_tensor_lists
  37. try:
  38. _locals.use_const_ref_for_mutable_tensors = use_const_ref_for_mutable_tensors
  39. _locals.use_ilistref_for_tensor_lists = use_ilistref_for_tensor_lists
  40. yield
  41. finally:
  42. _locals.use_const_ref_for_mutable_tensors = (
  43. old_use_const_ref_for_mutable_tensors
  44. )
  45. _locals.use_ilistref_for_tensor_lists = old_use_ilistref_for_tensor_lists