12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667 |
- import contextlib
- import importlib
- import sys
- import torch
- import torch.testing
- from torch.testing._internal.common_utils import (
- IS_WINDOWS,
- TEST_WITH_CROSSREF,
- TEST_WITH_ROCM,
- TEST_WITH_TORCHDYNAMO,
- TestCase as TorchTestCase,
- )
- from . import config, reset, utils
- def run_tests(needs=()):
- from torch.testing._internal.common_utils import run_tests
- if (
- TEST_WITH_TORCHDYNAMO
- or IS_WINDOWS
- or TEST_WITH_CROSSREF
- or TEST_WITH_ROCM
- or sys.version_info >= (3, 11)
- ):
- return # skip testing
- if isinstance(needs, str):
- needs = (needs,)
- for need in needs:
- if need == "cuda" and not torch.cuda.is_available():
- return
- else:
- try:
- importlib.import_module(need)
- except ImportError:
- return
- run_tests()
- class TestCase(TorchTestCase):
- @classmethod
- def tearDownClass(cls):
- cls._exit_stack.close()
- super().tearDownClass()
- @classmethod
- def setUpClass(cls):
- super().setUpClass()
- cls._exit_stack = contextlib.ExitStack()
- cls._exit_stack.enter_context(
- config.patch(raise_on_ctx_manager_usage=True, suppress_errors=False),
- )
- def setUp(self):
- super().setUp()
- reset()
- utils.counters.clear()
- def tearDown(self):
- for k, v in utils.counters.items():
- print(k, v.most_common())
- reset()
- utils.counters.clear()
- super().tearDown()
|