test_data.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. from scipy.datasets._registry import registry
  2. from scipy.datasets._fetchers import data_fetcher
  3. from scipy.datasets._utils import _clear_cache
  4. from scipy.datasets import ascent, face, electrocardiogram, download_all
  5. from numpy.testing import assert_equal, assert_almost_equal
  6. import os
  7. import pytest
  8. try:
  9. import pooch
  10. except ImportError:
  11. raise ImportError("Missing optional dependency 'pooch' required "
  12. "for scipy.datasets module. Please use pip or "
  13. "conda to install 'pooch'.")
  14. data_dir = data_fetcher.path # type: ignore
  15. def _has_hash(path, expected_hash):
  16. """Check if the provided path has the expected hash."""
  17. if not os.path.exists(path):
  18. return False
  19. return pooch.file_hash(path) == expected_hash
  20. class TestDatasets:
  21. @pytest.fixture(scope='module', autouse=True)
  22. def test_download_all(self):
  23. # This fixture requires INTERNET CONNECTION
  24. # test_setup phase
  25. download_all()
  26. yield
  27. def test_existence_all(self):
  28. assert len(os.listdir(data_dir)) >= len(registry)
  29. def test_ascent(self):
  30. assert_equal(ascent().shape, (512, 512))
  31. # hash check
  32. assert _has_hash(os.path.join(data_dir, "ascent.dat"),
  33. registry["ascent.dat"])
  34. def test_face(self):
  35. assert_equal(face().shape, (768, 1024, 3))
  36. # hash check
  37. assert _has_hash(os.path.join(data_dir, "face.dat"),
  38. registry["face.dat"])
  39. def test_electrocardiogram(self):
  40. # Test shape, dtype and stats of signal
  41. ecg = electrocardiogram()
  42. assert_equal(ecg.dtype, float)
  43. assert_equal(ecg.shape, (108000,))
  44. assert_almost_equal(ecg.mean(), -0.16510875)
  45. assert_almost_equal(ecg.std(), 0.5992473991177294)
  46. # hash check
  47. assert _has_hash(os.path.join(data_dir, "ecg.dat"),
  48. registry["ecg.dat"])
  49. def test_clear_cache(tmp_path):
  50. # Note: `tmp_path` is a pytest fixture, it handles cleanup
  51. dummy_basepath = tmp_path / "dummy_cache_dir"
  52. dummy_basepath.mkdir()
  53. # Create three dummy dataset files for dummy dataset methods
  54. dummy_method_map = {}
  55. for i in range(4):
  56. dummy_method_map[f"data{i}"] = [f"data{i}.dat"]
  57. data_filepath = dummy_basepath / f"data{i}.dat"
  58. data_filepath.write_text("")
  59. # clear files associated to single dataset method data0
  60. # also test callable argument instead of list of callables
  61. def data0():
  62. pass
  63. _clear_cache(datasets=data0, cache_dir=dummy_basepath,
  64. method_map=dummy_method_map)
  65. assert not os.path.exists(dummy_basepath/"data0.dat")
  66. # clear files associated to multiple dataset methods "data3" and "data4"
  67. def data1():
  68. pass
  69. def data2():
  70. pass
  71. _clear_cache(datasets=[data1, data2], cache_dir=dummy_basepath,
  72. method_map=dummy_method_map)
  73. assert not os.path.exists(dummy_basepath/"data1.dat")
  74. assert not os.path.exists(dummy_basepath/"data2.dat")
  75. # clear multiple dataset files "data3_0.dat" and "data3_1.dat"
  76. # associated with dataset method "data3"
  77. def data4():
  78. pass
  79. # create files
  80. (dummy_basepath / "data4_0.dat").write_text("")
  81. (dummy_basepath / "data4_1.dat").write_text("")
  82. dummy_method_map["data4"] = ["data4_0.dat", "data4_1.dat"]
  83. _clear_cache(datasets=[data4], cache_dir=dummy_basepath,
  84. method_map=dummy_method_map)
  85. assert not os.path.exists(dummy_basepath/"data4_0.dat")
  86. assert not os.path.exists(dummy_basepath/"data4_1.dat")
  87. # wrong dataset method should raise ValueError since it
  88. # doesn't exist in the dummy_method_map
  89. def data5():
  90. pass
  91. with pytest.raises(ValueError):
  92. _clear_cache(datasets=[data5], cache_dir=dummy_basepath,
  93. method_map=dummy_method_map)
  94. # remove all dataset cache
  95. _clear_cache(datasets=None, cache_dir=dummy_basepath)
  96. assert not os.path.exists(dummy_basepath)