mutation_guard.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. import functools
  2. import weakref
  3. import torch.nn
  4. from torch.nn import Module
  5. from .utils import ExactWeakKeyDictionary
  6. class MutationTracker:
  7. db = ExactWeakKeyDictionary()
  8. def __init__(self):
  9. self.mutation_count = 0
  10. self.watchers = []
  11. def on_mutation(self, name):
  12. self.mutation_count += 1
  13. tmp = self.watchers
  14. self.watchers = []
  15. for ref in tmp:
  16. guarded = ref()
  17. if guarded is not None:
  18. guarded.invalidate(ref)
  19. def track(self, guarded_code):
  20. self.watchers.append(weakref.ref(guarded_code))
  21. def watch(obj, guarded_code):
  22. """invalidate guarded_code when obj is mutated"""
  23. ensure_patched(type(obj))
  24. if obj not in MutationTracker.db:
  25. MutationTracker.db[obj] = MutationTracker()
  26. tracker = MutationTracker.db[obj]
  27. tracker.track(guarded_code)
  28. def ensure_patched(cls):
  29. if getattr(cls, "___needs_mutation_patch", True):
  30. cls.___needs_mutation_patch = False
  31. original_setattr = cls.__setattr__
  32. @functools.wraps(original_setattr)
  33. def custom_setattr(self, key, value):
  34. try:
  35. MutationTracker.db[self].on_mutation(key)
  36. except KeyError:
  37. pass
  38. return original_setattr(self, key, value)
  39. cls.__setattr__ = custom_setattr
  40. class GenerationTracker:
  41. generation = 0
  42. dynamic_classes = ExactWeakKeyDictionary()
  43. generation_values = ExactWeakKeyDictionary()
  44. @classmethod
  45. def tag(cls, obj):
  46. cls.generation_values[obj] = cls.generation
  47. @staticmethod
  48. def mark_class_dynamic(cls):
  49. assert issubclass(cls, torch.nn.Module)
  50. GenerationTracker.dynamic_classes[cls] = True
  51. @classmethod
  52. def get_generation_value(cls, obj):
  53. if obj not in cls.generation_values:
  54. return -1
  55. return cls.generation_values[obj]
  56. @classmethod
  57. def check(cls, obj):
  58. return (
  59. obj in cls.generation_values
  60. and cls.generation_values[obj] == cls.generation
  61. )
  62. def is_dynamic_nn_module(obj):
  63. """Check for nn.Modules() created dynamically or mutated"""
  64. if hasattr(obj, "torchdynamo_force_dynamic"):
  65. return obj.torchdynamo_force_dynamic
  66. dyn = GenerationTracker.dynamic_classes.get(type(obj)) or GenerationTracker.check(
  67. obj
  68. )
  69. return dyn
  70. def install_generation_tagging_init():
  71. """
  72. Monkey patch torch.nn.Module.__init__ and torch.nn.Module.__setstate__
  73. so we can detect nn.Module instances created dynamically inside forward methods.
  74. """
  75. if getattr(Module, "___needs_generation_tag_patch", True):
  76. init = Module.__init__
  77. def patched_init(self, *args, **kwargs):
  78. init(self, *args, **kwargs)
  79. GenerationTracker.tag(self)
  80. Module.__init__ = patched_init
  81. setstate = Module.__setstate__
  82. def patched_setstate(self, state):
  83. setstate(self, state)
  84. GenerationTracker.tag(self)
  85. Module.__setstate__ = patched_setstate
  86. Module.___needs_generation_tag_patch = False
  87. GenerationTracker.generation += 1