_validation.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. import numpy as np
  2. from scipy.sparse import csr_matrix, isspmatrix, isspmatrix_csc
  3. from ._tools import csgraph_to_dense, csgraph_from_dense,\
  4. csgraph_masked_from_dense, csgraph_from_masked
  5. DTYPE = np.float64
  6. def validate_graph(csgraph, directed, dtype=DTYPE,
  7. csr_output=True, dense_output=True,
  8. copy_if_dense=False, copy_if_sparse=False,
  9. null_value_in=0, null_value_out=np.inf,
  10. infinity_null=True, nan_null=True):
  11. """Routine for validation and conversion of csgraph inputs"""
  12. if not (csr_output or dense_output):
  13. raise ValueError("Internal: dense or csr output must be true")
  14. # if undirected and csc storage, then transposing in-place
  15. # is quicker than later converting to csr.
  16. if (not directed) and isspmatrix_csc(csgraph):
  17. csgraph = csgraph.T
  18. if isspmatrix(csgraph):
  19. if csr_output:
  20. csgraph = csr_matrix(csgraph, dtype=DTYPE, copy=copy_if_sparse)
  21. else:
  22. csgraph = csgraph_to_dense(csgraph, null_value=null_value_out)
  23. elif np.ma.isMaskedArray(csgraph):
  24. if dense_output:
  25. mask = csgraph.mask
  26. csgraph = np.array(csgraph.data, dtype=DTYPE, copy=copy_if_dense)
  27. csgraph[mask] = null_value_out
  28. else:
  29. csgraph = csgraph_from_masked(csgraph)
  30. else:
  31. if dense_output:
  32. csgraph = csgraph_masked_from_dense(csgraph,
  33. copy=copy_if_dense,
  34. null_value=null_value_in,
  35. nan_null=nan_null,
  36. infinity_null=infinity_null)
  37. mask = csgraph.mask
  38. csgraph = np.asarray(csgraph.data, dtype=DTYPE)
  39. csgraph[mask] = null_value_out
  40. else:
  41. csgraph = csgraph_from_dense(csgraph, null_value=null_value_in,
  42. infinity_null=infinity_null,
  43. nan_null=nan_null)
  44. if csgraph.ndim != 2:
  45. raise ValueError("compressed-sparse graph must be 2-D")
  46. if csgraph.shape[0] != csgraph.shape[1]:
  47. raise ValueError("compressed-sparse graph must be shape (N, N)")
  48. return csgraph