123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106 |
- from typing import List, Optional, Tuple, Union
- import torch
- from torch import nn, Tensor
- def _cat(tensors: List[Tensor], dim: int = 0) -> Tensor:
- """
- Efficient version of torch.cat that avoids a copy if there is only a single element in a list
- """
- # TODO add back the assert
- # assert isinstance(tensors, (list, tuple))
- if len(tensors) == 1:
- return tensors[0]
- return torch.cat(tensors, dim)
- def convert_boxes_to_roi_format(boxes: List[Tensor]) -> Tensor:
- concat_boxes = _cat([b for b in boxes], dim=0)
- temp = []
- for i, b in enumerate(boxes):
- temp.append(torch.full_like(b[:, :1], i))
- ids = _cat(temp, dim=0)
- rois = torch.cat([ids, concat_boxes], dim=1)
- return rois
- def check_roi_boxes_shape(boxes: Union[Tensor, List[Tensor]]):
- if isinstance(boxes, (list, tuple)):
- for _tensor in boxes:
- torch._assert(
- _tensor.size(1) == 4, "The shape of the tensor in the boxes list is not correct as List[Tensor[L, 4]]"
- )
- elif isinstance(boxes, torch.Tensor):
- torch._assert(boxes.size(1) == 5, "The boxes tensor shape is not correct as Tensor[K, 5]")
- else:
- torch._assert(False, "boxes is expected to be a Tensor[L, 5] or a List[Tensor[K, 4]]")
- return
- def split_normalization_params(
- model: nn.Module, norm_classes: Optional[List[type]] = None
- ) -> Tuple[List[Tensor], List[Tensor]]:
- # Adapted from https://github.com/facebookresearch/ClassyVision/blob/659d7f78/classy_vision/generic/util.py#L501
- if not norm_classes:
- norm_classes = [
- nn.modules.batchnorm._BatchNorm,
- nn.LayerNorm,
- nn.GroupNorm,
- nn.modules.instancenorm._InstanceNorm,
- nn.LocalResponseNorm,
- ]
- for t in norm_classes:
- if not issubclass(t, nn.Module):
- raise ValueError(f"Class {t} is not a subclass of nn.Module.")
- classes = tuple(norm_classes)
- norm_params = []
- other_params = []
- for module in model.modules():
- if next(module.children(), None):
- other_params.extend(p for p in module.parameters(recurse=False) if p.requires_grad)
- elif isinstance(module, classes):
- norm_params.extend(p for p in module.parameters() if p.requires_grad)
- else:
- other_params.extend(p for p in module.parameters() if p.requires_grad)
- return norm_params, other_params
- def _upcast(t: Tensor) -> Tensor:
- # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
- if t.is_floating_point():
- return t if t.dtype in (torch.float32, torch.float64) else t.float()
- else:
- return t if t.dtype in (torch.int32, torch.int64) else t.int()
- def _upcast_non_float(t: Tensor) -> Tensor:
- # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
- if t.dtype not in (torch.float32, torch.float64):
- return t.float()
- return t
- def _loss_inter_union(
- boxes1: torch.Tensor,
- boxes2: torch.Tensor,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- x1, y1, x2, y2 = boxes1.unbind(dim=-1)
- x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1)
- # Intersection keypoints
- xkis1 = torch.max(x1, x1g)
- ykis1 = torch.max(y1, y1g)
- xkis2 = torch.min(x2, x2g)
- ykis2 = torch.min(y2, y2g)
- intsctk = torch.zeros_like(x1)
- mask = (ykis2 > ykis1) & (xkis2 > xkis1)
- intsctk[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask])
- unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsctk
- return intsctk, unionk
|