utils.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. __docformat__ = "restructuredtext en"
  2. __all__ = []
  3. from numpy import asanyarray, asarray, array, zeros
  4. from scipy.sparse.linalg._interface import aslinearoperator, LinearOperator, \
  5. IdentityOperator
  6. _coerce_rules = {('f','f'):'f', ('f','d'):'d', ('f','F'):'F',
  7. ('f','D'):'D', ('d','f'):'d', ('d','d'):'d',
  8. ('d','F'):'D', ('d','D'):'D', ('F','f'):'F',
  9. ('F','d'):'D', ('F','F'):'F', ('F','D'):'D',
  10. ('D','f'):'D', ('D','d'):'D', ('D','F'):'D',
  11. ('D','D'):'D'}
  12. def coerce(x,y):
  13. if x not in 'fdFD':
  14. x = 'd'
  15. if y not in 'fdFD':
  16. y = 'd'
  17. return _coerce_rules[x,y]
  18. def id(x):
  19. return x
  20. def make_system(A, M, x0, b):
  21. """Make a linear system Ax=b
  22. Parameters
  23. ----------
  24. A : LinearOperator
  25. sparse or dense matrix (or any valid input to aslinearoperator)
  26. M : {LinearOperator, Nones}
  27. preconditioner
  28. sparse or dense matrix (or any valid input to aslinearoperator)
  29. x0 : {array_like, str, None}
  30. initial guess to iterative method.
  31. ``x0 = 'Mb'`` means using the nonzero initial guess ``M @ b``.
  32. Default is `None`, which means using the zero initial guess.
  33. b : array_like
  34. right hand side
  35. Returns
  36. -------
  37. (A, M, x, b, postprocess)
  38. A : LinearOperator
  39. matrix of the linear system
  40. M : LinearOperator
  41. preconditioner
  42. x : rank 1 ndarray
  43. initial guess
  44. b : rank 1 ndarray
  45. right hand side
  46. postprocess : function
  47. converts the solution vector to the appropriate
  48. type and dimensions (e.g. (N,1) matrix)
  49. """
  50. A_ = A
  51. A = aslinearoperator(A)
  52. if A.shape[0] != A.shape[1]:
  53. raise ValueError(f'expected square matrix, but got shape={(A.shape,)}')
  54. N = A.shape[0]
  55. b = asanyarray(b)
  56. if not (b.shape == (N,1) or b.shape == (N,)):
  57. raise ValueError(f'shapes of A {A.shape} and b {b.shape} are '
  58. 'incompatible')
  59. if b.dtype.char not in 'fdFD':
  60. b = b.astype('d') # upcast non-FP types to double
  61. def postprocess(x):
  62. return x
  63. if hasattr(A,'dtype'):
  64. xtype = A.dtype.char
  65. else:
  66. xtype = A.matvec(b).dtype.char
  67. xtype = coerce(xtype, b.dtype.char)
  68. b = asarray(b,dtype=xtype) # make b the same type as x
  69. b = b.ravel()
  70. # process preconditioner
  71. if M is None:
  72. if hasattr(A_,'psolve'):
  73. psolve = A_.psolve
  74. else:
  75. psolve = id
  76. if hasattr(A_,'rpsolve'):
  77. rpsolve = A_.rpsolve
  78. else:
  79. rpsolve = id
  80. if psolve is id and rpsolve is id:
  81. M = IdentityOperator(shape=A.shape, dtype=A.dtype)
  82. else:
  83. M = LinearOperator(A.shape, matvec=psolve, rmatvec=rpsolve,
  84. dtype=A.dtype)
  85. else:
  86. M = aslinearoperator(M)
  87. if A.shape != M.shape:
  88. raise ValueError('matrix and preconditioner have different shapes')
  89. # set initial guess
  90. if x0 is None:
  91. x = zeros(N, dtype=xtype)
  92. elif isinstance(x0, str):
  93. if x0 == 'Mb': # use nonzero initial guess ``M @ b``
  94. bCopy = b.copy()
  95. x = M.matvec(bCopy)
  96. else:
  97. x = array(x0, dtype=xtype)
  98. if not (x.shape == (N, 1) or x.shape == (N,)):
  99. raise ValueError(f'shapes of A {A.shape} and '
  100. f'x0 {x.shape} are incompatible')
  101. x = x.ravel()
  102. return A, M, x, b, postprocess