trf_linear.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. """The adaptation of Trust Region Reflective algorithm for a linear
  2. least-squares problem."""
  3. import numpy as np
  4. from numpy.linalg import norm
  5. from scipy.linalg import qr, solve_triangular
  6. from scipy.sparse.linalg import lsmr
  7. from scipy.optimize import OptimizeResult
  8. from .givens_elimination import givens_elimination
  9. from .common import (
  10. EPS, step_size_to_bound, find_active_constraints, in_bounds,
  11. make_strictly_feasible, build_quadratic_1d, evaluate_quadratic,
  12. minimize_quadratic_1d, CL_scaling_vector, reflective_transformation,
  13. print_header_linear, print_iteration_linear, compute_grad,
  14. regularized_lsq_operator, right_multiplied_operator)
  15. def regularized_lsq_with_qr(m, n, R, QTb, perm, diag, copy_R=True):
  16. """Solve regularized least squares using information from QR-decomposition.
  17. The initial problem is to solve the following system in a least-squares
  18. sense::
  19. A x = b
  20. D x = 0
  21. where D is diagonal matrix. The method is based on QR decomposition
  22. of the form A P = Q R, where P is a column permutation matrix, Q is an
  23. orthogonal matrix and R is an upper triangular matrix.
  24. Parameters
  25. ----------
  26. m, n : int
  27. Initial shape of A.
  28. R : ndarray, shape (n, n)
  29. Upper triangular matrix from QR decomposition of A.
  30. QTb : ndarray, shape (n,)
  31. First n components of Q^T b.
  32. perm : ndarray, shape (n,)
  33. Array defining column permutation of A, such that ith column of
  34. P is perm[i]-th column of identity matrix.
  35. diag : ndarray, shape (n,)
  36. Array containing diagonal elements of D.
  37. Returns
  38. -------
  39. x : ndarray, shape (n,)
  40. Found least-squares solution.
  41. """
  42. if copy_R:
  43. R = R.copy()
  44. v = QTb.copy()
  45. givens_elimination(R, v, diag[perm])
  46. abs_diag_R = np.abs(np.diag(R))
  47. threshold = EPS * max(m, n) * np.max(abs_diag_R)
  48. nns, = np.nonzero(abs_diag_R > threshold)
  49. R = R[np.ix_(nns, nns)]
  50. v = v[nns]
  51. x = np.zeros(n)
  52. x[perm[nns]] = solve_triangular(R, v)
  53. return x
  54. def backtracking(A, g, x, p, theta, p_dot_g, lb, ub):
  55. """Find an appropriate step size using backtracking line search."""
  56. alpha = 1
  57. while True:
  58. x_new, _ = reflective_transformation(x + alpha * p, lb, ub)
  59. step = x_new - x
  60. cost_change = -evaluate_quadratic(A, g, step)
  61. if cost_change > -0.1 * alpha * p_dot_g:
  62. break
  63. alpha *= 0.5
  64. active = find_active_constraints(x_new, lb, ub)
  65. if np.any(active != 0):
  66. x_new, _ = reflective_transformation(x + theta * alpha * p, lb, ub)
  67. x_new = make_strictly_feasible(x_new, lb, ub, rstep=0)
  68. step = x_new - x
  69. cost_change = -evaluate_quadratic(A, g, step)
  70. return x, step, cost_change
  71. def select_step(x, A_h, g_h, c_h, p, p_h, d, lb, ub, theta):
  72. """Select the best step according to Trust Region Reflective algorithm."""
  73. if in_bounds(x + p, lb, ub):
  74. return p
  75. p_stride, hits = step_size_to_bound(x, p, lb, ub)
  76. r_h = np.copy(p_h)
  77. r_h[hits.astype(bool)] *= -1
  78. r = d * r_h
  79. # Restrict step, such that it hits the bound.
  80. p *= p_stride
  81. p_h *= p_stride
  82. x_on_bound = x + p
  83. # Find the step size along reflected direction.
  84. r_stride_u, _ = step_size_to_bound(x_on_bound, r, lb, ub)
  85. # Stay interior.
  86. r_stride_l = (1 - theta) * r_stride_u
  87. r_stride_u *= theta
  88. if r_stride_u > 0:
  89. a, b, c = build_quadratic_1d(A_h, g_h, r_h, s0=p_h, diag=c_h)
  90. r_stride, r_value = minimize_quadratic_1d(
  91. a, b, r_stride_l, r_stride_u, c=c)
  92. r_h = p_h + r_h * r_stride
  93. r = d * r_h
  94. else:
  95. r_value = np.inf
  96. # Now correct p_h to make it strictly interior.
  97. p_h *= theta
  98. p *= theta
  99. p_value = evaluate_quadratic(A_h, g_h, p_h, diag=c_h)
  100. ag_h = -g_h
  101. ag = d * ag_h
  102. ag_stride_u, _ = step_size_to_bound(x, ag, lb, ub)
  103. ag_stride_u *= theta
  104. a, b = build_quadratic_1d(A_h, g_h, ag_h, diag=c_h)
  105. ag_stride, ag_value = minimize_quadratic_1d(a, b, 0, ag_stride_u)
  106. ag *= ag_stride
  107. if p_value < r_value and p_value < ag_value:
  108. return p
  109. elif r_value < p_value and r_value < ag_value:
  110. return r
  111. else:
  112. return ag
  113. def trf_linear(A, b, x_lsq, lb, ub, tol, lsq_solver, lsmr_tol,
  114. max_iter, verbose, *, lsmr_maxiter=None):
  115. m, n = A.shape
  116. x, _ = reflective_transformation(x_lsq, lb, ub)
  117. x = make_strictly_feasible(x, lb, ub, rstep=0.1)
  118. if lsq_solver == 'exact':
  119. QT, R, perm = qr(A, mode='economic', pivoting=True)
  120. QT = QT.T
  121. if m < n:
  122. R = np.vstack((R, np.zeros((n - m, n))))
  123. QTr = np.zeros(n)
  124. k = min(m, n)
  125. elif lsq_solver == 'lsmr':
  126. r_aug = np.zeros(m + n)
  127. auto_lsmr_tol = False
  128. if lsmr_tol is None:
  129. lsmr_tol = 1e-2 * tol
  130. elif lsmr_tol == 'auto':
  131. auto_lsmr_tol = True
  132. r = A.dot(x) - b
  133. g = compute_grad(A, r)
  134. cost = 0.5 * np.dot(r, r)
  135. initial_cost = cost
  136. termination_status = None
  137. step_norm = None
  138. cost_change = None
  139. if max_iter is None:
  140. max_iter = 100
  141. if verbose == 2:
  142. print_header_linear()
  143. for iteration in range(max_iter):
  144. v, dv = CL_scaling_vector(x, g, lb, ub)
  145. g_scaled = g * v
  146. g_norm = norm(g_scaled, ord=np.inf)
  147. if g_norm < tol:
  148. termination_status = 1
  149. if verbose == 2:
  150. print_iteration_linear(iteration, cost, cost_change,
  151. step_norm, g_norm)
  152. if termination_status is not None:
  153. break
  154. diag_h = g * dv
  155. diag_root_h = diag_h ** 0.5
  156. d = v ** 0.5
  157. g_h = d * g
  158. A_h = right_multiplied_operator(A, d)
  159. if lsq_solver == 'exact':
  160. QTr[:k] = QT.dot(r)
  161. p_h = -regularized_lsq_with_qr(m, n, R * d[perm], QTr, perm,
  162. diag_root_h, copy_R=False)
  163. elif lsq_solver == 'lsmr':
  164. lsmr_op = regularized_lsq_operator(A_h, diag_root_h)
  165. r_aug[:m] = r
  166. if auto_lsmr_tol:
  167. eta = 1e-2 * min(0.5, g_norm)
  168. lsmr_tol = max(EPS, min(0.1, eta * g_norm))
  169. p_h = -lsmr(lsmr_op, r_aug, maxiter=lsmr_maxiter,
  170. atol=lsmr_tol, btol=lsmr_tol)[0]
  171. p = d * p_h
  172. p_dot_g = np.dot(p, g)
  173. if p_dot_g > 0:
  174. termination_status = -1
  175. theta = 1 - min(0.005, g_norm)
  176. step = select_step(x, A_h, g_h, diag_h, p, p_h, d, lb, ub, theta)
  177. cost_change = -evaluate_quadratic(A, g, step)
  178. # Perhaps almost never executed, the idea is that `p` is descent
  179. # direction thus we must find acceptable cost decrease using simple
  180. # "backtracking", otherwise the algorithm's logic would break.
  181. if cost_change < 0:
  182. x, step, cost_change = backtracking(
  183. A, g, x, p, theta, p_dot_g, lb, ub)
  184. else:
  185. x = make_strictly_feasible(x + step, lb, ub, rstep=0)
  186. step_norm = norm(step)
  187. r = A.dot(x) - b
  188. g = compute_grad(A, r)
  189. if cost_change < tol * cost:
  190. termination_status = 2
  191. cost = 0.5 * np.dot(r, r)
  192. if termination_status is None:
  193. termination_status = 0
  194. active_mask = find_active_constraints(x, lb, ub, rtol=tol)
  195. return OptimizeResult(
  196. x=x, fun=r, cost=cost, optimality=g_norm, active_mask=active_mask,
  197. nit=iteration + 1, status=termination_status,
  198. initial_cost=initial_cost)