12345678910111213141516 |
- import torch.nn as nn
- import torchvision.models as models
- class EmbeddingNet(nn.Module):
- def __init__(self, backbone=None):
- super().__init__()
- if backbone is None:
- backbone = models.resnet50(num_classes=128)
- self.backbone = backbone
- def forward(self, x):
- x = self.backbone(x)
- x = nn.functional.normalize(x, dim=1)
- return x
|