lll.py 3.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. from __future__ import annotations
  2. from math import floor as mfloor
  3. from sympy.polys.domains import ZZ, QQ
  4. from sympy.polys.matrices.exceptions import DMRankError, DMShapeError, DMValueError, DMDomainError
  5. def _ddm_lll(x, delta=QQ(3, 4), return_transform=False):
  6. if QQ(1, 4) >= delta or delta >= QQ(1, 1):
  7. raise DMValueError("delta must lie in range (0.25, 1)")
  8. if x.shape[0] > x.shape[1]:
  9. raise DMShapeError("input matrix must have shape (m, n) with m <= n")
  10. if x.domain != ZZ:
  11. raise DMDomainError("input matrix domain must be ZZ")
  12. m = x.shape[0]
  13. n = x.shape[1]
  14. k = 1
  15. y = x.copy()
  16. y_star = x.zeros((m, n), QQ)
  17. mu = x.zeros((m, m), QQ)
  18. g_star = [QQ(0, 1) for _ in range(m)]
  19. half = QQ(1, 2)
  20. T = x.eye(m, ZZ) if return_transform else None
  21. linear_dependent_error = "input matrix contains linearly dependent rows"
  22. def closest_integer(x):
  23. return ZZ(mfloor(x + half))
  24. def lovasz_condition(k: int) -> bool:
  25. return g_star[k] >= ((delta - mu[k][k - 1] ** 2) * g_star[k - 1])
  26. def mu_small(k: int, j: int) -> bool:
  27. return abs(mu[k][j]) <= half
  28. def dot_rows(x, y, rows: tuple[int, int]):
  29. return sum([x[rows[0]][z] * y[rows[1]][z] for z in range(x.shape[1])])
  30. def reduce_row(T, mu, y, rows: tuple[int, int]):
  31. r = closest_integer(mu[rows[0]][rows[1]])
  32. y[rows[0]] = [y[rows[0]][z] - r * y[rows[1]][z] for z in range(n)]
  33. mu[rows[0]][:rows[1]] = [mu[rows[0]][z] - r * mu[rows[1]][z] for z in range(rows[1])]
  34. mu[rows[0]][rows[1]] -= r
  35. if return_transform:
  36. T[rows[0]] = [T[rows[0]][z] - r * T[rows[1]][z] for z in range(m)]
  37. for i in range(m):
  38. y_star[i] = [QQ.convert_from(z, ZZ) for z in y[i]]
  39. for j in range(i):
  40. row_dot = dot_rows(y, y_star, (i, j))
  41. try:
  42. mu[i][j] = row_dot / g_star[j]
  43. except ZeroDivisionError:
  44. raise DMRankError(linear_dependent_error)
  45. y_star[i] = [y_star[i][z] - mu[i][j] * y_star[j][z] for z in range(n)]
  46. g_star[i] = dot_rows(y_star, y_star, (i, i))
  47. while k < m:
  48. if not mu_small(k, k - 1):
  49. reduce_row(T, mu, y, (k, k - 1))
  50. if lovasz_condition(k):
  51. for l in range(k - 2, -1, -1):
  52. if not mu_small(k, l):
  53. reduce_row(T, mu, y, (k, l))
  54. k += 1
  55. else:
  56. nu = mu[k][k - 1]
  57. alpha = g_star[k] + nu ** 2 * g_star[k - 1]
  58. try:
  59. beta = g_star[k - 1] / alpha
  60. except ZeroDivisionError:
  61. raise DMRankError(linear_dependent_error)
  62. mu[k][k - 1] = nu * beta
  63. g_star[k] = g_star[k] * beta
  64. g_star[k - 1] = alpha
  65. y[k], y[k - 1] = y[k - 1], y[k]
  66. mu[k][:k - 1], mu[k - 1][:k - 1] = mu[k - 1][:k - 1], mu[k][:k - 1]
  67. for i in range(k + 1, m):
  68. xi = mu[i][k]
  69. mu[i][k] = mu[i][k - 1] - nu * xi
  70. mu[i][k - 1] = mu[k][k - 1] * mu[i][k] + xi
  71. if return_transform:
  72. T[k], T[k - 1] = T[k - 1], T[k]
  73. k = max(k - 1, 1)
  74. assert all([lovasz_condition(i) for i in range(1, m)])
  75. assert all([mu_small(i, j) for i in range(m) for j in range(i)])
  76. return y, T
  77. def ddm_lll(x, delta=QQ(3, 4)):
  78. return _ddm_lll(x, delta=delta, return_transform=False)[0]
  79. def ddm_lll_transform(x, delta=QQ(3, 4)):
  80. return _ddm_lll(x, delta=delta, return_transform=True)