| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137 |
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import paddle
- import paddle.nn as nn
- class TripletLossV2(nn.Layer):
- """Triplet loss with hard positive/negative mining.
- Args:
- margin (float): margin for triplet.
- """
- def __init__(self, margin=0.5, normalize_feature=True):
- super(TripletLossV2, self).__init__()
- self.margin = margin
- self.ranking_loss = paddle.nn.loss.MarginRankingLoss(margin=margin)
- self.normalize_feature = normalize_feature
- def forward(self, input, target):
- """
- Args:
- inputs: feature matrix with shape (batch_size, feat_dim)
- target: ground truth labels with shape (num_classes)
- """
- inputs = input["features"]
- if self.normalize_feature:
- inputs = 1. * inputs / (paddle.expand_as(
- paddle.norm(
- inputs, p=2, axis=-1, keepdim=True), inputs) + 1e-12)
- bs = inputs.shape[0]
- # compute distance
- dist = paddle.pow(inputs, 2).sum(axis=1, keepdim=True).expand([bs, bs])
- dist = dist + dist.t()
- dist = paddle.addmm(
- input=dist, x=inputs, y=inputs.t(), alpha=-2.0, beta=1.0)
- dist = paddle.clip(dist, min=1e-12).sqrt()
- # hard negative mining
- is_pos = paddle.expand(target, (
- bs, bs)).equal(paddle.expand(target, (bs, bs)).t())
- is_neg = paddle.expand(target, (
- bs, bs)).not_equal(paddle.expand(target, (bs, bs)).t())
- # `dist_ap` means distance(anchor, positive)
- ## both `dist_ap` and `relative_p_inds` with shape [N, 1]
- '''
- dist_ap, relative_p_inds = paddle.max(
- paddle.reshape(dist[is_pos], (bs, -1)), axis=1, keepdim=True)
- # `dist_an` means distance(anchor, negative)
- # both `dist_an` and `relative_n_inds` with shape [N, 1]
- dist_an, relative_n_inds = paddle.min(
- paddle.reshape(dist[is_neg], (bs, -1)), axis=1, keepdim=True)
- '''
- dist_ap = paddle.max(paddle.reshape(
- paddle.masked_select(dist, is_pos), (bs, -1)),
- axis=1,
- keepdim=True)
- # `dist_an` means distance(anchor, negative)
- # both `dist_an` and `relative_n_inds` with shape [N, 1]
- dist_an = paddle.min(paddle.reshape(
- paddle.masked_select(dist, is_neg), (bs, -1)),
- axis=1,
- keepdim=True)
- # shape [N]
- dist_ap = paddle.squeeze(dist_ap, axis=1)
- dist_an = paddle.squeeze(dist_an, axis=1)
- # Compute ranking hinge loss
- y = paddle.ones_like(dist_an)
- loss = self.ranking_loss(dist_an, dist_ap, y)
- return {"TripletLossV2": loss}
- class TripletLoss(nn.Layer):
- """Triplet loss with hard positive/negative mining.
- Reference:
- Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737.
- Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py.
- Args:
- margin (float): margin for triplet.
- """
- def __init__(self, margin=1.0):
- super(TripletLoss, self).__init__()
- self.margin = margin
- self.ranking_loss = paddle.nn.loss.MarginRankingLoss(margin=margin)
- def forward(self, input, target):
- """
- Args:
- inputs: feature matrix with shape (batch_size, feat_dim)
- target: ground truth labels with shape (num_classes)
- """
- inputs = input["features"]
- bs = inputs.shape[0]
- # Compute pairwise distance, replace by the official when merged
- dist = paddle.pow(inputs, 2).sum(axis=1, keepdim=True).expand([bs, bs])
- dist = dist + dist.t()
- dist = paddle.addmm(
- input=dist, x=inputs, y=inputs.t(), alpha=-2.0, beta=1.0)
- dist = paddle.clip(dist, min=1e-12).sqrt()
- mask = paddle.equal(
- target.expand([bs, bs]), target.expand([bs, bs]).t())
- mask_numpy_idx = mask.numpy()
- dist_ap, dist_an = [], []
- for i in range(bs):
- # dist_ap_i = paddle.to_tensor(dist[i].numpy()[mask_numpy_idx[i]].max(),dtype='float64').unsqueeze(0)
- # dist_ap_i.stop_gradient = False
- # dist_ap.append(dist_ap_i)
- dist_ap.append(
- max([
- dist[i][j] if mask_numpy_idx[i][j] == True else float(
- "-inf") for j in range(bs)
- ]).unsqueeze(0))
- # dist_an_i = paddle.to_tensor(dist[i].numpy()[mask_numpy_idx[i] == False].min(), dtype='float64').unsqueeze(0)
- # dist_an_i.stop_gradient = False
- # dist_an.append(dist_an_i)
- dist_an.append(
- min([
- dist[i][k] if mask_numpy_idx[i][k] == False else float(
- "inf") for k in range(bs)
- ]).unsqueeze(0))
- dist_ap = paddle.concat(dist_ap, axis=0)
- dist_an = paddle.concat(dist_an, axis=0)
- # Compute ranking hinge loss
- y = paddle.ones_like(dist_an)
- loss = self.ranking_loss(dist_an, dist_ap, y)
- return {"TripletLoss": loss}
|