torch_version.py 3.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. from typing import Any, Iterable
  2. from .version import __version__ as internal_version
  3. __all__ = ['TorchVersion', 'Version', 'InvalidVersion']
  4. class _LazyImport:
  5. """Wraps around classes lazy imported from packaging.version
  6. Output of the function v in following snippets are identical:
  7. from packaging.version import Version
  8. def v():
  9. return Version('1.2.3')
  10. and
  11. Version = _LazyImport('Version')
  12. def v():
  13. return Version('1.2.3')
  14. The difference here is that in later example imports
  15. do not happen until v is called
  16. """
  17. def __init__(self, cls_name: str) -> None:
  18. self._cls_name = cls_name
  19. def get_cls(self):
  20. try:
  21. import packaging.version # type: ignore[import]
  22. except ImportError:
  23. # If packaging isn't installed, try and use the vendored copy
  24. # in pkg_resources
  25. from pkg_resources import packaging # type: ignore[attr-defined, no-redef]
  26. return getattr(packaging.version, self._cls_name)
  27. def __call__(self, *args, **kwargs):
  28. return self.get_cls()(*args, **kwargs)
  29. def __instancecheck__(self, obj):
  30. return isinstance(obj, self.get_cls())
  31. Version = _LazyImport("Version")
  32. InvalidVersion = _LazyImport("InvalidVersion")
  33. class TorchVersion(str):
  34. """A string with magic powers to compare to both Version and iterables!
  35. Prior to 1.10.0 torch.__version__ was stored as a str and so many did
  36. comparisons against torch.__version__ as if it were a str. In order to not
  37. break them we have TorchVersion which masquerades as a str while also
  38. having the ability to compare against both packaging.version.Version as
  39. well as tuples of values, eg. (1, 2, 1)
  40. Examples:
  41. Comparing a TorchVersion object to a Version object
  42. TorchVersion('1.10.0a') > Version('1.10.0a')
  43. Comparing a TorchVersion object to a Tuple object
  44. TorchVersion('1.10.0a') > (1, 2) # 1.2
  45. TorchVersion('1.10.0a') > (1, 2, 1) # 1.2.1
  46. Comparing a TorchVersion object against a string
  47. TorchVersion('1.10.0a') > '1.2'
  48. TorchVersion('1.10.0a') > '1.2.1'
  49. """
  50. # fully qualified type names here to appease mypy
  51. def _convert_to_version(self, inp: Any) -> Any:
  52. if isinstance(inp, Version.get_cls()):
  53. return inp
  54. elif isinstance(inp, str):
  55. return Version(inp)
  56. elif isinstance(inp, Iterable):
  57. # Ideally this should work for most cases by attempting to group
  58. # the version tuple, assuming the tuple looks (MAJOR, MINOR, ?PATCH)
  59. # Examples:
  60. # * (1) -> Version("1")
  61. # * (1, 20) -> Version("1.20")
  62. # * (1, 20, 1) -> Version("1.20.1")
  63. return Version('.'.join((str(item) for item in inp)))
  64. else:
  65. raise InvalidVersion(inp)
  66. def _cmp_wrapper(self, cmp: Any, method: str) -> bool:
  67. try:
  68. return getattr(Version(self), method)(self._convert_to_version(cmp))
  69. except BaseException as e:
  70. if not isinstance(e, InvalidVersion.get_cls()):
  71. raise
  72. # Fall back to regular string comparison if dealing with an invalid
  73. # version like 'parrot'
  74. return getattr(super(), method)(cmp)
  75. for cmp_method in ["__gt__", "__lt__", "__eq__", "__ge__", "__le__"]:
  76. setattr(TorchVersion, cmp_method, lambda x, y, method=cmp_method: x._cmp_wrapper(y, method))
  77. __version__ = TorchVersion(internal_version)