convert_parameters.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. import torch
  2. from typing import Iterable, Optional
  3. def parameters_to_vector(parameters: Iterable[torch.Tensor]) -> torch.Tensor:
  4. r"""Convert parameters to one vector
  5. Args:
  6. parameters (Iterable[Tensor]): an iterator of Tensors that are the
  7. parameters of a model.
  8. Returns:
  9. The parameters represented by a single vector
  10. """
  11. # Flag for the device where the parameter is located
  12. param_device = None
  13. vec = []
  14. for param in parameters:
  15. # Ensure the parameters are located in the same device
  16. param_device = _check_param_device(param, param_device)
  17. vec.append(param.view(-1))
  18. return torch.cat(vec)
  19. def vector_to_parameters(vec: torch.Tensor, parameters: Iterable[torch.Tensor]) -> None:
  20. r"""Convert one vector to the parameters
  21. Args:
  22. vec (Tensor): a single vector represents the parameters of a model.
  23. parameters (Iterable[Tensor]): an iterator of Tensors that are the
  24. parameters of a model.
  25. """
  26. # Ensure vec of type Tensor
  27. if not isinstance(vec, torch.Tensor):
  28. raise TypeError('expected torch.Tensor, but got: {}'
  29. .format(torch.typename(vec)))
  30. # Flag for the device where the parameter is located
  31. param_device = None
  32. # Pointer for slicing the vector for each parameter
  33. pointer = 0
  34. for param in parameters:
  35. # Ensure the parameters are located in the same device
  36. param_device = _check_param_device(param, param_device)
  37. # The length of the parameter
  38. num_param = param.numel()
  39. # Slice the vector, reshape it, and replace the old data of the parameter
  40. param.data = vec[pointer:pointer + num_param].view_as(param).data
  41. # Increment the pointer
  42. pointer += num_param
  43. def _check_param_device(param: torch.Tensor, old_param_device: Optional[int]) -> int:
  44. r"""This helper function is to check if the parameters are located
  45. in the same device. Currently, the conversion between model parameters
  46. and single vector form is not supported for multiple allocations,
  47. e.g. parameters in different GPUs, or mixture of CPU/GPU.
  48. Args:
  49. param ([Tensor]): a Tensor of a parameter of a model
  50. old_param_device (int): the device where the first parameter of a
  51. model is allocated.
  52. Returns:
  53. old_param_device (int): report device for the first time
  54. """
  55. # Meet the first parameter
  56. if old_param_device is None:
  57. old_param_device = param.get_device() if param.is_cuda else -1
  58. else:
  59. warn = False
  60. if param.is_cuda: # Check if in same GPU
  61. warn = (param.get_device() != old_param_device)
  62. else: # Check if in CPU
  63. warn = (old_param_device != -1)
  64. if warn:
  65. raise TypeError('Found two parameters on different devices, '
  66. 'this is currently not supported.')
  67. return old_param_device