nokoのブログ

こちらは暫定のメモ置き場ですので悪しからず

metric learningについて調べてたことメモ

はじめに

metric learningとは

  • metric learningとは データ間の計量(距離や類似度など)を学習する手法です。
  • 下記の図のように、同じクラスは近く、異なるクラスは遠くなるように学習します。 スクリーンショット 2019-07-03 22.22.02.png http://ml.cecs.ucf.edu/node/28

  • 代表的な手法として下記がある(詳細は後述)

    • triplet loss
    • n-pair loss
    • angular loss

なぜmetric learningを用いるのか

  • metric learningは主に、異常検知・個体識別・情報検索(類似画像検索など)などに用いられる
  • データ間の計量(距離や類似度)を学習するので、意味的な距離を考慮した学習ができる
    • 例えば、スニーカー同士、ブーツ同士、ヒール同士が近くに配置されているだけでなく、ヒールが高いほど右方向に配置されている
      スクリーンショット 2019-07-04 22.53.52.png

https://vision.cornell.edu/se3/embeddings-and-metric-learning/

  • classificationではなくmetric learningを用いる理由
    • 個体識別などのタスクは、「クラスに属する画像を事前に得ることができない」ため
    • クラスあたりのデータ数が少ないケースでは、overfittingしやすいため

triplet loss

  • metric learningの代表的な手法の一つ

  • 学習にはAnchorとPositiveとNegativeの3枚1組のデータセットを用意する。AnchorとPositiveが近く、AnchorとNegativeが遠くなるように学習する

    • アノテーションも楽になる場合がある。(ミスしにくくなる場合がある。)
      • -> 似ている順番に人手で並べてラベル付けするのではなく、3枚の画像セットに対して、Anchorを決めて、似ている方をPositive。似ていない方をNegativeとすればよいため。
        ※ PositiveはAnchorと同じラベルのデータで、Negativeは異なるラベルのデータから選択する場合もある。

スクリーンショット 2019-07-04 22.56.45.png

https://arxiv.org/pdf/1503.03832.pdf

スクリーンショット 2019-07-08 22.42.18.png

  • 損失関数
  • スクリーンショット 2019-07-08 21.05.59.png

https://arxiv.org/pdf/1503.03832.pdf

  • コードイメージ
class TripletMarginLoss(nn.Module):

def __init__(self, margin):
    super(TripletMarginLoss, self).__init__()
    self.margin = margin

def forward(self, anchor, positive, negative):
    dist = torch.sum(
        torch.pow((anchor - positive),2) - torch.pow((anchor - negative),2),
        dim=1) + self.margin
    dist_hinge = torch.clamp(dist, min=0.0)  #max(dist, 0.0)と等価
    loss = torch.mean(dist_hinge)
    return loss

triplet lossの改善手法

  • triplet lossの弱点
    • (dp, dn, αの)相対的な関係値のみを記述している = 学習する際に適度な難しさのpairを選ぶことができていない場合がある
      • semi-hard sample miningが必要
  • Improved Triplet Loss
    • 従来のtriplet lossの項に加えて、anchorとpositiveの距離を一定の大きさβより小さくする項を追加。
      • 相対的な位置関係とは別に、クラス内の距離がβより小さくなるように働く。
        • 十分に学習が進めば、クラス内距離は全てβ以下となる

スクリーンショット 2019-07-08 23.07.14.png

  • Hardness-Aware Deep Metric Learning
    ※ triplet lossだけでなく、metric lossの種類によらず利用可能
    • 学習状況に適切な難しさのnegativeサンプルを生成する
      • 潜在空間上の線形補間により調整

スクリーンショット 2019-07-09 12.13.16.png

  • ①feature空間のanchor, positive, negative
  • ②①をembedding spaceに射影
  • ③線形補間により、より難しいz^-を生成
  • ④feature spaceに戻す
  • ⑤z^-とz-が同じラベルとは限らない->y-と同じラベルになるようなy~-をマップ
  • ⑥⑤をembedding spaceに射影

