utils.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. from typing import Any, Dict, Optional
  2. from torch import nn
  3. __all__ = [
  4. "module_to_fqn",
  5. "fqn_to_module",
  6. "get_arg_info_from_tensor_fqn",
  7. "FakeSparsity",
  8. ]
  9. def module_to_fqn(model: nn.Module, module: nn.Module, prefix: str = "") -> Optional[str]:
  10. """
  11. Returns the fqn for a module or None if module not a descendent of model.
  12. """
  13. if module is model:
  14. return ""
  15. for name, child in model.named_children():
  16. fqn = module_to_fqn(child, module, ".")
  17. if isinstance(fqn, str):
  18. return prefix + name + fqn
  19. return None
  20. def fqn_to_module(model: Optional[nn.Module], path: str) -> Optional[nn.Module]:
  21. """
  22. Given an fqn, returns the corresponding module or tensor or None if the fqn given by `path`
  23. doesn't correspond to anything. Similar to model.get_submodule(path) but works for tensors.
  24. """
  25. if path != "":
  26. for name in path.split("."):
  27. model = getattr(model, name, None)
  28. return model
  29. def get_arg_info_from_tensor_fqn(model: nn.Module, tensor_fqn: str) -> Dict[str, Any]:
  30. """
  31. Uses tensor_fqn to obtain a dict containing module_fqn, module and tensor_name
  32. """
  33. # string manip to split tensor_fqn into module_fqn and tensor_name
  34. # if tensor_fqn is 'weight' then module_fqn and tensor_name are '' and 'weight'
  35. # if tensor_fqn is 'linear.weight' then module_fqn and tensor_name are 'linear' and 'weight'
  36. tensor_name = tensor_fqn.split(".")[-1]
  37. module_fqn = tensor_fqn[: -len(tensor_name) - ("." in tensor_fqn)]
  38. module = fqn_to_module(model, module_fqn)
  39. return {
  40. "module_fqn": module_fqn,
  41. "module": module,
  42. "tensor_name": tensor_name,
  43. "tensor_fqn": tensor_fqn,
  44. }
  45. # Parametrizations
  46. class FakeSparsity(nn.Module):
  47. r"""Parametrization for the weights. Should be attached to the 'weight' or
  48. any other parmeter that requires a mask applied to it.
  49. Note::
  50. Once the mask is passed, the variable should not change the id. The
  51. contents of the mask can change, but the mask reference itself should
  52. not.
  53. """
  54. def __init__(self, mask):
  55. super().__init__()
  56. self.register_buffer("mask", mask)
  57. def forward(self, x):
  58. assert self.mask.shape == x.shape
  59. return self.mask * x
  60. def state_dict(self, *args, **kwargs):
  61. # We don't want to let the parametrizations to save the mask.
  62. # That way we make sure that the linear module doesn't store the masks
  63. # alongside their parametrizations.
  64. return {}