transforms.py 40 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211
  1. import functools
  2. import math
  3. import numbers
  4. import operator
  5. import weakref
  6. from typing import List
  7. import torch
  8. import torch.nn.functional as F
  9. from torch.distributions import constraints
  10. from torch.distributions.utils import (_sum_rightmost, broadcast_all,
  11. lazy_property, tril_matrix_to_vec,
  12. vec_to_tril_matrix)
  13. from torch.nn.functional import pad
  14. from torch.nn.functional import softplus
  15. __all__ = [
  16. 'AbsTransform',
  17. 'AffineTransform',
  18. 'CatTransform',
  19. 'ComposeTransform',
  20. 'CorrCholeskyTransform',
  21. 'CumulativeDistributionTransform',
  22. 'ExpTransform',
  23. 'IndependentTransform',
  24. 'LowerCholeskyTransform',
  25. 'PositiveDefiniteTransform',
  26. 'PowerTransform',
  27. 'ReshapeTransform',
  28. 'SigmoidTransform',
  29. 'SoftplusTransform',
  30. 'TanhTransform',
  31. 'SoftmaxTransform',
  32. 'StackTransform',
  33. 'StickBreakingTransform',
  34. 'Transform',
  35. 'identity_transform',
  36. ]
  37. class Transform:
  38. """
  39. Abstract class for invertable transformations with computable log
  40. det jacobians. They are primarily used in
  41. :class:`torch.distributions.TransformedDistribution`.
  42. Caching is useful for transforms whose inverses are either expensive or
  43. numerically unstable. Note that care must be taken with memoized values
  44. since the autograd graph may be reversed. For example while the following
  45. works with or without caching::
  46. y = t(x)
  47. t.log_abs_det_jacobian(x, y).backward() # x will receive gradients.
  48. However the following will error when caching due to dependency reversal::
  49. y = t(x)
  50. z = t.inv(y)
  51. grad(z.sum(), [y]) # error because z is x
  52. Derived classes should implement one or both of :meth:`_call` or
  53. :meth:`_inverse`. Derived classes that set `bijective=True` should also
  54. implement :meth:`log_abs_det_jacobian`.
  55. Args:
  56. cache_size (int): Size of cache. If zero, no caching is done. If one,
  57. the latest single value is cached. Only 0 and 1 are supported.
  58. Attributes:
  59. domain (:class:`~torch.distributions.constraints.Constraint`):
  60. The constraint representing valid inputs to this transform.
  61. codomain (:class:`~torch.distributions.constraints.Constraint`):
  62. The constraint representing valid outputs to this transform
  63. which are inputs to the inverse transform.
  64. bijective (bool): Whether this transform is bijective. A transform
  65. ``t`` is bijective iff ``t.inv(t(x)) == x`` and
  66. ``t(t.inv(y)) == y`` for every ``x`` in the domain and ``y`` in
  67. the codomain. Transforms that are not bijective should at least
  68. maintain the weaker pseudoinverse properties
  69. ``t(t.inv(t(x)) == t(x)`` and ``t.inv(t(t.inv(y))) == t.inv(y)``.
  70. sign (int or Tensor): For bijective univariate transforms, this
  71. should be +1 or -1 depending on whether transform is monotone
  72. increasing or decreasing.
  73. """
  74. bijective = False
  75. domain: constraints.Constraint
  76. codomain: constraints.Constraint
  77. def __init__(self, cache_size=0):
  78. self._cache_size = cache_size
  79. self._inv = None
  80. if cache_size == 0:
  81. pass # default behavior
  82. elif cache_size == 1:
  83. self._cached_x_y = None, None
  84. else:
  85. raise ValueError('cache_size must be 0 or 1')
  86. super().__init__()
  87. def __getstate__(self):
  88. state = self.__dict__.copy()
  89. state["_inv"] = None
  90. return state
  91. @property
  92. def event_dim(self):
  93. if self.domain.event_dim == self.codomain.event_dim:
  94. return self.domain.event_dim
  95. raise ValueError("Please use either .domain.event_dim or .codomain.event_dim")
  96. @property
  97. def inv(self):
  98. """
  99. Returns the inverse :class:`Transform` of this transform.
  100. This should satisfy ``t.inv.inv is t``.
  101. """
  102. inv = None
  103. if self._inv is not None:
  104. inv = self._inv()
  105. if inv is None:
  106. inv = _InverseTransform(self)
  107. self._inv = weakref.ref(inv)
  108. return inv
  109. @property
  110. def sign(self):
  111. """
  112. Returns the sign of the determinant of the Jacobian, if applicable.
  113. In general this only makes sense for bijective transforms.
  114. """
  115. raise NotImplementedError
  116. def with_cache(self, cache_size=1):
  117. if self._cache_size == cache_size:
  118. return self
  119. if type(self).__init__ is Transform.__init__:
  120. return type(self)(cache_size=cache_size)
  121. raise NotImplementedError("{}.with_cache is not implemented".format(type(self)))
  122. def __eq__(self, other):
  123. return self is other
  124. def __ne__(self, other):
  125. # Necessary for Python2
  126. return not self.__eq__(other)
  127. def __call__(self, x):
  128. """
  129. Computes the transform `x => y`.
  130. """
  131. if self._cache_size == 0:
  132. return self._call(x)
  133. x_old, y_old = self._cached_x_y
  134. if x is x_old:
  135. return y_old
  136. y = self._call(x)
  137. self._cached_x_y = x, y
  138. return y
  139. def _inv_call(self, y):
  140. """
  141. Inverts the transform `y => x`.
  142. """
  143. if self._cache_size == 0:
  144. return self._inverse(y)
  145. x_old, y_old = self._cached_x_y
  146. if y is y_old:
  147. return x_old
  148. x = self._inverse(y)
  149. self._cached_x_y = x, y
  150. return x
  151. def _call(self, x):
  152. """
  153. Abstract method to compute forward transformation.
  154. """
  155. raise NotImplementedError
  156. def _inverse(self, y):
  157. """
  158. Abstract method to compute inverse transformation.
  159. """
  160. raise NotImplementedError
  161. def log_abs_det_jacobian(self, x, y):
  162. """
  163. Computes the log det jacobian `log |dy/dx|` given input and output.
  164. """
  165. raise NotImplementedError
  166. def __repr__(self):
  167. return self.__class__.__name__ + '()'
  168. def forward_shape(self, shape):
  169. """
  170. Infers the shape of the forward computation, given the input shape.
  171. Defaults to preserving shape.
  172. """
  173. return shape
  174. def inverse_shape(self, shape):
  175. """
  176. Infers the shapes of the inverse computation, given the output shape.
  177. Defaults to preserving shape.
  178. """
  179. return shape
  180. class _InverseTransform(Transform):
  181. """
  182. Inverts a single :class:`Transform`.
  183. This class is private; please instead use the ``Transform.inv`` property.
  184. """
  185. def __init__(self, transform: Transform):
  186. super().__init__(cache_size=transform._cache_size)
  187. self._inv: Transform = transform
  188. @constraints.dependent_property(is_discrete=False)
  189. def domain(self):
  190. assert self._inv is not None
  191. return self._inv.codomain
  192. @constraints.dependent_property(is_discrete=False)
  193. def codomain(self):
  194. assert self._inv is not None
  195. return self._inv.domain
  196. @property
  197. def bijective(self):
  198. assert self._inv is not None
  199. return self._inv.bijective
  200. @property
  201. def sign(self):
  202. assert self._inv is not None
  203. return self._inv.sign
  204. @property
  205. def inv(self):
  206. return self._inv
  207. def with_cache(self, cache_size=1):
  208. assert self._inv is not None
  209. return self.inv.with_cache(cache_size).inv
  210. def __eq__(self, other):
  211. if not isinstance(other, _InverseTransform):
  212. return False
  213. assert self._inv is not None
  214. return self._inv == other._inv
  215. def __repr__(self):
  216. return f"{self.__class__.__name__}({repr(self._inv)})"
  217. def __call__(self, x):
  218. assert self._inv is not None
  219. return self._inv._inv_call(x)
  220. def log_abs_det_jacobian(self, x, y):
  221. assert self._inv is not None
  222. return -self._inv.log_abs_det_jacobian(y, x)
  223. def forward_shape(self, shape):
  224. return self._inv.inverse_shape(shape)
  225. def inverse_shape(self, shape):
  226. return self._inv.forward_shape(shape)
  227. class ComposeTransform(Transform):
  228. """
  229. Composes multiple transforms in a chain.
  230. The transforms being composed are responsible for caching.
  231. Args:
  232. parts (list of :class:`Transform`): A list of transforms to compose.
  233. cache_size (int): Size of cache. If zero, no caching is done. If one,
  234. the latest single value is cached. Only 0 and 1 are supported.
  235. """
  236. def __init__(self, parts: List[Transform], cache_size=0):
  237. if cache_size:
  238. parts = [part.with_cache(cache_size) for part in parts]
  239. super().__init__(cache_size=cache_size)
  240. self.parts = parts
  241. def __eq__(self, other):
  242. if not isinstance(other, ComposeTransform):
  243. return False
  244. return self.parts == other.parts
  245. @constraints.dependent_property(is_discrete=False)
  246. def domain(self):
  247. if not self.parts:
  248. return constraints.real
  249. domain = self.parts[0].domain
  250. # Adjust event_dim to be maximum among all parts.
  251. event_dim = self.parts[-1].codomain.event_dim
  252. for part in reversed(self.parts):
  253. event_dim += part.domain.event_dim - part.codomain.event_dim
  254. event_dim = max(event_dim, part.domain.event_dim)
  255. assert event_dim >= domain.event_dim
  256. if event_dim > domain.event_dim:
  257. domain = constraints.independent(domain, event_dim - domain.event_dim)
  258. return domain
  259. @constraints.dependent_property(is_discrete=False)
  260. def codomain(self):
  261. if not self.parts:
  262. return constraints.real
  263. codomain = self.parts[-1].codomain
  264. # Adjust event_dim to be maximum among all parts.
  265. event_dim = self.parts[0].domain.event_dim
  266. for part in self.parts:
  267. event_dim += part.codomain.event_dim - part.domain.event_dim
  268. event_dim = max(event_dim, part.codomain.event_dim)
  269. assert event_dim >= codomain.event_dim
  270. if event_dim > codomain.event_dim:
  271. codomain = constraints.independent(codomain, event_dim - codomain.event_dim)
  272. return codomain
  273. @lazy_property
  274. def bijective(self):
  275. return all(p.bijective for p in self.parts)
  276. @lazy_property
  277. def sign(self):
  278. sign = 1
  279. for p in self.parts:
  280. sign = sign * p.sign
  281. return sign
  282. @property
  283. def inv(self):
  284. inv = None
  285. if self._inv is not None:
  286. inv = self._inv()
  287. if inv is None:
  288. inv = ComposeTransform([p.inv for p in reversed(self.parts)])
  289. self._inv = weakref.ref(inv)
  290. inv._inv = weakref.ref(self)
  291. return inv
  292. def with_cache(self, cache_size=1):
  293. if self._cache_size == cache_size:
  294. return self
  295. return ComposeTransform(self.parts, cache_size=cache_size)
  296. def __call__(self, x):
  297. for part in self.parts:
  298. x = part(x)
  299. return x
  300. def log_abs_det_jacobian(self, x, y):
  301. if not self.parts:
  302. return torch.zeros_like(x)
  303. # Compute intermediates. This will be free if parts[:-1] are all cached.
  304. xs = [x]
  305. for part in self.parts[:-1]:
  306. xs.append(part(xs[-1]))
  307. xs.append(y)
  308. terms = []
  309. event_dim = self.domain.event_dim
  310. for part, x, y in zip(self.parts, xs[:-1], xs[1:]):
  311. terms.append(_sum_rightmost(part.log_abs_det_jacobian(x, y),
  312. event_dim - part.domain.event_dim))
  313. event_dim += part.codomain.event_dim - part.domain.event_dim
  314. return functools.reduce(operator.add, terms)
  315. def forward_shape(self, shape):
  316. for part in self.parts:
  317. shape = part.forward_shape(shape)
  318. return shape
  319. def inverse_shape(self, shape):
  320. for part in reversed(self.parts):
  321. shape = part.inverse_shape(shape)
  322. return shape
  323. def __repr__(self):
  324. fmt_string = self.__class__.__name__ + '(\n '
  325. fmt_string += ',\n '.join([p.__repr__() for p in self.parts])
  326. fmt_string += '\n)'
  327. return fmt_string
  328. identity_transform = ComposeTransform([])
  329. class IndependentTransform(Transform):
  330. """
  331. Wrapper around another transform to treat
  332. ``reinterpreted_batch_ndims``-many extra of the right most dimensions as
  333. dependent. This has no effect on the forward or backward transforms, but
  334. does sum out ``reinterpreted_batch_ndims``-many of the rightmost dimensions
  335. in :meth:`log_abs_det_jacobian`.
  336. Args:
  337. base_transform (:class:`Transform`): A base transform.
  338. reinterpreted_batch_ndims (int): The number of extra rightmost
  339. dimensions to treat as dependent.
  340. """
  341. def __init__(self, base_transform, reinterpreted_batch_ndims, cache_size=0):
  342. super().__init__(cache_size=cache_size)
  343. self.base_transform = base_transform.with_cache(cache_size)
  344. self.reinterpreted_batch_ndims = reinterpreted_batch_ndims
  345. def with_cache(self, cache_size=1):
  346. if self._cache_size == cache_size:
  347. return self
  348. return IndependentTransform(self.base_transform,
  349. self.reinterpreted_batch_ndims,
  350. cache_size=cache_size)
  351. @constraints.dependent_property(is_discrete=False)
  352. def domain(self):
  353. return constraints.independent(self.base_transform.domain,
  354. self.reinterpreted_batch_ndims)
  355. @constraints.dependent_property(is_discrete=False)
  356. def codomain(self):
  357. return constraints.independent(self.base_transform.codomain,
  358. self.reinterpreted_batch_ndims)
  359. @property
  360. def bijective(self):
  361. return self.base_transform.bijective
  362. @property
  363. def sign(self):
  364. return self.base_transform.sign
  365. def _call(self, x):
  366. if x.dim() < self.domain.event_dim:
  367. raise ValueError("Too few dimensions on input")
  368. return self.base_transform(x)
  369. def _inverse(self, y):
  370. if y.dim() < self.codomain.event_dim:
  371. raise ValueError("Too few dimensions on input")
  372. return self.base_transform.inv(y)
  373. def log_abs_det_jacobian(self, x, y):
  374. result = self.base_transform.log_abs_det_jacobian(x, y)
  375. result = _sum_rightmost(result, self.reinterpreted_batch_ndims)
  376. return result
  377. def __repr__(self):
  378. return f"{self.__class__.__name__}({repr(self.base_transform)}, {self.reinterpreted_batch_ndims})"
  379. def forward_shape(self, shape):
  380. return self.base_transform.forward_shape(shape)
  381. def inverse_shape(self, shape):
  382. return self.base_transform.inverse_shape(shape)
  383. class ReshapeTransform(Transform):
  384. """
  385. Unit Jacobian transform to reshape the rightmost part of a tensor.
  386. Note that ``in_shape`` and ``out_shape`` must have the same number of
  387. elements, just as for :meth:`torch.Tensor.reshape`.
  388. Arguments:
  389. in_shape (torch.Size): The input event shape.
  390. out_shape (torch.Size): The output event shape.
  391. """
  392. bijective = True
  393. def __init__(self, in_shape, out_shape, cache_size=0):
  394. self.in_shape = torch.Size(in_shape)
  395. self.out_shape = torch.Size(out_shape)
  396. if self.in_shape.numel() != self.out_shape.numel():
  397. raise ValueError("in_shape, out_shape have different numbers of elements")
  398. super().__init__(cache_size=cache_size)
  399. @constraints.dependent_property
  400. def domain(self):
  401. return constraints.independent(constraints.real, len(self.in_shape))
  402. @constraints.dependent_property
  403. def codomain(self):
  404. return constraints.independent(constraints.real, len(self.out_shape))
  405. def with_cache(self, cache_size=1):
  406. if self._cache_size == cache_size:
  407. return self
  408. return ReshapeTransform(self.in_shape, self.out_shape, cache_size=cache_size)
  409. def _call(self, x):
  410. batch_shape = x.shape[:x.dim() - len(self.in_shape)]
  411. return x.reshape(batch_shape + self.out_shape)
  412. def _inverse(self, y):
  413. batch_shape = y.shape[:y.dim() - len(self.out_shape)]
  414. return y.reshape(batch_shape + self.in_shape)
  415. def log_abs_det_jacobian(self, x, y):
  416. batch_shape = x.shape[:x.dim() - len(self.in_shape)]
  417. return x.new_zeros(batch_shape)
  418. def forward_shape(self, shape):
  419. if len(shape) < len(self.in_shape):
  420. raise ValueError("Too few dimensions on input")
  421. cut = len(shape) - len(self.in_shape)
  422. if shape[cut:] != self.in_shape:
  423. raise ValueError("Shape mismatch: expected {} but got {}".format(shape[cut:], self.in_shape))
  424. return shape[:cut] + self.out_shape
  425. def inverse_shape(self, shape):
  426. if len(shape) < len(self.out_shape):
  427. raise ValueError("Too few dimensions on input")
  428. cut = len(shape) - len(self.out_shape)
  429. if shape[cut:] != self.out_shape:
  430. raise ValueError("Shape mismatch: expected {} but got {}".format(shape[cut:], self.out_shape))
  431. return shape[:cut] + self.in_shape
  432. class ExpTransform(Transform):
  433. r"""
  434. Transform via the mapping :math:`y = \exp(x)`.
  435. """
  436. domain = constraints.real
  437. codomain = constraints.positive
  438. bijective = True
  439. sign = +1
  440. def __eq__(self, other):
  441. return isinstance(other, ExpTransform)
  442. def _call(self, x):
  443. return x.exp()
  444. def _inverse(self, y):
  445. return y.log()
  446. def log_abs_det_jacobian(self, x, y):
  447. return x
  448. class PowerTransform(Transform):
  449. r"""
  450. Transform via the mapping :math:`y = x^{\text{exponent}}`.
  451. """
  452. domain = constraints.positive
  453. codomain = constraints.positive
  454. bijective = True
  455. sign = +1
  456. def __init__(self, exponent, cache_size=0):
  457. super().__init__(cache_size=cache_size)
  458. self.exponent, = broadcast_all(exponent)
  459. def with_cache(self, cache_size=1):
  460. if self._cache_size == cache_size:
  461. return self
  462. return PowerTransform(self.exponent, cache_size=cache_size)
  463. def __eq__(self, other):
  464. if not isinstance(other, PowerTransform):
  465. return False
  466. return self.exponent.eq(other.exponent).all().item()
  467. def _call(self, x):
  468. return x.pow(self.exponent)
  469. def _inverse(self, y):
  470. return y.pow(1 / self.exponent)
  471. def log_abs_det_jacobian(self, x, y):
  472. return (self.exponent * y / x).abs().log()
  473. def forward_shape(self, shape):
  474. return torch.broadcast_shapes(shape, getattr(self.exponent, "shape", ()))
  475. def inverse_shape(self, shape):
  476. return torch.broadcast_shapes(shape, getattr(self.exponent, "shape", ()))
  477. def _clipped_sigmoid(x):
  478. finfo = torch.finfo(x.dtype)
  479. return torch.clamp(torch.sigmoid(x), min=finfo.tiny, max=1. - finfo.eps)
  480. class SigmoidTransform(Transform):
  481. r"""
  482. Transform via the mapping :math:`y = \frac{1}{1 + \exp(-x)}` and :math:`x = \text{logit}(y)`.
  483. """
  484. domain = constraints.real
  485. codomain = constraints.unit_interval
  486. bijective = True
  487. sign = +1
  488. def __eq__(self, other):
  489. return isinstance(other, SigmoidTransform)
  490. def _call(self, x):
  491. return _clipped_sigmoid(x)
  492. def _inverse(self, y):
  493. finfo = torch.finfo(y.dtype)
  494. y = y.clamp(min=finfo.tiny, max=1. - finfo.eps)
  495. return y.log() - (-y).log1p()
  496. def log_abs_det_jacobian(self, x, y):
  497. return -F.softplus(-x) - F.softplus(x)
  498. class SoftplusTransform(Transform):
  499. r"""
  500. Transform via the mapping :math:`\text{Softplus}(x) = \log(1 + \exp(x))`.
  501. The implementation reverts to the linear function when :math:`x > 20`.
  502. """
  503. domain = constraints.real
  504. codomain = constraints.positive
  505. bijective = True
  506. sign = +1
  507. def __eq__(self, other):
  508. return isinstance(other, SoftplusTransform)
  509. def _call(self, x):
  510. return softplus(x)
  511. def _inverse(self, y):
  512. return (-y).expm1().neg().log() + y
  513. def log_abs_det_jacobian(self, x, y):
  514. return -softplus(-x)
  515. class TanhTransform(Transform):
  516. r"""
  517. Transform via the mapping :math:`y = \tanh(x)`.
  518. It is equivalent to
  519. ```
  520. ComposeTransform([AffineTransform(0., 2.), SigmoidTransform(), AffineTransform(-1., 2.)])
  521. ```
  522. However this might not be numerically stable, thus it is recommended to use `TanhTransform`
  523. instead.
  524. Note that one should use `cache_size=1` when it comes to `NaN/Inf` values.
  525. """
  526. domain = constraints.real
  527. codomain = constraints.interval(-1.0, 1.0)
  528. bijective = True
  529. sign = +1
  530. def __eq__(self, other):
  531. return isinstance(other, TanhTransform)
  532. def _call(self, x):
  533. return x.tanh()
  534. def _inverse(self, y):
  535. # We do not clamp to the boundary here as it may degrade the performance of certain algorithms.
  536. # one should use `cache_size=1` instead
  537. return torch.atanh(y)
  538. def log_abs_det_jacobian(self, x, y):
  539. # We use a formula that is more numerically stable, see details in the following link
  540. # https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/bijectors/tanh.py#L69-L80
  541. return 2. * (math.log(2.) - x - softplus(-2. * x))
  542. class AbsTransform(Transform):
  543. r"""
  544. Transform via the mapping :math:`y = |x|`.
  545. """
  546. domain = constraints.real
  547. codomain = constraints.positive
  548. def __eq__(self, other):
  549. return isinstance(other, AbsTransform)
  550. def _call(self, x):
  551. return x.abs()
  552. def _inverse(self, y):
  553. return y
  554. class AffineTransform(Transform):
  555. r"""
  556. Transform via the pointwise affine mapping :math:`y = \text{loc} + \text{scale} \times x`.
  557. Args:
  558. loc (Tensor or float): Location parameter.
  559. scale (Tensor or float): Scale parameter.
  560. event_dim (int): Optional size of `event_shape`. This should be zero
  561. for univariate random variables, 1 for distributions over vectors,
  562. 2 for distributions over matrices, etc.
  563. """
  564. bijective = True
  565. def __init__(self, loc, scale, event_dim=0, cache_size=0):
  566. super().__init__(cache_size=cache_size)
  567. self.loc = loc
  568. self.scale = scale
  569. self._event_dim = event_dim
  570. @property
  571. def event_dim(self):
  572. return self._event_dim
  573. @constraints.dependent_property(is_discrete=False)
  574. def domain(self):
  575. if self.event_dim == 0:
  576. return constraints.real
  577. return constraints.independent(constraints.real, self.event_dim)
  578. @constraints.dependent_property(is_discrete=False)
  579. def codomain(self):
  580. if self.event_dim == 0:
  581. return constraints.real
  582. return constraints.independent(constraints.real, self.event_dim)
  583. def with_cache(self, cache_size=1):
  584. if self._cache_size == cache_size:
  585. return self
  586. return AffineTransform(self.loc, self.scale, self.event_dim, cache_size=cache_size)
  587. def __eq__(self, other):
  588. if not isinstance(other, AffineTransform):
  589. return False
  590. if isinstance(self.loc, numbers.Number) and isinstance(other.loc, numbers.Number):
  591. if self.loc != other.loc:
  592. return False
  593. else:
  594. if not (self.loc == other.loc).all().item():
  595. return False
  596. if isinstance(self.scale, numbers.Number) and isinstance(other.scale, numbers.Number):
  597. if self.scale != other.scale:
  598. return False
  599. else:
  600. if not (self.scale == other.scale).all().item():
  601. return False
  602. return True
  603. @property
  604. def sign(self):
  605. if isinstance(self.scale, numbers.Real):
  606. return 1 if float(self.scale) > 0 else -1 if float(self.scale) < 0 else 0
  607. return self.scale.sign()
  608. def _call(self, x):
  609. return self.loc + self.scale * x
  610. def _inverse(self, y):
  611. return (y - self.loc) / self.scale
  612. def log_abs_det_jacobian(self, x, y):
  613. shape = x.shape
  614. scale = self.scale
  615. if isinstance(scale, numbers.Real):
  616. result = torch.full_like(x, math.log(abs(scale)))
  617. else:
  618. result = torch.abs(scale).log()
  619. if self.event_dim:
  620. result_size = result.size()[:-self.event_dim] + (-1,)
  621. result = result.view(result_size).sum(-1)
  622. shape = shape[:-self.event_dim]
  623. return result.expand(shape)
  624. def forward_shape(self, shape):
  625. return torch.broadcast_shapes(shape,
  626. getattr(self.loc, "shape", ()),
  627. getattr(self.scale, "shape", ()))
  628. def inverse_shape(self, shape):
  629. return torch.broadcast_shapes(shape,
  630. getattr(self.loc, "shape", ()),
  631. getattr(self.scale, "shape", ()))
  632. class CorrCholeskyTransform(Transform):
  633. r"""
  634. Transforms an uncontrained real vector :math:`x` with length :math:`D*(D-1)/2` into the
  635. Cholesky factor of a D-dimension correlation matrix. This Cholesky factor is a lower
  636. triangular matrix with positive diagonals and unit Euclidean norm for each row.
  637. The transform is processed as follows:
  638. 1. First we convert x into a lower triangular matrix in row order.
  639. 2. For each row :math:`X_i` of the lower triangular part, we apply a *signed* version of
  640. class :class:`StickBreakingTransform` to transform :math:`X_i` into a
  641. unit Euclidean length vector using the following steps:
  642. - Scales into the interval :math:`(-1, 1)` domain: :math:`r_i = \tanh(X_i)`.
  643. - Transforms into an unsigned domain: :math:`z_i = r_i^2`.
  644. - Applies :math:`s_i = StickBreakingTransform(z_i)`.
  645. - Transforms back into signed domain: :math:`y_i = sign(r_i) * \sqrt{s_i}`.
  646. """
  647. domain = constraints.real_vector
  648. codomain = constraints.corr_cholesky
  649. bijective = True
  650. def _call(self, x):
  651. x = torch.tanh(x)
  652. eps = torch.finfo(x.dtype).eps
  653. x = x.clamp(min=-1 + eps, max=1 - eps)
  654. r = vec_to_tril_matrix(x, diag=-1)
  655. # apply stick-breaking on the squared values
  656. # Note that y = sign(r) * sqrt(z * z1m_cumprod)
  657. # = (sign(r) * sqrt(z)) * sqrt(z1m_cumprod) = r * sqrt(z1m_cumprod)
  658. z = r ** 2
  659. z1m_cumprod_sqrt = (1 - z).sqrt().cumprod(-1)
  660. # Diagonal elements must be 1.
  661. r = r + torch.eye(r.shape[-1], dtype=r.dtype, device=r.device)
  662. y = r * pad(z1m_cumprod_sqrt[..., :-1], [1, 0], value=1)
  663. return y
  664. def _inverse(self, y):
  665. # inverse stick-breaking
  666. # See: https://mc-stan.org/docs/2_18/reference-manual/cholesky-factors-of-correlation-matrices-1.html
  667. y_cumsum = 1 - torch.cumsum(y * y, dim=-1)
  668. y_cumsum_shifted = pad(y_cumsum[..., :-1], [1, 0], value=1)
  669. y_vec = tril_matrix_to_vec(y, diag=-1)
  670. y_cumsum_vec = tril_matrix_to_vec(y_cumsum_shifted, diag=-1)
  671. t = y_vec / (y_cumsum_vec).sqrt()
  672. # inverse of tanh
  673. x = (t.log1p() - t.neg().log1p()) / 2
  674. return x
  675. def log_abs_det_jacobian(self, x, y, intermediates=None):
  676. # Because domain and codomain are two spaces with different dimensions, determinant of
  677. # Jacobian is not well-defined. We return `log_abs_det_jacobian` of `x` and the
  678. # flattened lower triangular part of `y`.
  679. # See: https://mc-stan.org/docs/2_18/reference-manual/cholesky-factors-of-correlation-matrices-1.html
  680. y1m_cumsum = 1 - (y * y).cumsum(dim=-1)
  681. # by taking diagonal=-2, we don't need to shift z_cumprod to the right
  682. # also works for 2 x 2 matrix
  683. y1m_cumsum_tril = tril_matrix_to_vec(y1m_cumsum, diag=-2)
  684. stick_breaking_logdet = 0.5 * (y1m_cumsum_tril).log().sum(-1)
  685. tanh_logdet = -2 * (x + softplus(-2 * x) - math.log(2.)).sum(dim=-1)
  686. return stick_breaking_logdet + tanh_logdet
  687. def forward_shape(self, shape):
  688. # Reshape from (..., N) to (..., D, D).
  689. if len(shape) < 1:
  690. raise ValueError("Too few dimensions on input")
  691. N = shape[-1]
  692. D = round((0.25 + 2 * N) ** 0.5 + 0.5)
  693. if D * (D - 1) // 2 != N:
  694. raise ValueError("Input is not a flattend lower-diagonal number")
  695. return shape[:-1] + (D, D)
  696. def inverse_shape(self, shape):
  697. # Reshape from (..., D, D) to (..., N).
  698. if len(shape) < 2:
  699. raise ValueError("Too few dimensions on input")
  700. if shape[-2] != shape[-1]:
  701. raise ValueError("Input is not square")
  702. D = shape[-1]
  703. N = D * (D - 1) // 2
  704. return shape[:-2] + (N,)
  705. class SoftmaxTransform(Transform):
  706. r"""
  707. Transform from unconstrained space to the simplex via :math:`y = \exp(x)` then
  708. normalizing.
  709. This is not bijective and cannot be used for HMC. However this acts mostly
  710. coordinate-wise (except for the final normalization), and thus is
  711. appropriate for coordinate-wise optimization algorithms.
  712. """
  713. domain = constraints.real_vector
  714. codomain = constraints.simplex
  715. def __eq__(self, other):
  716. return isinstance(other, SoftmaxTransform)
  717. def _call(self, x):
  718. logprobs = x
  719. probs = (logprobs - logprobs.max(-1, True)[0]).exp()
  720. return probs / probs.sum(-1, True)
  721. def _inverse(self, y):
  722. probs = y
  723. return probs.log()
  724. def forward_shape(self, shape):
  725. if len(shape) < 1:
  726. raise ValueError("Too few dimensions on input")
  727. return shape
  728. def inverse_shape(self, shape):
  729. if len(shape) < 1:
  730. raise ValueError("Too few dimensions on input")
  731. return shape
  732. class StickBreakingTransform(Transform):
  733. """
  734. Transform from unconstrained space to the simplex of one additional
  735. dimension via a stick-breaking process.
  736. This transform arises as an iterated sigmoid transform in a stick-breaking
  737. construction of the `Dirichlet` distribution: the first logit is
  738. transformed via sigmoid to the first probability and the probability of
  739. everything else, and then the process recurses.
  740. This is bijective and appropriate for use in HMC; however it mixes
  741. coordinates together and is less appropriate for optimization.
  742. """
  743. domain = constraints.real_vector
  744. codomain = constraints.simplex
  745. bijective = True
  746. def __eq__(self, other):
  747. return isinstance(other, StickBreakingTransform)
  748. def _call(self, x):
  749. offset = x.shape[-1] + 1 - x.new_ones(x.shape[-1]).cumsum(-1)
  750. z = _clipped_sigmoid(x - offset.log())
  751. z_cumprod = (1 - z).cumprod(-1)
  752. y = pad(z, [0, 1], value=1) * pad(z_cumprod, [1, 0], value=1)
  753. return y
  754. def _inverse(self, y):
  755. y_crop = y[..., :-1]
  756. offset = y.shape[-1] - y.new_ones(y_crop.shape[-1]).cumsum(-1)
  757. sf = 1 - y_crop.cumsum(-1)
  758. # we clamp to make sure that sf is positive which sometimes does not
  759. # happen when y[-1] ~ 0 or y[:-1].sum() ~ 1
  760. sf = torch.clamp(sf, min=torch.finfo(y.dtype).tiny)
  761. x = y_crop.log() - sf.log() + offset.log()
  762. return x
  763. def log_abs_det_jacobian(self, x, y):
  764. offset = x.shape[-1] + 1 - x.new_ones(x.shape[-1]).cumsum(-1)
  765. x = x - offset.log()
  766. # use the identity 1 - sigmoid(x) = exp(-x) * sigmoid(x)
  767. detJ = (-x + F.logsigmoid(x) + y[..., :-1].log()).sum(-1)
  768. return detJ
  769. def forward_shape(self, shape):
  770. if len(shape) < 1:
  771. raise ValueError("Too few dimensions on input")
  772. return shape[:-1] + (shape[-1] + 1,)
  773. def inverse_shape(self, shape):
  774. if len(shape) < 1:
  775. raise ValueError("Too few dimensions on input")
  776. return shape[:-1] + (shape[-1] - 1,)
  777. class LowerCholeskyTransform(Transform):
  778. """
  779. Transform from unconstrained matrices to lower-triangular matrices with
  780. nonnegative diagonal entries.
  781. This is useful for parameterizing positive definite matrices in terms of
  782. their Cholesky factorization.
  783. """
  784. domain = constraints.independent(constraints.real, 2)
  785. codomain = constraints.lower_cholesky
  786. def __eq__(self, other):
  787. return isinstance(other, LowerCholeskyTransform)
  788. def _call(self, x):
  789. return x.tril(-1) + x.diagonal(dim1=-2, dim2=-1).exp().diag_embed()
  790. def _inverse(self, y):
  791. return y.tril(-1) + y.diagonal(dim1=-2, dim2=-1).log().diag_embed()
  792. class PositiveDefiniteTransform(Transform):
  793. """
  794. Transform from unconstrained matrices to positive-definite matrices.
  795. """
  796. domain = constraints.independent(constraints.real, 2)
  797. codomain = constraints.positive_definite # type: ignore[assignment]
  798. def __eq__(self, other):
  799. return isinstance(other, PositiveDefiniteTransform)
  800. def _call(self, x):
  801. x = LowerCholeskyTransform()(x)
  802. return x @ x.mT
  803. def _inverse(self, y):
  804. y = torch.linalg.cholesky(y)
  805. return LowerCholeskyTransform().inv(y)
  806. class CatTransform(Transform):
  807. """
  808. Transform functor that applies a sequence of transforms `tseq`
  809. component-wise to each submatrix at `dim`, of length `lengths[dim]`,
  810. in a way compatible with :func:`torch.cat`.
  811. Example::
  812. x0 = torch.cat([torch.range(1, 10), torch.range(1, 10)], dim=0)
  813. x = torch.cat([x0, x0], dim=0)
  814. t0 = CatTransform([ExpTransform(), identity_transform], dim=0, lengths=[10, 10])
  815. t = CatTransform([t0, t0], dim=0, lengths=[20, 20])
  816. y = t(x)
  817. """
  818. transforms: List[Transform]
  819. def __init__(self, tseq, dim=0, lengths=None, cache_size=0):
  820. assert all(isinstance(t, Transform) for t in tseq)
  821. if cache_size:
  822. tseq = [t.with_cache(cache_size) for t in tseq]
  823. super().__init__(cache_size=cache_size)
  824. self.transforms = list(tseq)
  825. if lengths is None:
  826. lengths = [1] * len(self.transforms)
  827. self.lengths = list(lengths)
  828. assert len(self.lengths) == len(self.transforms)
  829. self.dim = dim
  830. @lazy_property
  831. def event_dim(self):
  832. return max(t.event_dim for t in self.transforms)
  833. @lazy_property
  834. def length(self):
  835. return sum(self.lengths)
  836. def with_cache(self, cache_size=1):
  837. if self._cache_size == cache_size:
  838. return self
  839. return CatTransform(self.transforms, self.dim, self.lengths, cache_size)
  840. def _call(self, x):
  841. assert -x.dim() <= self.dim < x.dim()
  842. assert x.size(self.dim) == self.length
  843. yslices = []
  844. start = 0
  845. for trans, length in zip(self.transforms, self.lengths):
  846. xslice = x.narrow(self.dim, start, length)
  847. yslices.append(trans(xslice))
  848. start = start + length # avoid += for jit compat
  849. return torch.cat(yslices, dim=self.dim)
  850. def _inverse(self, y):
  851. assert -y.dim() <= self.dim < y.dim()
  852. assert y.size(self.dim) == self.length
  853. xslices = []
  854. start = 0
  855. for trans, length in zip(self.transforms, self.lengths):
  856. yslice = y.narrow(self.dim, start, length)
  857. xslices.append(trans.inv(yslice))
  858. start = start + length # avoid += for jit compat
  859. return torch.cat(xslices, dim=self.dim)
  860. def log_abs_det_jacobian(self, x, y):
  861. assert -x.dim() <= self.dim < x.dim()
  862. assert x.size(self.dim) == self.length
  863. assert -y.dim() <= self.dim < y.dim()
  864. assert y.size(self.dim) == self.length
  865. logdetjacs = []
  866. start = 0
  867. for trans, length in zip(self.transforms, self.lengths):
  868. xslice = x.narrow(self.dim, start, length)
  869. yslice = y.narrow(self.dim, start, length)
  870. logdetjac = trans.log_abs_det_jacobian(xslice, yslice)
  871. if trans.event_dim < self.event_dim:
  872. logdetjac = _sum_rightmost(logdetjac, self.event_dim - trans.event_dim)
  873. logdetjacs.append(logdetjac)
  874. start = start + length # avoid += for jit compat
  875. # Decide whether to concatenate or sum.
  876. dim = self.dim
  877. if dim >= 0:
  878. dim = dim - x.dim()
  879. dim = dim + self.event_dim
  880. if dim < 0:
  881. return torch.cat(logdetjacs, dim=dim)
  882. else:
  883. return sum(logdetjacs)
  884. @property
  885. def bijective(self):
  886. return all(t.bijective for t in self.transforms)
  887. @constraints.dependent_property
  888. def domain(self):
  889. return constraints.cat([t.domain for t in self.transforms],
  890. self.dim, self.lengths)
  891. @constraints.dependent_property
  892. def codomain(self):
  893. return constraints.cat([t.codomain for t in self.transforms],
  894. self.dim, self.lengths)
  895. class StackTransform(Transform):
  896. """
  897. Transform functor that applies a sequence of transforms `tseq`
  898. component-wise to each submatrix at `dim`
  899. in a way compatible with :func:`torch.stack`.
  900. Example::
  901. x = torch.stack([torch.range(1, 10), torch.range(1, 10)], dim=1)
  902. t = StackTransform([ExpTransform(), identity_transform], dim=1)
  903. y = t(x)
  904. """
  905. transforms: List[Transform]
  906. def __init__(self, tseq, dim=0, cache_size=0):
  907. assert all(isinstance(t, Transform) for t in tseq)
  908. if cache_size:
  909. tseq = [t.with_cache(cache_size) for t in tseq]
  910. super().__init__(cache_size=cache_size)
  911. self.transforms = list(tseq)
  912. self.dim = dim
  913. def with_cache(self, cache_size=1):
  914. if self._cache_size == cache_size:
  915. return self
  916. return StackTransform(self.transforms, self.dim, cache_size)
  917. def _slice(self, z):
  918. return [z.select(self.dim, i) for i in range(z.size(self.dim))]
  919. def _call(self, x):
  920. assert -x.dim() <= self.dim < x.dim()
  921. assert x.size(self.dim) == len(self.transforms)
  922. yslices = []
  923. for xslice, trans in zip(self._slice(x), self.transforms):
  924. yslices.append(trans(xslice))
  925. return torch.stack(yslices, dim=self.dim)
  926. def _inverse(self, y):
  927. assert -y.dim() <= self.dim < y.dim()
  928. assert y.size(self.dim) == len(self.transforms)
  929. xslices = []
  930. for yslice, trans in zip(self._slice(y), self.transforms):
  931. xslices.append(trans.inv(yslice))
  932. return torch.stack(xslices, dim=self.dim)
  933. def log_abs_det_jacobian(self, x, y):
  934. assert -x.dim() <= self.dim < x.dim()
  935. assert x.size(self.dim) == len(self.transforms)
  936. assert -y.dim() <= self.dim < y.dim()
  937. assert y.size(self.dim) == len(self.transforms)
  938. logdetjacs = []
  939. yslices = self._slice(y)
  940. xslices = self._slice(x)
  941. for xslice, yslice, trans in zip(xslices, yslices, self.transforms):
  942. logdetjacs.append(trans.log_abs_det_jacobian(xslice, yslice))
  943. return torch.stack(logdetjacs, dim=self.dim)
  944. @property
  945. def bijective(self):
  946. return all(t.bijective for t in self.transforms)
  947. @constraints.dependent_property
  948. def domain(self):
  949. return constraints.stack([t.domain for t in self.transforms], self.dim)
  950. @constraints.dependent_property
  951. def codomain(self):
  952. return constraints.stack([t.codomain for t in self.transforms], self.dim)
  953. class CumulativeDistributionTransform(Transform):
  954. """
  955. Transform via the cumulative distribution function of a probability distribution.
  956. Args:
  957. distribution (Distribution): Distribution whose cumulative distribution function to use for
  958. the transformation.
  959. Example::
  960. # Construct a Gaussian copula from a multivariate normal.
  961. base_dist = MultivariateNormal(
  962. loc=torch.zeros(2),
  963. scale_tril=LKJCholesky(2).sample(),
  964. )
  965. transform = CumulativeDistributionTransform(Normal(0, 1))
  966. copula = TransformedDistribution(base_dist, [transform])
  967. """
  968. bijective = True
  969. codomain = constraints.unit_interval
  970. sign = +1
  971. def __init__(self, distribution, cache_size=0):
  972. super().__init__(cache_size=cache_size)
  973. self.distribution = distribution
  974. @property
  975. def domain(self):
  976. return self.distribution.support
  977. def _call(self, x):
  978. return self.distribution.cdf(x)
  979. def _inverse(self, y):
  980. return self.distribution.icdf(y)
  981. def log_abs_det_jacobian(self, x, y):
  982. return self.distribution.log_prob(x)
  983. def with_cache(self, cache_size=1):
  984. if self._cache_size == cache_size:
  985. return self
  986. return CumulativeDistributionTransform(self.distribution, cache_size=cache_size)