_convert_np.py 856 B

12345678910111213141516171819202122232425262728293031323334353637383940
  1. """
  2. This module converts objects into numpy array.
  3. """
  4. import numpy as np
  5. import torch
  6. def make_np(x):
  7. """
  8. Args:
  9. x: An instance of torch tensor or caffe blob name
  10. Returns:
  11. numpy.array: Numpy array
  12. """
  13. if isinstance(x, np.ndarray):
  14. return x
  15. if isinstance(x, str): # Caffe2 will pass name of blob(s) to fetch
  16. return _prepare_caffe2(x)
  17. if np.isscalar(x):
  18. return np.array([x])
  19. if isinstance(x, torch.Tensor):
  20. return _prepare_pytorch(x)
  21. raise NotImplementedError(
  22. "Got {}, but numpy array, torch tensor, or caffe2 blob name are expected.".format(
  23. type(x)
  24. )
  25. )
  26. def _prepare_pytorch(x):
  27. x = x.detach().cpu().numpy()
  28. return x
  29. def _prepare_caffe2(x):
  30. from caffe2.python import workspace
  31. x = workspace.FetchBlob(x)
  32. return x