123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123 |
- from scipy.datasets._registry import registry
- from scipy.datasets._fetchers import data_fetcher
- from scipy.datasets._utils import _clear_cache
- from scipy.datasets import ascent, face, electrocardiogram, download_all
- from numpy.testing import assert_equal, assert_almost_equal
- import os
- import pytest
- try:
- import pooch
- except ImportError:
- raise ImportError("Missing optional dependency 'pooch' required "
- "for scipy.datasets module. Please use pip or "
- "conda to install 'pooch'.")
- data_dir = data_fetcher.path # type: ignore
- def _has_hash(path, expected_hash):
- """Check if the provided path has the expected hash."""
- if not os.path.exists(path):
- return False
- return pooch.file_hash(path) == expected_hash
- class TestDatasets:
- @pytest.fixture(scope='module', autouse=True)
- def test_download_all(self):
- # This fixture requires INTERNET CONNECTION
- # test_setup phase
- download_all()
- yield
- def test_existence_all(self):
- assert len(os.listdir(data_dir)) >= len(registry)
- def test_ascent(self):
- assert_equal(ascent().shape, (512, 512))
- # hash check
- assert _has_hash(os.path.join(data_dir, "ascent.dat"),
- registry["ascent.dat"])
- def test_face(self):
- assert_equal(face().shape, (768, 1024, 3))
- # hash check
- assert _has_hash(os.path.join(data_dir, "face.dat"),
- registry["face.dat"])
- def test_electrocardiogram(self):
- # Test shape, dtype and stats of signal
- ecg = electrocardiogram()
- assert_equal(ecg.dtype, float)
- assert_equal(ecg.shape, (108000,))
- assert_almost_equal(ecg.mean(), -0.16510875)
- assert_almost_equal(ecg.std(), 0.5992473991177294)
- # hash check
- assert _has_hash(os.path.join(data_dir, "ecg.dat"),
- registry["ecg.dat"])
- def test_clear_cache(tmp_path):
- # Note: `tmp_path` is a pytest fixture, it handles cleanup
- dummy_basepath = tmp_path / "dummy_cache_dir"
- dummy_basepath.mkdir()
- # Create three dummy dataset files for dummy dataset methods
- dummy_method_map = {}
- for i in range(4):
- dummy_method_map[f"data{i}"] = [f"data{i}.dat"]
- data_filepath = dummy_basepath / f"data{i}.dat"
- data_filepath.write_text("")
- # clear files associated to single dataset method data0
- # also test callable argument instead of list of callables
- def data0():
- pass
- _clear_cache(datasets=data0, cache_dir=dummy_basepath,
- method_map=dummy_method_map)
- assert not os.path.exists(dummy_basepath/"data0.dat")
- # clear files associated to multiple dataset methods "data3" and "data4"
- def data1():
- pass
- def data2():
- pass
- _clear_cache(datasets=[data1, data2], cache_dir=dummy_basepath,
- method_map=dummy_method_map)
- assert not os.path.exists(dummy_basepath/"data1.dat")
- assert not os.path.exists(dummy_basepath/"data2.dat")
- # clear multiple dataset files "data3_0.dat" and "data3_1.dat"
- # associated with dataset method "data3"
- def data4():
- pass
- # create files
- (dummy_basepath / "data4_0.dat").write_text("")
- (dummy_basepath / "data4_1.dat").write_text("")
- dummy_method_map["data4"] = ["data4_0.dat", "data4_1.dat"]
- _clear_cache(datasets=[data4], cache_dir=dummy_basepath,
- method_map=dummy_method_map)
- assert not os.path.exists(dummy_basepath/"data4_0.dat")
- assert not os.path.exists(dummy_basepath/"data4_1.dat")
- # wrong dataset method should raise ValueError since it
- # doesn't exist in the dummy_method_map
- def data5():
- pass
- with pytest.raises(ValueError):
- _clear_cache(datasets=[data5], cache_dir=dummy_basepath,
- method_map=dummy_method_map)
- # remove all dataset cache
- _clear_cache(datasets=None, cache_dir=dummy_basepath)
- assert not os.path.exists(dummy_basepath)
|