dataframe_wrapper.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. _pandas = None
  2. _WITH_PANDAS = None
  3. def _try_import_pandas() -> bool:
  4. try:
  5. import pandas # type: ignore[import]
  6. global _pandas
  7. _pandas = pandas
  8. return True
  9. except ImportError:
  10. return False
  11. # pandas used only for prototyping, will be shortly replaced with TorchArrow
  12. def _with_pandas() -> bool:
  13. global _WITH_PANDAS
  14. if _WITH_PANDAS is None:
  15. _WITH_PANDAS = _try_import_pandas()
  16. return _WITH_PANDAS
  17. class PandasWrapper:
  18. @classmethod
  19. def create_dataframe(cls, data, columns):
  20. if not _with_pandas():
  21. raise Exception("DataFrames prototype requires pandas to function")
  22. return _pandas.DataFrame(data, columns=columns) # type: ignore[union-attr]
  23. @classmethod
  24. def is_dataframe(cls, data):
  25. if not _with_pandas():
  26. return False
  27. return isinstance(data, _pandas.core.frame.DataFrame) # type: ignore[union-attr]
  28. @classmethod
  29. def is_column(cls, data):
  30. if not _with_pandas():
  31. return False
  32. return isinstance(data, _pandas.core.series.Series) # type: ignore[union-attr]
  33. @classmethod
  34. def iterate(cls, data):
  35. if not _with_pandas():
  36. raise Exception("DataFrames prototype requires pandas to function")
  37. yield from data.itertuples(index=False)
  38. @classmethod
  39. def concat(cls, buffer):
  40. if not _with_pandas():
  41. raise Exception("DataFrames prototype requires pandas to function")
  42. return _pandas.concat(buffer) # type: ignore[union-attr]
  43. @classmethod
  44. def get_item(cls, data, idx):
  45. if not _with_pandas():
  46. raise Exception("DataFrames prototype requires pandas to function")
  47. return data[idx: idx + 1]
  48. @classmethod
  49. def get_len(cls, df):
  50. if not _with_pandas():
  51. raise Exception("DataFrames prototype requires pandas to function")
  52. return len(df.index)
  53. @classmethod
  54. def get_columns(cls, df):
  55. if not _with_pandas():
  56. raise Exception("DataFrames prototype requires pandas to function")
  57. return list(df.columns.values.tolist())
  58. # When you build own implementation just override it with dataframe_wrapper.set_df_wrapper(new_wrapper_class)
  59. default_wrapper = PandasWrapper
  60. def get_df_wrapper():
  61. return default_wrapper
  62. def set_df_wrapper(wrapper):
  63. global default_wrapper
  64. default_wrapper = wrapper
  65. def create_dataframe(data, columns=None):
  66. wrapper = get_df_wrapper()
  67. return wrapper.create_dataframe(data, columns)
  68. def is_dataframe(data):
  69. wrapper = get_df_wrapper()
  70. return wrapper.is_dataframe(data)
  71. def get_columns(data):
  72. wrapper = get_df_wrapper()
  73. return wrapper.get_columns(data)
  74. def is_column(data):
  75. wrapper = get_df_wrapper()
  76. return wrapper.is_column(data)
  77. def concat(buffer):
  78. wrapper = get_df_wrapper()
  79. return wrapper.concat(buffer)
  80. def iterate(data):
  81. wrapper = get_df_wrapper()
  82. return wrapper.iterate(data)
  83. def get_item(data, idx):
  84. wrapper = get_df_wrapper()
  85. return wrapper.get_item(data, idx)
  86. def get_len(df):
  87. wrapper = get_df_wrapper()
  88. return wrapper.get_len(df)