constraints.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609
  1. r"""
  2. The following constraints are implemented:
  3. - ``constraints.boolean``
  4. - ``constraints.cat``
  5. - ``constraints.corr_cholesky``
  6. - ``constraints.dependent``
  7. - ``constraints.greater_than(lower_bound)``
  8. - ``constraints.greater_than_eq(lower_bound)``
  9. - ``constraints.independent(constraint, reinterpreted_batch_ndims)``
  10. - ``constraints.integer_interval(lower_bound, upper_bound)``
  11. - ``constraints.interval(lower_bound, upper_bound)``
  12. - ``constraints.less_than(upper_bound)``
  13. - ``constraints.lower_cholesky``
  14. - ``constraints.lower_triangular``
  15. - ``constraints.multinomial``
  16. - ``constraints.nonnegative_integer``
  17. - ``constraints.one_hot``
  18. - ``constraints.positive_integer``
  19. - ``constraints.positive``
  20. - ``constraints.positive_semidefinite``
  21. - ``constraints.positive_definite``
  22. - ``constraints.real_vector``
  23. - ``constraints.real``
  24. - ``constraints.simplex``
  25. - ``constraints.symmetric``
  26. - ``constraints.stack``
  27. - ``constraints.square``
  28. - ``constraints.symmetric``
  29. - ``constraints.unit_interval``
  30. """
  31. import torch
  32. __all__ = [
  33. 'Constraint',
  34. 'boolean',
  35. 'cat',
  36. 'corr_cholesky',
  37. 'dependent',
  38. 'dependent_property',
  39. 'greater_than',
  40. 'greater_than_eq',
  41. 'independent',
  42. 'integer_interval',
  43. 'interval',
  44. 'half_open_interval',
  45. 'is_dependent',
  46. 'less_than',
  47. 'lower_cholesky',
  48. 'lower_triangular',
  49. 'multinomial',
  50. 'nonnegative_integer',
  51. 'positive',
  52. 'positive_semidefinite',
  53. 'positive_definite',
  54. 'positive_integer',
  55. 'real',
  56. 'real_vector',
  57. 'simplex',
  58. 'square',
  59. 'stack',
  60. 'symmetric',
  61. 'unit_interval',
  62. ]
  63. class Constraint:
  64. """
  65. Abstract base class for constraints.
  66. A constraint object represents a region over which a variable is valid,
  67. e.g. within which a variable can be optimized.
  68. Attributes:
  69. is_discrete (bool): Whether constrained space is discrete.
  70. Defaults to False.
  71. event_dim (int): Number of rightmost dimensions that together define
  72. an event. The :meth:`check` method will remove this many dimensions
  73. when computing validity.
  74. """
  75. is_discrete = False # Default to continuous.
  76. event_dim = 0 # Default to univariate.
  77. def check(self, value):
  78. """
  79. Returns a byte tensor of ``sample_shape + batch_shape`` indicating
  80. whether each event in value satisfies this constraint.
  81. """
  82. raise NotImplementedError
  83. def __repr__(self):
  84. return self.__class__.__name__[1:] + '()'
  85. class _Dependent(Constraint):
  86. """
  87. Placeholder for variables whose support depends on other variables.
  88. These variables obey no simple coordinate-wise constraints.
  89. Args:
  90. is_discrete (bool): Optional value of ``.is_discrete`` in case this
  91. can be computed statically. If not provided, access to the
  92. ``.is_discrete`` attribute will raise a NotImplementedError.
  93. event_dim (int): Optional value of ``.event_dim`` in case this
  94. can be computed statically. If not provided, access to the
  95. ``.event_dim`` attribute will raise a NotImplementedError.
  96. """
  97. def __init__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented):
  98. self._is_discrete = is_discrete
  99. self._event_dim = event_dim
  100. super().__init__()
  101. @property
  102. def is_discrete(self):
  103. if self._is_discrete is NotImplemented:
  104. raise NotImplementedError(".is_discrete cannot be determined statically")
  105. return self._is_discrete
  106. @property
  107. def event_dim(self):
  108. if self._event_dim is NotImplemented:
  109. raise NotImplementedError(".event_dim cannot be determined statically")
  110. return self._event_dim
  111. def __call__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented):
  112. """
  113. Support for syntax to customize static attributes::
  114. constraints.dependent(is_discrete=True, event_dim=1)
  115. """
  116. if is_discrete is NotImplemented:
  117. is_discrete = self._is_discrete
  118. if event_dim is NotImplemented:
  119. event_dim = self._event_dim
  120. return _Dependent(is_discrete=is_discrete, event_dim=event_dim)
  121. def check(self, x):
  122. raise ValueError('Cannot determine validity of dependent constraint')
  123. def is_dependent(constraint):
  124. return isinstance(constraint, _Dependent)
  125. class _DependentProperty(property, _Dependent):
  126. """
  127. Decorator that extends @property to act like a `Dependent` constraint when
  128. called on a class and act like a property when called on an object.
  129. Example::
  130. class Uniform(Distribution):
  131. def __init__(self, low, high):
  132. self.low = low
  133. self.high = high
  134. @constraints.dependent_property(is_discrete=False, event_dim=0)
  135. def support(self):
  136. return constraints.interval(self.low, self.high)
  137. Args:
  138. fn (Callable): The function to be decorated.
  139. is_discrete (bool): Optional value of ``.is_discrete`` in case this
  140. can be computed statically. If not provided, access to the
  141. ``.is_discrete`` attribute will raise a NotImplementedError.
  142. event_dim (int): Optional value of ``.event_dim`` in case this
  143. can be computed statically. If not provided, access to the
  144. ``.event_dim`` attribute will raise a NotImplementedError.
  145. """
  146. def __init__(self, fn=None, *, is_discrete=NotImplemented, event_dim=NotImplemented):
  147. super().__init__(fn)
  148. self._is_discrete = is_discrete
  149. self._event_dim = event_dim
  150. def __call__(self, fn):
  151. """
  152. Support for syntax to customize static attributes::
  153. @constraints.dependent_property(is_discrete=True, event_dim=1)
  154. def support(self):
  155. ...
  156. """
  157. return _DependentProperty(fn, is_discrete=self._is_discrete, event_dim=self._event_dim)
  158. class _IndependentConstraint(Constraint):
  159. """
  160. Wraps a constraint by aggregating over ``reinterpreted_batch_ndims``-many
  161. dims in :meth:`check`, so that an event is valid only if all its
  162. independent entries are valid.
  163. """
  164. def __init__(self, base_constraint, reinterpreted_batch_ndims):
  165. assert isinstance(base_constraint, Constraint)
  166. assert isinstance(reinterpreted_batch_ndims, int)
  167. assert reinterpreted_batch_ndims >= 0
  168. self.base_constraint = base_constraint
  169. self.reinterpreted_batch_ndims = reinterpreted_batch_ndims
  170. super().__init__()
  171. @property
  172. def is_discrete(self):
  173. return self.base_constraint.is_discrete
  174. @property
  175. def event_dim(self):
  176. return self.base_constraint.event_dim + self.reinterpreted_batch_ndims
  177. def check(self, value):
  178. result = self.base_constraint.check(value)
  179. if result.dim() < self.reinterpreted_batch_ndims:
  180. expected = self.base_constraint.event_dim + self.reinterpreted_batch_ndims
  181. raise ValueError(f"Expected value.dim() >= {expected} but got {value.dim()}")
  182. result = result.reshape(result.shape[:result.dim() - self.reinterpreted_batch_ndims] + (-1,))
  183. result = result.all(-1)
  184. return result
  185. def __repr__(self):
  186. return "{}({}, {})".format(self.__class__.__name__[1:], repr(self.base_constraint),
  187. self.reinterpreted_batch_ndims)
  188. class _Boolean(Constraint):
  189. """
  190. Constrain to the two values `{0, 1}`.
  191. """
  192. is_discrete = True
  193. def check(self, value):
  194. return (value == 0) | (value == 1)
  195. class _OneHot(Constraint):
  196. """
  197. Constrain to one-hot vectors.
  198. """
  199. is_discrete = True
  200. event_dim = 1
  201. def check(self, value):
  202. is_boolean = (value == 0) | (value == 1)
  203. is_normalized = value.sum(-1).eq(1)
  204. return is_boolean.all(-1) & is_normalized
  205. class _IntegerInterval(Constraint):
  206. """
  207. Constrain to an integer interval `[lower_bound, upper_bound]`.
  208. """
  209. is_discrete = True
  210. def __init__(self, lower_bound, upper_bound):
  211. self.lower_bound = lower_bound
  212. self.upper_bound = upper_bound
  213. super().__init__()
  214. def check(self, value):
  215. return (value % 1 == 0) & (self.lower_bound <= value) & (value <= self.upper_bound)
  216. def __repr__(self):
  217. fmt_string = self.__class__.__name__[1:]
  218. fmt_string += '(lower_bound={}, upper_bound={})'.format(self.lower_bound, self.upper_bound)
  219. return fmt_string
  220. class _IntegerLessThan(Constraint):
  221. """
  222. Constrain to an integer interval `(-inf, upper_bound]`.
  223. """
  224. is_discrete = True
  225. def __init__(self, upper_bound):
  226. self.upper_bound = upper_bound
  227. super().__init__()
  228. def check(self, value):
  229. return (value % 1 == 0) & (value <= self.upper_bound)
  230. def __repr__(self):
  231. fmt_string = self.__class__.__name__[1:]
  232. fmt_string += '(upper_bound={})'.format(self.upper_bound)
  233. return fmt_string
  234. class _IntegerGreaterThan(Constraint):
  235. """
  236. Constrain to an integer interval `[lower_bound, inf)`.
  237. """
  238. is_discrete = True
  239. def __init__(self, lower_bound):
  240. self.lower_bound = lower_bound
  241. super().__init__()
  242. def check(self, value):
  243. return (value % 1 == 0) & (value >= self.lower_bound)
  244. def __repr__(self):
  245. fmt_string = self.__class__.__name__[1:]
  246. fmt_string += '(lower_bound={})'.format(self.lower_bound)
  247. return fmt_string
  248. class _Real(Constraint):
  249. """
  250. Trivially constrain to the extended real line `[-inf, inf]`.
  251. """
  252. def check(self, value):
  253. return value == value # False for NANs.
  254. class _GreaterThan(Constraint):
  255. """
  256. Constrain to a real half line `(lower_bound, inf]`.
  257. """
  258. def __init__(self, lower_bound):
  259. self.lower_bound = lower_bound
  260. super().__init__()
  261. def check(self, value):
  262. return self.lower_bound < value
  263. def __repr__(self):
  264. fmt_string = self.__class__.__name__[1:]
  265. fmt_string += '(lower_bound={})'.format(self.lower_bound)
  266. return fmt_string
  267. class _GreaterThanEq(Constraint):
  268. """
  269. Constrain to a real half line `[lower_bound, inf)`.
  270. """
  271. def __init__(self, lower_bound):
  272. self.lower_bound = lower_bound
  273. super().__init__()
  274. def check(self, value):
  275. return self.lower_bound <= value
  276. def __repr__(self):
  277. fmt_string = self.__class__.__name__[1:]
  278. fmt_string += '(lower_bound={})'.format(self.lower_bound)
  279. return fmt_string
  280. class _LessThan(Constraint):
  281. """
  282. Constrain to a real half line `[-inf, upper_bound)`.
  283. """
  284. def __init__(self, upper_bound):
  285. self.upper_bound = upper_bound
  286. super().__init__()
  287. def check(self, value):
  288. return value < self.upper_bound
  289. def __repr__(self):
  290. fmt_string = self.__class__.__name__[1:]
  291. fmt_string += '(upper_bound={})'.format(self.upper_bound)
  292. return fmt_string
  293. class _Interval(Constraint):
  294. """
  295. Constrain to a real interval `[lower_bound, upper_bound]`.
  296. """
  297. def __init__(self, lower_bound, upper_bound):
  298. self.lower_bound = lower_bound
  299. self.upper_bound = upper_bound
  300. super().__init__()
  301. def check(self, value):
  302. return (self.lower_bound <= value) & (value <= self.upper_bound)
  303. def __repr__(self):
  304. fmt_string = self.__class__.__name__[1:]
  305. fmt_string += '(lower_bound={}, upper_bound={})'.format(self.lower_bound, self.upper_bound)
  306. return fmt_string
  307. class _HalfOpenInterval(Constraint):
  308. """
  309. Constrain to a real interval `[lower_bound, upper_bound)`.
  310. """
  311. def __init__(self, lower_bound, upper_bound):
  312. self.lower_bound = lower_bound
  313. self.upper_bound = upper_bound
  314. super().__init__()
  315. def check(self, value):
  316. return (self.lower_bound <= value) & (value < self.upper_bound)
  317. def __repr__(self):
  318. fmt_string = self.__class__.__name__[1:]
  319. fmt_string += '(lower_bound={}, upper_bound={})'.format(self.lower_bound, self.upper_bound)
  320. return fmt_string
  321. class _Simplex(Constraint):
  322. """
  323. Constrain to the unit simplex in the innermost (rightmost) dimension.
  324. Specifically: `x >= 0` and `x.sum(-1) == 1`.
  325. """
  326. event_dim = 1
  327. def check(self, value):
  328. return torch.all(value >= 0, dim=-1) & ((value.sum(-1) - 1).abs() < 1e-6)
  329. class _Multinomial(Constraint):
  330. """
  331. Constrain to nonnegative integer values summing to at most an upper bound.
  332. Note due to limitations of the Multinomial distribution, this currently
  333. checks the weaker condition ``value.sum(-1) <= upper_bound``. In the future
  334. this may be strengthened to ``value.sum(-1) == upper_bound``.
  335. """
  336. is_discrete = True
  337. event_dim = 1
  338. def __init__(self, upper_bound):
  339. self.upper_bound = upper_bound
  340. def check(self, x):
  341. return (x >= 0).all(dim=-1) & (x.sum(dim=-1) <= self.upper_bound)
  342. class _LowerTriangular(Constraint):
  343. """
  344. Constrain to lower-triangular square matrices.
  345. """
  346. event_dim = 2
  347. def check(self, value):
  348. value_tril = value.tril()
  349. return (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0]
  350. class _LowerCholesky(Constraint):
  351. """
  352. Constrain to lower-triangular square matrices with positive diagonals.
  353. """
  354. event_dim = 2
  355. def check(self, value):
  356. value_tril = value.tril()
  357. lower_triangular = (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0]
  358. positive_diagonal = (value.diagonal(dim1=-2, dim2=-1) > 0).min(-1)[0]
  359. return lower_triangular & positive_diagonal
  360. class _CorrCholesky(Constraint):
  361. """
  362. Constrain to lower-triangular square matrices with positive diagonals and each
  363. row vector being of unit length.
  364. """
  365. event_dim = 2
  366. def check(self, value):
  367. tol = torch.finfo(value.dtype).eps * value.size(-1) * 10 # 10 is an adjustable fudge factor
  368. row_norm = torch.linalg.norm(value.detach(), dim=-1)
  369. unit_row_norm = (row_norm - 1.).abs().le(tol).all(dim=-1)
  370. return _LowerCholesky().check(value) & unit_row_norm
  371. class _Square(Constraint):
  372. """
  373. Constrain to square matrices.
  374. """
  375. event_dim = 2
  376. def check(self, value):
  377. return torch.full(
  378. size=value.shape[:-2],
  379. fill_value=(value.shape[-2] == value.shape[-1]),
  380. dtype=torch.bool,
  381. device=value.device
  382. )
  383. class _Symmetric(_Square):
  384. """
  385. Constrain to Symmetric square matrices.
  386. """
  387. def check(self, value):
  388. square_check = super().check(value)
  389. if not square_check.all():
  390. return square_check
  391. return torch.isclose(value, value.mT, atol=1e-6).all(-2).all(-1)
  392. class _PositiveSemidefinite(_Symmetric):
  393. """
  394. Constrain to positive-semidefinite matrices.
  395. """
  396. def check(self, value):
  397. sym_check = super().check(value)
  398. if not sym_check.all():
  399. return sym_check
  400. return torch.linalg.eigvalsh(value).ge(0).all(-1)
  401. class _PositiveDefinite(_Symmetric):
  402. """
  403. Constrain to positive-definite matrices.
  404. """
  405. def check(self, value):
  406. sym_check = super().check(value)
  407. if not sym_check.all():
  408. return sym_check
  409. return torch.linalg.cholesky_ex(value).info.eq(0)
  410. class _Cat(Constraint):
  411. """
  412. Constraint functor that applies a sequence of constraints
  413. `cseq` at the submatrices at dimension `dim`,
  414. each of size `lengths[dim]`, in a way compatible with :func:`torch.cat`.
  415. """
  416. def __init__(self, cseq, dim=0, lengths=None):
  417. assert all(isinstance(c, Constraint) for c in cseq)
  418. self.cseq = list(cseq)
  419. if lengths is None:
  420. lengths = [1] * len(self.cseq)
  421. self.lengths = list(lengths)
  422. assert len(self.lengths) == len(self.cseq)
  423. self.dim = dim
  424. super().__init__()
  425. @property
  426. def is_discrete(self):
  427. return any(c.is_discrete for c in self.cseq)
  428. @property
  429. def event_dim(self):
  430. return max(c.event_dim for c in self.cseq)
  431. def check(self, value):
  432. assert -value.dim() <= self.dim < value.dim()
  433. checks = []
  434. start = 0
  435. for constr, length in zip(self.cseq, self.lengths):
  436. v = value.narrow(self.dim, start, length)
  437. checks.append(constr.check(v))
  438. start = start + length # avoid += for jit compat
  439. return torch.cat(checks, self.dim)
  440. class _Stack(Constraint):
  441. """
  442. Constraint functor that applies a sequence of constraints
  443. `cseq` at the submatrices at dimension `dim`,
  444. in a way compatible with :func:`torch.stack`.
  445. """
  446. def __init__(self, cseq, dim=0):
  447. assert all(isinstance(c, Constraint) for c in cseq)
  448. self.cseq = list(cseq)
  449. self.dim = dim
  450. super().__init__()
  451. @property
  452. def is_discrete(self):
  453. return any(c.is_discrete for c in self.cseq)
  454. @property
  455. def event_dim(self):
  456. dim = max(c.event_dim for c in self.cseq)
  457. if self.dim + dim < 0:
  458. dim += 1
  459. return dim
  460. def check(self, value):
  461. assert -value.dim() <= self.dim < value.dim()
  462. vs = [value.select(self.dim, i) for i in range(value.size(self.dim))]
  463. return torch.stack([constr.check(v)
  464. for v, constr in zip(vs, self.cseq)], self.dim)
  465. # Public interface.
  466. dependent = _Dependent()
  467. dependent_property = _DependentProperty
  468. independent = _IndependentConstraint
  469. boolean = _Boolean()
  470. one_hot = _OneHot()
  471. nonnegative_integer = _IntegerGreaterThan(0)
  472. positive_integer = _IntegerGreaterThan(1)
  473. integer_interval = _IntegerInterval
  474. real = _Real()
  475. real_vector = independent(real, 1)
  476. positive = _GreaterThan(0.)
  477. nonnegative = _GreaterThanEq(0.)
  478. greater_than = _GreaterThan
  479. greater_than_eq = _GreaterThanEq
  480. less_than = _LessThan
  481. multinomial = _Multinomial
  482. unit_interval = _Interval(0., 1.)
  483. interval = _Interval
  484. half_open_interval = _HalfOpenInterval
  485. simplex = _Simplex()
  486. lower_triangular = _LowerTriangular()
  487. lower_cholesky = _LowerCholesky()
  488. corr_cholesky = _CorrCholesky()
  489. square = _Square()
  490. symmetric = _Symmetric()
  491. positive_semidefinite = _PositiveSemidefinite()
  492. positive_definite = _PositiveDefinite()
  493. cat = _Cat
  494. stack = _Stack