patches.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. """
  3. Monkey patches to update/extend functionality of existing functions
  4. """
  5. from pathlib import Path
  6. import cv2
  7. import numpy as np
  8. import torch
  9. # OpenCV Multilanguage-friendly functions ------------------------------------------------------------------------------
  10. _imshow = cv2.imshow # copy to avoid recursion errors
  11. def imread(filename: str, flags: int = cv2.IMREAD_COLOR):
  12. """Read an image from a file.
  13. Args:
  14. filename (str): Path to the file to read.
  15. flags (int, optional): Flag that can take values of cv2.IMREAD_*. Defaults to cv2.IMREAD_COLOR.
  16. Returns:
  17. (np.ndarray): The read image.
  18. """
  19. return cv2.imdecode(np.fromfile(filename, np.uint8), flags)
  20. def imwrite(filename: str, img: np.ndarray, params=None):
  21. """Write an image to a file.
  22. Args:
  23. filename (str): Path to the file to write.
  24. img (np.ndarray): Image to write.
  25. params (list of ints, optional): Additional parameters. See OpenCV documentation.
  26. Returns:
  27. (bool): True if the file was written, False otherwise.
  28. """
  29. try:
  30. cv2.imencode(Path(filename).suffix, img, params)[1].tofile(filename)
  31. return True
  32. except Exception:
  33. return False
  34. def imshow(winname: str, mat: np.ndarray):
  35. """Displays an image in the specified window.
  36. Args:
  37. winname (str): Name of the window.
  38. mat (np.ndarray): Image to be shown.
  39. """
  40. _imshow(winname.encode('unicode_escape').decode(), mat)
  41. # PyTorch functions ----------------------------------------------------------------------------------------------------
  42. _torch_save = torch.save # copy to avoid recursion errors
  43. def torch_save(*args, **kwargs):
  44. """Use dill (if exists) to serialize the lambda functions where pickle does not do this.
  45. Args:
  46. *args (tuple): Positional arguments to pass to torch.save.
  47. **kwargs (dict): Keyword arguments to pass to torch.save.
  48. """
  49. try:
  50. import dill as pickle # noqa
  51. except ImportError:
  52. import pickle
  53. if 'pickle_module' not in kwargs:
  54. kwargs['pickle_module'] = pickle # noqa
  55. return _torch_save(*args, **kwargs)