focal_loss.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import torch
  2. import torch.nn.functional as F
  3. from ..utils import _log_api_usage_once
  4. def sigmoid_focal_loss(
  5. inputs: torch.Tensor,
  6. targets: torch.Tensor,
  7. alpha: float = 0.25,
  8. gamma: float = 2,
  9. reduction: str = "none",
  10. ) -> torch.Tensor:
  11. """
  12. Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
  13. Args:
  14. inputs (Tensor): A float tensor of arbitrary shape.
  15. The predictions for each example.
  16. targets (Tensor): A float tensor with the same shape as inputs. Stores the binary
  17. classification label for each element in inputs
  18. (0 for the negative class and 1 for the positive class).
  19. alpha (float): Weighting factor in range (0,1) to balance
  20. positive vs negative examples or -1 for ignore. Default: ``0.25``.
  21. gamma (float): Exponent of the modulating factor (1 - p_t) to
  22. balance easy vs hard examples. Default: ``2``.
  23. reduction (string): ``'none'`` | ``'mean'`` | ``'sum'``
  24. ``'none'``: No reduction will be applied to the output.
  25. ``'mean'``: The output will be averaged.
  26. ``'sum'``: The output will be summed. Default: ``'none'``.
  27. Returns:
  28. Loss tensor with the reduction option applied.
  29. """
  30. # Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py
  31. if not torch.jit.is_scripting() and not torch.jit.is_tracing():
  32. _log_api_usage_once(sigmoid_focal_loss)
  33. p = torch.sigmoid(inputs)
  34. ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
  35. p_t = p * targets + (1 - p) * (1 - targets)
  36. loss = ce_loss * ((1 - p_t) ** gamma)
  37. if alpha >= 0:
  38. alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
  39. loss = alpha_t * loss
  40. # Check reduction option and return loss accordingly
  41. if reduction == "none":
  42. pass
  43. elif reduction == "mean":
  44. loss = loss.mean()
  45. elif reduction == "sum":
  46. loss = loss.sum()
  47. else:
  48. raise ValueError(
  49. f"Invalid Value for arg 'reduction': '{reduction} \n Supported reduction modes: 'none', 'mean', 'sum'"
  50. )
  51. return loss