|
há 1 mês atrás | |
---|---|---|
.. | ||
README.md | há 1 mês atrás | |
loss.py | há 1 mês atrás | |
model.py | há 1 mês atrás | |
sampler.py | há 1 mês atrás | |
test.py | há 1 mês atrás | |
train.py | há 1 mês atrás |
In this reference, we use triplet loss to learn embeddings which can be used to differentiate images. This learning technique was popularized by FaceNet: A Unified Embedding for Face Recognition and Clustering and has been quite effective in learning embeddings to differentiate between faces.
This reference can be directly applied to the following use cases:
By default, the training script trains ResNet50 on the FashionMNIST Dataset to learn image embeddings which can be used to differentiate between images by measuring the euclidean distance between embeddings. This can be changed as per your requirements.
Image embeddings of the same class should be 'close' to each other, while image embeddings between different classes should be 'far' away.
To run the training script:
python train.py -h # Lists all optional arguments
python train.py # Runs training script with default args
Running the training script as is should yield 97% accuracy on the FMNIST test set within 10 epochs.
TripletMarginLoss
is a loss function which takes in a triplet of samples. A valid triplet has an:
TripletMarginLoss
(refer to loss.py
) does the following:
loss = max(dist(anchor, positive) - dist(anchor, negative) + margin, 0)
Where dist
is a distance function. Minimizing this function effectively leads to minimizing dist(anchor, positive)
and maximizing dist(anchor, negative)
.
The FaceNet paper describe this loss in more detail.
In order to generate valid triplets from a batch of samples, we need to make sure that each batch has multiple samples with the same label. We do this using PKSampler
(refer to sampler.py
), which ensures that each batch of size p * k
will have samples from exactly p
classes and k
samples per class.
TripletMarginLoss
currently supports the following mining techniques:
batch_all
: Generates all possible triplets from a batch and excludes the triplets which are 'easy' (which have loss = 0
) before passing it through the loss function.batch_hard
: For every anchor, batch_hard
creates a triplet with the 'hardest' positive (farthest positive) and negative (closest negative).These mining strategies usually speed up training.
This webpage describes the sampling and mining strategies in more detail.