1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556 |
- import threading
- from contextlib import contextmanager
- from typing import Iterator, Optional
- # Simple dynamic scoping implementation. The name "parametrize" comes
- # from Racket.
- #
- # WARNING WARNING: LOOKING TO EDIT THIS FILE? Think carefully about
- # why you need to add a toggle to the global behavior of code
- # generation. The parameters here should really only be used
- # for "temporary" situations, where we need to temporarily change
- # the codegen in some cases because we cannot conveniently update
- # all call sites, and are slated to be eliminated once all call
- # sites are eliminated. If you don't have a plan for how to get there,
- # DON'T add a new entry here.
- class Locals(threading.local):
- use_const_ref_for_mutable_tensors: Optional[bool] = None
- use_ilistref_for_tensor_lists: Optional[bool] = None
- _locals = Locals()
- def use_const_ref_for_mutable_tensors() -> bool:
- assert _locals.use_const_ref_for_mutable_tensors is not None, (
- "need to initialize local.use_const_ref_for_mutable_tensors with "
- "local.parametrize"
- )
- return _locals.use_const_ref_for_mutable_tensors
- def use_ilistref_for_tensor_lists() -> bool:
- assert _locals.use_ilistref_for_tensor_lists is not None, (
- "need to initialize local.use_ilistref_for_tensor_lists with "
- "local.parametrize"
- )
- return _locals.use_ilistref_for_tensor_lists
- @contextmanager
- def parametrize(
- *, use_const_ref_for_mutable_tensors: bool, use_ilistref_for_tensor_lists: bool
- ) -> Iterator[None]:
- old_use_const_ref_for_mutable_tensors = _locals.use_const_ref_for_mutable_tensors
- old_use_ilistref_for_tensor_lists = _locals.use_ilistref_for_tensor_lists
- try:
- _locals.use_const_ref_for_mutable_tensors = use_const_ref_for_mutable_tensors
- _locals.use_ilistref_for_tensor_lists = use_ilistref_for_tensor_lists
- yield
- finally:
- _locals.use_const_ref_for_mutable_tensors = (
- old_use_const_ref_for_mutable_tensors
- )
- _locals.use_ilistref_for_tensor_lists = old_use_ilistref_for_tensor_lists
|