スクリーンショット 2019-07-08 23.08.39.png https://arxiv.org/pdf/1903.05503.pdf

  • Javg(1つ前のepochのAverage metric loss)が大きくなる->z^-がz-に近づき、easyになる。

スクリーンショット 2019-07-08 23.23.16.png

https://arxiv.org/pdf/1903.05503.pdf

n-pair loss

  • tripletにおけるnegativeサンプルをN個にしたバージョン
  • N個の異なるラベルのnegativeサンプルを用いることで、一つのpositiveサンプルに対して、各negativeクラス間の相対的な位置関係が分かりやすくなり、学習が安定する

  • n-pair sampling

    • n-pair samplingでデータセットを作り、損失関数としてはangular lossを使う、などもアリ。

スクリーンショット 2019-07-08 22.19.38.png

http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf

  • コードイメージ
def n_pair_sampling(base_dir, path_text, n_pair_index_text, epoch_number, N):
    labels = os.listdir(base_dir)
    label_names = []
    for label in tqdm(labels):
        images = os.path.join(base_dir, label)
        for im_name in os.listdir(images):
            label_names.append(int(label))
            path = os.path.join(label, im_name)
            with open(path_text, mode='a') as f:
                f.write("{}\n".format(path))

    label_names = np.array(label_names)
    for _ in tqdm(range(epoch_number)):
        pair_samples = []
        categories = [int(i) for i in os.listdir(base_dir)]
        select_classes = np.random.choice(categories, N, replace=False)
        for select_class in select_classes:
            pair_sample = np.random.choice(np.where(label_names==select_class)[0], 2, replace=False)
            #[x1, x2]
            pair_samples.append(pair_sample)
        pair_samples = np.array(pair_samples)
        # print("pair", pair_samples)
        anchors = pair_samples[:,0]
        positives = pair_samples[:,1]
        # print("anchors", anchors,"positives" , positives)
        with open(n_pair_index_text, mode='a') as f:
            for anchor_index in anchors:
                f.write("{} ".format(anchor_index))
            f.write(",")
            for postive_index in positives:
                f.write("{} ".format(postive_index))
            f.write("\n")
  • n-pair loss

スクリーンショット 2019-07-08 22.21.21.png

http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf

  • コードイメージ
class n_pair_mc_loss():
    def __init__(self):
        super(n_pair_mc_loss, self).__init__()
    def forward(self, f, f_p):
        n_pairs = len(f)
        term1 = torch.matmul(f, torch.transpose(f_p, 0, 1))
        term2 = torch.sum(f * f_p, keepdim=True, dim=1)
        f_apn = term1 - term2
        mask = torch.ones_like(f_apn) - torch.eye(n_pairs).cuda()
        f_apn = f_apn * mask
        return torch.mean(torch.logsumexp(f_apn, dim=1))

angular loss

  • 角度を使ったloss
  • データセットは、n-pair samplingでもよい

スクリーンショット 2019-07-08 22.23.05.png

  • 損失関数(triplet版)

スクリーンショット 2019-07-08 22.23.50.png

https://arxiv.org/abs/1708.01682

  • コードイメージ(triplet版)
class AngularLoss(nn.Module):
    def __init__(self, alpha=45, in_degree=True):
        super(AngularLoss, self).__init__()
        if in_degree:
            alpha = np.deg2rad(alpha)
        self.tan_alpha = np.tan(alpha) ** 2

    def forward(self, a, p, n):
        c = (a + p) / 2
        sq_dist_ap = (a - p).pow(2).sum(1)
        sq_dist_nc = (n - c).pow(2).sum(1)
        loss = sq_dist_ap - 4*self.tan_alpha*sq_dist_nc
        return F.relu(loss).mean()

おわりに

  • n-pair samplingでangular lossがよい?
  • 損失関数をどうするかより、negative samplingどうするかが大事で、みんなそこを研究している?
  • ...結局、Metric Learningで一番イケてる方法は?