metrics.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. from typing import Dict, List, Optional, Tuple
  2. from torch import Tensor
  3. AVAILABLE_METRICS = ["mae", "rmse", "epe", "bad1", "bad2", "epe", "1px", "3px", "5px", "fl-all", "relepe"]
  4. def compute_metrics(
  5. flow_pred: Tensor, flow_gt: Tensor, valid_flow_mask: Optional[Tensor], metrics: List[str]
  6. ) -> Tuple[Dict[str, float], int]:
  7. for m in metrics:
  8. if m not in AVAILABLE_METRICS:
  9. raise ValueError(f"Invalid metric: {m}. Valid metrics are: {AVAILABLE_METRICS}")
  10. metrics_dict = {}
  11. pixels_diffs = (flow_pred - flow_gt).abs()
  12. # there is no Y flow in Stereo Matching, therefore flow.abs() = flow.pow(2).sum(dim=1).sqrt()
  13. flow_norm = flow_gt.abs()
  14. if valid_flow_mask is not None:
  15. valid_flow_mask = valid_flow_mask.unsqueeze(1)
  16. pixels_diffs = pixels_diffs[valid_flow_mask]
  17. flow_norm = flow_norm[valid_flow_mask]
  18. num_pixels = pixels_diffs.numel()
  19. if "bad1" in metrics:
  20. metrics_dict["bad1"] = (pixels_diffs > 1).float().mean().item()
  21. if "bad2" in metrics:
  22. metrics_dict["bad2"] = (pixels_diffs > 2).float().mean().item()
  23. if "mae" in metrics:
  24. metrics_dict["mae"] = pixels_diffs.mean().item()
  25. if "rmse" in metrics:
  26. metrics_dict["rmse"] = pixels_diffs.pow(2).mean().sqrt().item()
  27. if "epe" in metrics:
  28. metrics_dict["epe"] = pixels_diffs.mean().item()
  29. if "1px" in metrics:
  30. metrics_dict["1px"] = (pixels_diffs < 1).float().mean().item()
  31. if "3px" in metrics:
  32. metrics_dict["3px"] = (pixels_diffs < 3).float().mean().item()
  33. if "5px" in metrics:
  34. metrics_dict["5px"] = (pixels_diffs < 5).float().mean().item()
  35. if "fl-all" in metrics:
  36. metrics_dict["fl-all"] = ((pixels_diffs < 3) & ((pixels_diffs / flow_norm) < 0.05)).float().mean().item() * 100
  37. if "relepe" in metrics:
  38. metrics_dict["relepe"] = (pixels_diffs / flow_norm).mean().item()
  39. return metrics_dict, num_pixels