12345678910111213141516171819202122232425262728293031323334353637383940 |
- """
- This module converts objects into numpy array.
- """
- import numpy as np
- import torch
- def make_np(x):
- """
- Args:
- x: An instance of torch tensor or caffe blob name
- Returns:
- numpy.array: Numpy array
- """
- if isinstance(x, np.ndarray):
- return x
- if isinstance(x, str): # Caffe2 will pass name of blob(s) to fetch
- return _prepare_caffe2(x)
- if np.isscalar(x):
- return np.array([x])
- if isinstance(x, torch.Tensor):
- return _prepare_pytorch(x)
- raise NotImplementedError(
- "Got {}, but numpy array, torch tensor, or caffe2 blob name are expected.".format(
- type(x)
- )
- )
- def _prepare_pytorch(x):
- x = x.detach().cpu().numpy()
- return x
- def _prepare_caffe2(x):
- from caffe2.python import workspace
- x = workspace.FetchBlob(x)
- return x
|