1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980 |
- from typing import Any, Dict, Optional
- from torch import nn
- __all__ = [
- "module_to_fqn",
- "fqn_to_module",
- "get_arg_info_from_tensor_fqn",
- "FakeSparsity",
- ]
- def module_to_fqn(model: nn.Module, module: nn.Module, prefix: str = "") -> Optional[str]:
- """
- Returns the fqn for a module or None if module not a descendent of model.
- """
- if module is model:
- return ""
- for name, child in model.named_children():
- fqn = module_to_fqn(child, module, ".")
- if isinstance(fqn, str):
- return prefix + name + fqn
- return None
- def fqn_to_module(model: Optional[nn.Module], path: str) -> Optional[nn.Module]:
- """
- Given an fqn, returns the corresponding module or tensor or None if the fqn given by `path`
- doesn't correspond to anything. Similar to model.get_submodule(path) but works for tensors.
- """
- if path != "":
- for name in path.split("."):
- model = getattr(model, name, None)
- return model
- def get_arg_info_from_tensor_fqn(model: nn.Module, tensor_fqn: str) -> Dict[str, Any]:
- """
- Uses tensor_fqn to obtain a dict containing module_fqn, module and tensor_name
- """
- # string manip to split tensor_fqn into module_fqn and tensor_name
- # if tensor_fqn is 'weight' then module_fqn and tensor_name are '' and 'weight'
- # if tensor_fqn is 'linear.weight' then module_fqn and tensor_name are 'linear' and 'weight'
- tensor_name = tensor_fqn.split(".")[-1]
- module_fqn = tensor_fqn[: -len(tensor_name) - ("." in tensor_fqn)]
- module = fqn_to_module(model, module_fqn)
- return {
- "module_fqn": module_fqn,
- "module": module,
- "tensor_name": tensor_name,
- "tensor_fqn": tensor_fqn,
- }
- # Parametrizations
- class FakeSparsity(nn.Module):
- r"""Parametrization for the weights. Should be attached to the 'weight' or
- any other parmeter that requires a mask applied to it.
- Note::
- Once the mask is passed, the variable should not change the id. The
- contents of the mask can change, but the mask reference itself should
- not.
- """
- def __init__(self, mask):
- super().__init__()
- self.register_buffer("mask", mask)
- def forward(self, x):
- assert self.mask.shape == x.shape
- return self.mask * x
- def state_dict(self, *args, **kwargs):
- # We don't want to let the parametrizations to save the mask.
- # That way we make sure that the linear module doesn't store the masks
- # alongside their parametrizations.
- return {}
|