123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119 |
- """JIT-related state
- This module stores various pieces of Python-global state relating to the JIT.
- This is not intended to be imported directly; please the exposed
- functionalities in `torch.jit`.
- """
- import torch
- import os
- import weakref
- class EnabledProxy:
- """Stores whether the JIT is enabled or not.
- This is just a wrapper for a bool, so that we get reference semantics
- """
- def __init__(self):
- self.enabled = self.parse_env(
- "PYTORCH_JIT", True, "> Using PyTorch JIT", "> PyTorch JIT DISABLED"
- )
- def parse_env(self, name, default, true_message, false_message):
- value = os.environ.get(name)
- if value is None:
- return default
- if value.lower() in {"1", "true", "yes"}:
- return True
- elif value.lower() in {"0", "false", "no"}:
- return False
- if value == "1v":
- print(true_message)
- return True
- elif value == "0v":
- print(false_message)
- return False
- raise ValueError("Unknown setting of {}. Try using 0 or 1.".format(name))
- def __bool__(self):
- return self.enabled
- _enabled = EnabledProxy()
- def disable():
- _enabled.enabled = False
- def enable():
- _enabled.enabled = True
- # The Python CompilationUnit. All functions and modules defined in Python will
- # live in here. It's defined in Python because doing in cpp creates static
- # destruction order issues.
- _python_cu = torch._C.CompilationUnit()
- # python class => ScriptClass mapping
- _script_classes = {}
- _name_to_pyclass = {}
- def _add_script_class(python_class, script_class):
- _script_classes[python_class] = script_class
- _name_to_pyclass[script_class.qualified_name()] = python_class
- def _get_script_class(python_class):
- override = getattr(python_class, "_jit_override_qualname", None)
- if override is not None:
- python_class = _get_python_class(override)
- return _script_classes.get(python_class, None)
- def _get_python_class(qualified_name):
- return _name_to_pyclass.get(qualified_name, None)
- def _clear_class_state():
- _script_classes.clear()
- _name_to_pyclass.clear()
- # Caching: we currently cache compilation of free functions and overloaded functions.
- # To cache free functions we hold a weak ref to the function object and
- # map to the compiled fn's qualified name.
- # To cache overloaded functions we hold a weak ref to the function obj and
- # map to all of its overloaded compiled fns.
- # In the future we could consider caching more types of objects so that
- # aliasing is preserved across separate compilations of the same object.
- _jit_caching_layer: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
- _jit_function_overload_caching: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
- def _try_get_jit_cached_overloads(key):
- qual_names = _jit_function_overload_caching.get(key, None)
- if qual_names:
- return [_python_cu.find_function(qual_name) for qual_name in qual_names]
- else:
- return None
- def _set_jit_overload_cache(key, compiled_fns):
- _jit_function_overload_caching[key] = [fn.qualified_name for fn in compiled_fns]
- def _try_get_jit_cached_function(key):
- if getattr(key, "__disable_jit_function_caching__", False) is True:
- return None
- qual_name = _jit_caching_layer.get(key, None)
- if qual_name:
- return _python_cu.find_function(qual_name)
- else:
- return None
- def _set_jit_function_cache(key, value):
- # only free functions currently supported
- assert isinstance(value, torch.jit.ScriptFunction)
- _jit_caching_layer[key] = value.qualified_name
|