123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119 |
- import functools
- import weakref
- import torch.nn
- from torch.nn import Module
- from .utils import ExactWeakKeyDictionary
- class MutationTracker:
- db = ExactWeakKeyDictionary()
- def __init__(self):
- self.mutation_count = 0
- self.watchers = []
- def on_mutation(self, name):
- self.mutation_count += 1
- tmp = self.watchers
- self.watchers = []
- for ref in tmp:
- guarded = ref()
- if guarded is not None:
- guarded.invalidate(ref)
- def track(self, guarded_code):
- self.watchers.append(weakref.ref(guarded_code))
- def watch(obj, guarded_code):
- """invalidate guarded_code when obj is mutated"""
- ensure_patched(type(obj))
- if obj not in MutationTracker.db:
- MutationTracker.db[obj] = MutationTracker()
- tracker = MutationTracker.db[obj]
- tracker.track(guarded_code)
- def ensure_patched(cls):
- if getattr(cls, "___needs_mutation_patch", True):
- cls.___needs_mutation_patch = False
- original_setattr = cls.__setattr__
- @functools.wraps(original_setattr)
- def custom_setattr(self, key, value):
- try:
- MutationTracker.db[self].on_mutation(key)
- except KeyError:
- pass
- return original_setattr(self, key, value)
- cls.__setattr__ = custom_setattr
- class GenerationTracker:
- generation = 0
- dynamic_classes = ExactWeakKeyDictionary()
- generation_values = ExactWeakKeyDictionary()
- @classmethod
- def tag(cls, obj):
- cls.generation_values[obj] = cls.generation
- @staticmethod
- def mark_class_dynamic(cls):
- assert issubclass(cls, torch.nn.Module)
- GenerationTracker.dynamic_classes[cls] = True
- @classmethod
- def get_generation_value(cls, obj):
- if obj not in cls.generation_values:
- return -1
- return cls.generation_values[obj]
- @classmethod
- def check(cls, obj):
- return (
- obj in cls.generation_values
- and cls.generation_values[obj] == cls.generation
- )
- def is_dynamic_nn_module(obj):
- """Check for nn.Modules() created dynamically or mutated"""
- if hasattr(obj, "torchdynamo_force_dynamic"):
- return obj.torchdynamo_force_dynamic
- dyn = GenerationTracker.dynamic_classes.get(type(obj)) or GenerationTracker.check(
- obj
- )
- return dyn
- def install_generation_tagging_init():
- """
- Monkey patch torch.nn.Module.__init__ and torch.nn.Module.__setstate__
- so we can detect nn.Module instances created dynamically inside forward methods.
- """
- if getattr(Module, "___needs_generation_tag_patch", True):
- init = Module.__init__
- def patched_init(self, *args, **kwargs):
- init(self, *args, **kwargs)
- GenerationTracker.tag(self)
- Module.__init__ = patched_init
- setstate = Module.__setstate__
- def patched_setstate(self, state):
- setstate(self, state)
- GenerationTracker.tag(self)
- Module.__setstate__ = patched_setstate
- Module.___needs_generation_tag_patch = False
- GenerationTracker.generation += 1
|