context.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. import os
  2. import subprocess
  3. import contextlib
  4. import functools
  5. import tempfile
  6. import shutil
  7. import operator
  8. import warnings
  9. @contextlib.contextmanager
  10. def pushd(dir):
  11. """
  12. >>> tmp_path = getfixture('tmp_path')
  13. >>> with pushd(tmp_path):
  14. ... assert os.getcwd() == os.fspath(tmp_path)
  15. >>> assert os.getcwd() != os.fspath(tmp_path)
  16. """
  17. orig = os.getcwd()
  18. os.chdir(dir)
  19. try:
  20. yield dir
  21. finally:
  22. os.chdir(orig)
  23. @contextlib.contextmanager
  24. def tarball_context(url, target_dir=None, runner=None, pushd=pushd):
  25. """
  26. Get a tarball, extract it, change to that directory, yield, then
  27. clean up.
  28. `runner` is the function to invoke commands.
  29. `pushd` is a context manager for changing the directory.
  30. """
  31. if target_dir is None:
  32. target_dir = os.path.basename(url).replace('.tar.gz', '').replace('.tgz', '')
  33. if runner is None:
  34. runner = functools.partial(subprocess.check_call, shell=True)
  35. else:
  36. warnings.warn("runner parameter is deprecated", DeprecationWarning)
  37. # In the tar command, use --strip-components=1 to strip the first path and
  38. # then
  39. # use -C to cause the files to be extracted to {target_dir}. This ensures
  40. # that we always know where the files were extracted.
  41. runner('mkdir {target_dir}'.format(**vars()))
  42. try:
  43. getter = 'wget {url} -O -'
  44. extract = 'tar x{compression} --strip-components=1 -C {target_dir}'
  45. cmd = ' | '.join((getter, extract))
  46. runner(cmd.format(compression=infer_compression(url), **vars()))
  47. with pushd(target_dir):
  48. yield target_dir
  49. finally:
  50. runner('rm -Rf {target_dir}'.format(**vars()))
  51. def infer_compression(url):
  52. """
  53. Given a URL or filename, infer the compression code for tar.
  54. >>> infer_compression('http://foo/bar.tar.gz')
  55. 'z'
  56. >>> infer_compression('http://foo/bar.tgz')
  57. 'z'
  58. >>> infer_compression('file.bz')
  59. 'j'
  60. >>> infer_compression('file.xz')
  61. 'J'
  62. """
  63. # cheat and just assume it's the last two characters
  64. compression_indicator = url[-2:]
  65. mapping = dict(gz='z', bz='j', xz='J')
  66. # Assume 'z' (gzip) if no match
  67. return mapping.get(compression_indicator, 'z')
  68. @contextlib.contextmanager
  69. def temp_dir(remover=shutil.rmtree):
  70. """
  71. Create a temporary directory context. Pass a custom remover
  72. to override the removal behavior.
  73. >>> import pathlib
  74. >>> with temp_dir() as the_dir:
  75. ... assert os.path.isdir(the_dir)
  76. ... _ = pathlib.Path(the_dir).joinpath('somefile').write_text('contents')
  77. >>> assert not os.path.exists(the_dir)
  78. """
  79. temp_dir = tempfile.mkdtemp()
  80. try:
  81. yield temp_dir
  82. finally:
  83. remover(temp_dir)
  84. @contextlib.contextmanager
  85. def repo_context(url, branch=None, quiet=True, dest_ctx=temp_dir):
  86. """
  87. Check out the repo indicated by url.
  88. If dest_ctx is supplied, it should be a context manager
  89. to yield the target directory for the check out.
  90. """
  91. exe = 'git' if 'git' in url else 'hg'
  92. with dest_ctx() as repo_dir:
  93. cmd = [exe, 'clone', url, repo_dir]
  94. if branch:
  95. cmd.extend(['--branch', branch])
  96. devnull = open(os.path.devnull, 'w')
  97. stdout = devnull if quiet else None
  98. subprocess.check_call(cmd, stdout=stdout)
  99. yield repo_dir
  100. @contextlib.contextmanager
  101. def null():
  102. """
  103. A null context suitable to stand in for a meaningful context.
  104. >>> with null() as value:
  105. ... assert value is None
  106. """
  107. yield
  108. class ExceptionTrap:
  109. """
  110. A context manager that will catch certain exceptions and provide an
  111. indication they occurred.
  112. >>> with ExceptionTrap() as trap:
  113. ... raise Exception()
  114. >>> bool(trap)
  115. True
  116. >>> with ExceptionTrap() as trap:
  117. ... pass
  118. >>> bool(trap)
  119. False
  120. >>> with ExceptionTrap(ValueError) as trap:
  121. ... raise ValueError("1 + 1 is not 3")
  122. >>> bool(trap)
  123. True
  124. >>> trap.value
  125. ValueError('1 + 1 is not 3')
  126. >>> trap.tb
  127. <traceback object at ...>
  128. >>> with ExceptionTrap(ValueError) as trap:
  129. ... raise Exception()
  130. Traceback (most recent call last):
  131. ...
  132. Exception
  133. >>> bool(trap)
  134. False
  135. """
  136. exc_info = None, None, None
  137. def __init__(self, exceptions=(Exception,)):
  138. self.exceptions = exceptions
  139. def __enter__(self):
  140. return self
  141. @property
  142. def type(self):
  143. return self.exc_info[0]
  144. @property
  145. def value(self):
  146. return self.exc_info[1]
  147. @property
  148. def tb(self):
  149. return self.exc_info[2]
  150. def __exit__(self, *exc_info):
  151. type = exc_info[0]
  152. matches = type and issubclass(type, self.exceptions)
  153. if matches:
  154. self.exc_info = exc_info
  155. return matches
  156. def __bool__(self):
  157. return bool(self.type)
  158. def raises(self, func, *, _test=bool):
  159. """
  160. Wrap func and replace the result with the truth
  161. value of the trap (True if an exception occurred).
  162. First, give the decorator an alias to support Python 3.8
  163. Syntax.
  164. >>> raises = ExceptionTrap(ValueError).raises
  165. Now decorate a function that always fails.
  166. >>> @raises
  167. ... def fail():
  168. ... raise ValueError('failed')
  169. >>> fail()
  170. True
  171. """
  172. @functools.wraps(func)
  173. def wrapper(*args, **kwargs):
  174. with ExceptionTrap(self.exceptions) as trap:
  175. func(*args, **kwargs)
  176. return _test(trap)
  177. return wrapper
  178. def passes(self, func):
  179. """
  180. Wrap func and replace the result with the truth
  181. value of the trap (True if no exception).
  182. First, give the decorator an alias to support Python 3.8
  183. Syntax.
  184. >>> passes = ExceptionTrap(ValueError).passes
  185. Now decorate a function that always fails.
  186. >>> @passes
  187. ... def fail():
  188. ... raise ValueError('failed')
  189. >>> fail()
  190. False
  191. """
  192. return self.raises(func, _test=operator.not_)
  193. class suppress(contextlib.suppress, contextlib.ContextDecorator):
  194. """
  195. A version of contextlib.suppress with decorator support.
  196. >>> @suppress(KeyError)
  197. ... def key_error():
  198. ... {}['']
  199. >>> key_error()
  200. """
  201. class on_interrupt(contextlib.ContextDecorator):
  202. """
  203. Replace a KeyboardInterrupt with SystemExit(1)
  204. >>> def do_interrupt():
  205. ... raise KeyboardInterrupt()
  206. >>> on_interrupt('error')(do_interrupt)()
  207. Traceback (most recent call last):
  208. ...
  209. SystemExit: 1
  210. >>> on_interrupt('error', code=255)(do_interrupt)()
  211. Traceback (most recent call last):
  212. ...
  213. SystemExit: 255
  214. >>> on_interrupt('suppress')(do_interrupt)()
  215. >>> with __import__('pytest').raises(KeyboardInterrupt):
  216. ... on_interrupt('ignore')(do_interrupt)()
  217. """
  218. def __init__(
  219. self,
  220. action='error',
  221. # py3.7 compat
  222. # /,
  223. code=1,
  224. ):
  225. self.action = action
  226. self.code = code
  227. def __enter__(self):
  228. return self
  229. def __exit__(self, exctype, excinst, exctb):
  230. if exctype is not KeyboardInterrupt or self.action == 'ignore':
  231. return
  232. elif self.action == 'error':
  233. raise SystemExit(self.code) from excinst
  234. return self.action == 'suppress'