test_case.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. import contextlib
  2. import importlib
  3. import sys
  4. import torch
  5. import torch.testing
  6. from torch.testing._internal.common_utils import (
  7. IS_WINDOWS,
  8. TEST_WITH_CROSSREF,
  9. TEST_WITH_ROCM,
  10. TEST_WITH_TORCHDYNAMO,
  11. TestCase as TorchTestCase,
  12. )
  13. from . import config, reset, utils
  14. def run_tests(needs=()):
  15. from torch.testing._internal.common_utils import run_tests
  16. if (
  17. TEST_WITH_TORCHDYNAMO
  18. or IS_WINDOWS
  19. or TEST_WITH_CROSSREF
  20. or TEST_WITH_ROCM
  21. or sys.version_info >= (3, 11)
  22. ):
  23. return # skip testing
  24. if isinstance(needs, str):
  25. needs = (needs,)
  26. for need in needs:
  27. if need == "cuda" and not torch.cuda.is_available():
  28. return
  29. else:
  30. try:
  31. importlib.import_module(need)
  32. except ImportError:
  33. return
  34. run_tests()
  35. class TestCase(TorchTestCase):
  36. @classmethod
  37. def tearDownClass(cls):
  38. cls._exit_stack.close()
  39. super().tearDownClass()
  40. @classmethod
  41. def setUpClass(cls):
  42. super().setUpClass()
  43. cls._exit_stack = contextlib.ExitStack()
  44. cls._exit_stack.enter_context(
  45. config.patch(raise_on_ctx_manager_usage=True, suppress_errors=False),
  46. )
  47. def setUp(self):
  48. super().setUp()
  49. reset()
  50. utils.counters.clear()
  51. def tearDown(self):
  52. for k, v in utils.counters.items():
  53. print(k, v.most_common())
  54. reset()
  55. utils.counters.clear()
  56. super().tearDown()