from typing import Any, Iterable from .version import __version__ as internal_version __all__ = ['TorchVersion', 'Version', 'InvalidVersion'] class _LazyImport: """Wraps around classes lazy imported from packaging.version Output of the function v in following snippets are identical: from packaging.version import Version def v(): return Version('1.2.3') and Version = _LazyImport('Version') def v(): return Version('1.2.3') The difference here is that in later example imports do not happen until v is called """ def __init__(self, cls_name: str) -> None: self._cls_name = cls_name def get_cls(self): try: import packaging.version # type: ignore[import] except ImportError: # If packaging isn't installed, try and use the vendored copy # in pkg_resources from pkg_resources import packaging # type: ignore[attr-defined, no-redef] return getattr(packaging.version, self._cls_name) def __call__(self, *args, **kwargs): return self.get_cls()(*args, **kwargs) def __instancecheck__(self, obj): return isinstance(obj, self.get_cls()) Version = _LazyImport("Version") InvalidVersion = _LazyImport("InvalidVersion") class TorchVersion(str): """A string with magic powers to compare to both Version and iterables! Prior to 1.10.0 torch.__version__ was stored as a str and so many did comparisons against torch.__version__ as if it were a str. In order to not break them we have TorchVersion which masquerades as a str while also having the ability to compare against both packaging.version.Version as well as tuples of values, eg. (1, 2, 1) Examples: Comparing a TorchVersion object to a Version object TorchVersion('1.10.0a') > Version('1.10.0a') Comparing a TorchVersion object to a Tuple object TorchVersion('1.10.0a') > (1, 2) # 1.2 TorchVersion('1.10.0a') > (1, 2, 1) # 1.2.1 Comparing a TorchVersion object against a string TorchVersion('1.10.0a') > '1.2' TorchVersion('1.10.0a') > '1.2.1' """ # fully qualified type names here to appease mypy def _convert_to_version(self, inp: Any) -> Any: if isinstance(inp, Version.get_cls()): return inp elif isinstance(inp, str): return Version(inp) elif isinstance(inp, Iterable): # Ideally this should work for most cases by attempting to group # the version tuple, assuming the tuple looks (MAJOR, MINOR, ?PATCH) # Examples: # * (1) -> Version("1") # * (1, 20) -> Version("1.20") # * (1, 20, 1) -> Version("1.20.1") return Version('.'.join((str(item) for item in inp))) else: raise InvalidVersion(inp) def _cmp_wrapper(self, cmp: Any, method: str) -> bool: try: return getattr(Version(self), method)(self._convert_to_version(cmp)) except BaseException as e: if not isinstance(e, InvalidVersion.get_cls()): raise # Fall back to regular string comparison if dealing with an invalid # version like 'parrot' return getattr(super(), method)(cmp) for cmp_method in ["__gt__", "__lt__", "__eq__", "__ge__", "__le__"]: setattr(TorchVersion, cmp_method, lambda x, y, method=cmp_method: x._cmp_wrapper(y, method)) __version__ = TorchVersion(internal_version)