utils.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. # Ultralytics YOLO 🚀, AGPL-3.0 license
  2. """
  3. Module utils
  4. """
  5. import copy
  6. import math
  7. import numpy as np
  8. import torch
  9. import torch.nn as nn
  10. import torch.nn.functional as F
  11. from torch.nn.init import uniform_
  12. __all__ = 'multi_scale_deformable_attn_pytorch', 'inverse_sigmoid'
  13. def _get_clones(module, n):
  14. return nn.ModuleList([copy.deepcopy(module) for _ in range(n)])
  15. def bias_init_with_prob(prior_prob=0.01):
  16. """initialize conv/fc bias value according to a given probability value."""
  17. return float(-np.log((1 - prior_prob) / prior_prob)) # return bias_init
  18. def linear_init_(module):
  19. bound = 1 / math.sqrt(module.weight.shape[0])
  20. uniform_(module.weight, -bound, bound)
  21. if hasattr(module, 'bias') and module.bias is not None:
  22. uniform_(module.bias, -bound, bound)
  23. def inverse_sigmoid(x, eps=1e-5):
  24. x = x.clamp(min=0, max=1)
  25. x1 = x.clamp(min=eps)
  26. x2 = (1 - x).clamp(min=eps)
  27. return torch.log(x1 / x2)
  28. def multi_scale_deformable_attn_pytorch(value: torch.Tensor, value_spatial_shapes: torch.Tensor,
  29. sampling_locations: torch.Tensor,
  30. attention_weights: torch.Tensor) -> torch.Tensor:
  31. """
  32. Multi-scale deformable attention.
  33. https://github.com/IDEA-Research/detrex/blob/main/detrex/layers/multi_scale_deform_attn.py
  34. """
  35. bs, _, num_heads, embed_dims = value.shape
  36. _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
  37. value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
  38. sampling_grids = 2 * sampling_locations - 1
  39. sampling_value_list = []
  40. for level, (H_, W_) in enumerate(value_spatial_shapes):
  41. # bs, H_*W_, num_heads, embed_dims ->
  42. # bs, H_*W_, num_heads*embed_dims ->
  43. # bs, num_heads*embed_dims, H_*W_ ->
  44. # bs*num_heads, embed_dims, H_, W_
  45. value_l_ = (value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_))
  46. # bs, num_queries, num_heads, num_points, 2 ->
  47. # bs, num_heads, num_queries, num_points, 2 ->
  48. # bs*num_heads, num_queries, num_points, 2
  49. sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1)
  50. # bs*num_heads, embed_dims, num_queries, num_points
  51. sampling_value_l_ = F.grid_sample(value_l_,
  52. sampling_grid_l_,
  53. mode='bilinear',
  54. padding_mode='zeros',
  55. align_corners=False)
  56. sampling_value_list.append(sampling_value_l_)
  57. # (bs, num_queries, num_heads, num_levels, num_points) ->
  58. # (bs, num_heads, num_queries, num_levels, num_points) ->
  59. # (bs, num_heads, 1, num_queries, num_levels*num_points)
  60. attention_weights = attention_weights.transpose(1, 2).reshape(bs * num_heads, 1, num_queries,
  61. num_levels * num_points)
  62. output = ((torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(
  63. bs, num_heads * embed_dims, num_queries))
  64. return output.transpose(1, 2).contiguous()