model.py 395 B

12345678910111213141516
  1. import torch.nn as nn
  2. import torchvision.models as models
  3. class EmbeddingNet(nn.Module):
  4. def __init__(self, backbone=None):
  5. super().__init__()
  6. if backbone is None:
  7. backbone = models.resnet50(num_classes=128)
  8. self.backbone = backbone
  9. def forward(self, x):
  10. x = self.backbone(x)
  11. x = nn.functional.normalize(x, dim=1)
  12. return x