__init__.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. from contextlib import contextmanager
  2. import types
  3. # The idea for this parameter is that we forbid bare assignment
  4. # to torch.backends.<cudnn|mkldnn>.enabled and friends when running our
  5. # test suite, where it's very easy to forget to undo the change
  6. # later.
  7. __allow_nonbracketed_mutation_flag = True
  8. def disable_global_flags():
  9. global __allow_nonbracketed_mutation_flag
  10. __allow_nonbracketed_mutation_flag = False
  11. def flags_frozen():
  12. return not __allow_nonbracketed_mutation_flag
  13. @contextmanager
  14. def __allow_nonbracketed_mutation():
  15. global __allow_nonbracketed_mutation_flag
  16. old = __allow_nonbracketed_mutation_flag
  17. __allow_nonbracketed_mutation_flag = True
  18. try:
  19. yield
  20. finally:
  21. __allow_nonbracketed_mutation_flag = old
  22. class ContextProp:
  23. def __init__(self, getter, setter):
  24. self.getter = getter
  25. self.setter = setter
  26. def __get__(self, obj, objtype):
  27. return self.getter()
  28. def __set__(self, obj, val):
  29. if not flags_frozen():
  30. self.setter(val)
  31. else:
  32. raise RuntimeError("not allowed to set %s flags "
  33. "after disable_global_flags; please use flags() context manager instead" % obj.__name__)
  34. class PropModule(types.ModuleType):
  35. def __init__(self, m, name):
  36. super().__init__(name)
  37. self.m = m
  38. def __getattr__(self, attr):
  39. return self.m.__getattribute__(attr)