_utils.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. from typing import List, Optional, Tuple, Union
  2. import torch
  3. from torch import nn, Tensor
  4. def _cat(tensors: List[Tensor], dim: int = 0) -> Tensor:
  5. """
  6. Efficient version of torch.cat that avoids a copy if there is only a single element in a list
  7. """
  8. # TODO add back the assert
  9. # assert isinstance(tensors, (list, tuple))
  10. if len(tensors) == 1:
  11. return tensors[0]
  12. return torch.cat(tensors, dim)
  13. def convert_boxes_to_roi_format(boxes: List[Tensor]) -> Tensor:
  14. concat_boxes = _cat([b for b in boxes], dim=0)
  15. temp = []
  16. for i, b in enumerate(boxes):
  17. temp.append(torch.full_like(b[:, :1], i))
  18. ids = _cat(temp, dim=0)
  19. rois = torch.cat([ids, concat_boxes], dim=1)
  20. return rois
  21. def check_roi_boxes_shape(boxes: Union[Tensor, List[Tensor]]):
  22. if isinstance(boxes, (list, tuple)):
  23. for _tensor in boxes:
  24. torch._assert(
  25. _tensor.size(1) == 4, "The shape of the tensor in the boxes list is not correct as List[Tensor[L, 4]]"
  26. )
  27. elif isinstance(boxes, torch.Tensor):
  28. torch._assert(boxes.size(1) == 5, "The boxes tensor shape is not correct as Tensor[K, 5]")
  29. else:
  30. torch._assert(False, "boxes is expected to be a Tensor[L, 5] or a List[Tensor[K, 4]]")
  31. return
  32. def split_normalization_params(
  33. model: nn.Module, norm_classes: Optional[List[type]] = None
  34. ) -> Tuple[List[Tensor], List[Tensor]]:
  35. # Adapted from https://github.com/facebookresearch/ClassyVision/blob/659d7f78/classy_vision/generic/util.py#L501
  36. if not norm_classes:
  37. norm_classes = [
  38. nn.modules.batchnorm._BatchNorm,
  39. nn.LayerNorm,
  40. nn.GroupNorm,
  41. nn.modules.instancenorm._InstanceNorm,
  42. nn.LocalResponseNorm,
  43. ]
  44. for t in norm_classes:
  45. if not issubclass(t, nn.Module):
  46. raise ValueError(f"Class {t} is not a subclass of nn.Module.")
  47. classes = tuple(norm_classes)
  48. norm_params = []
  49. other_params = []
  50. for module in model.modules():
  51. if next(module.children(), None):
  52. other_params.extend(p for p in module.parameters(recurse=False) if p.requires_grad)
  53. elif isinstance(module, classes):
  54. norm_params.extend(p for p in module.parameters() if p.requires_grad)
  55. else:
  56. other_params.extend(p for p in module.parameters() if p.requires_grad)
  57. return norm_params, other_params
  58. def _upcast(t: Tensor) -> Tensor:
  59. # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
  60. if t.is_floating_point():
  61. return t if t.dtype in (torch.float32, torch.float64) else t.float()
  62. else:
  63. return t if t.dtype in (torch.int32, torch.int64) else t.int()
  64. def _upcast_non_float(t: Tensor) -> Tensor:
  65. # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
  66. if t.dtype not in (torch.float32, torch.float64):
  67. return t.float()
  68. return t
  69. def _loss_inter_union(
  70. boxes1: torch.Tensor,
  71. boxes2: torch.Tensor,
  72. ) -> Tuple[torch.Tensor, torch.Tensor]:
  73. x1, y1, x2, y2 = boxes1.unbind(dim=-1)
  74. x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1)
  75. # Intersection keypoints
  76. xkis1 = torch.max(x1, x1g)
  77. ykis1 = torch.max(y1, y1g)
  78. xkis2 = torch.min(x2, x2g)
  79. ykis2 = torch.min(y2, y2g)
  80. intsctk = torch.zeros_like(x1)
  81. mask = (ykis2 > ykis1) & (xkis2 > xkis1)
  82. intsctk[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask])
  83. unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsctk
  84. return intsctk, unionk