はじめに
- metric learningまわりについて勉強したときのメモです。
- 下記を読んで勉強させていただきました。
metric learningとは
- metric learningとは データ間の計量(距離や類似度など)を学習する手法です。
下記の図のように、同じクラスは近く、異なるクラスは遠くなるように学習します。 http://ml.cecs.ucf.edu/node/28
代表的な手法として下記がある(詳細は後述)
- triplet loss
- n-pair loss
- angular loss
なぜmetric learningを用いるのか
- metric learningは主に、異常検知・個体識別・情報検索(類似画像検索など)などに用いられる
- データ間の計量(距離や類似度)を学習するので、意味的な距離を考慮した学習ができる
- 例えば、スニーカー同士、ブーツ同士、ヒール同士が近くに配置されているだけでなく、ヒールが高いほど右方向に配置されている
- 例えば、スニーカー同士、ブーツ同士、ヒール同士が近くに配置されているだけでなく、ヒールが高いほど右方向に配置されている
https://vision.cornell.edu/se3/embeddings-and-metric-learning/
- classificationではなくmetric learningを用いる理由
- 個体識別などのタスクは、「クラスに属する画像を事前に得ることができない」ため
- クラスあたりのデータ数が少ないケースでは、overfittingしやすいため
triplet loss
metric learningの代表的な手法の一つ
学習にはAnchorとPositiveとNegativeの3枚1組のデータセットを用意する。AnchorとPositiveが近く、AnchorとNegativeが遠くなるように学習する
- アノテーションも楽になる場合がある。(ミスしにくくなる場合がある。)
- -> 似ている順番に人手で並べてラベル付けするのではなく、3枚の画像セットに対して、Anchorを決めて、似ている方をPositive。似ていない方をNegativeとすればよいため。
※ PositiveはAnchorと同じラベルのデータで、Negativeは異なるラベルのデータから選択する場合もある。
- -> 似ている順番に人手で並べてラベル付けするのではなく、3枚の画像セットに対して、Anchorを決めて、似ている方をPositive。似ていない方をNegativeとすればよいため。
- アノテーションも楽になる場合がある。(ミスしにくくなる場合がある。)
https://arxiv.org/pdf/1503.03832.pdf
- 損失関数
https://arxiv.org/pdf/1503.03832.pdf
- コードイメージ
class TripletMarginLoss(nn.Module): def __init__(self, margin): super(TripletMarginLoss, self).__init__() self.margin = margin def forward(self, anchor, positive, negative): dist = torch.sum( torch.pow((anchor - positive),2) - torch.pow((anchor - negative),2), dim=1) + self.margin dist_hinge = torch.clamp(dist, min=0.0) #max(dist, 0.0)と等価 loss = torch.mean(dist_hinge) return loss
triplet lossの改善手法
- triplet lossの弱点
- (dp, dn, αの)相対的な関係値のみを記述している = 学習する際に適度な難しさのpairを選ぶことができていない場合がある
- semi-hard sample miningが必要
- (dp, dn, αの)相対的な関係値のみを記述している = 学習する際に適度な難しさのpairを選ぶことができていない場合がある
- Improved Triplet Loss
- 従来のtriplet lossの項に加えて、anchorとpositiveの距離を一定の大きさβより小さくする項を追加。
- 相対的な位置関係とは別に、クラス内の距離がβより小さくなるように働く。
- 十分に学習が進めば、クラス内距離は全てβ以下となる
- 相対的な位置関係とは別に、クラス内の距離がβより小さくなるように働く。
- 従来のtriplet lossの項に加えて、anchorとpositiveの距離を一定の大きさβより小さくする項を追加。
- Hardness-Aware Deep Metric Learning
※ triplet lossだけでなく、metric lossの種類によらず利用可能- 学習状況に適切な難しさのnegativeサンプルを生成する
- 潜在空間上の線形補間により調整
- 学習状況に適切な難しさのnegativeサンプルを生成する
- ①feature空間のanchor, positive, negative
- ②①をembedding spaceに射影
- ③線形補間により、より難しいz^-を生成
- ④feature spaceに戻す
- ⑤z^-とz-が同じラベルとは限らない->y-と同じラベルになるようなy~-をマップ
- ⑥⑤をembedding spaceに射影
https://arxiv.org/pdf/1903.05503.pdf
- Javg(1つ前のepochのAverage metric loss)が大きくなる->z^-がz-に近づき、easyになる。
https://arxiv.org/pdf/1903.05503.pdf
n-pair loss
- tripletにおけるnegativeサンプルをN個にしたバージョン
N個の異なるラベルのnegativeサンプルを用いることで、一つのpositiveサンプルに対して、各negativeクラス間の相対的な位置関係が分かりやすくなり、学習が安定する
n-pair sampling
- n-pair samplingでデータセットを作り、損失関数としてはangular lossを使う、などもアリ。
- コードイメージ
def n_pair_sampling(base_dir, path_text, n_pair_index_text, epoch_number, N): labels = os.listdir(base_dir) label_names = [] for label in tqdm(labels): images = os.path.join(base_dir, label) for im_name in os.listdir(images): label_names.append(int(label)) path = os.path.join(label, im_name) with open(path_text, mode='a') as f: f.write("{}\n".format(path)) label_names = np.array(label_names) for _ in tqdm(range(epoch_number)): pair_samples = [] categories = [int(i) for i in os.listdir(base_dir)] select_classes = np.random.choice(categories, N, replace=False) for select_class in select_classes: pair_sample = np.random.choice(np.where(label_names==select_class)[0], 2, replace=False) #[x1, x2] pair_samples.append(pair_sample) pair_samples = np.array(pair_samples) # print("pair", pair_samples) anchors = pair_samples[:,0] positives = pair_samples[:,1] # print("anchors", anchors,"positives" , positives) with open(n_pair_index_text, mode='a') as f: for anchor_index in anchors: f.write("{} ".format(anchor_index)) f.write(",") for postive_index in positives: f.write("{} ".format(postive_index)) f.write("\n")
- n-pair loss
- コードイメージ
class n_pair_mc_loss(): def __init__(self): super(n_pair_mc_loss, self).__init__() def forward(self, f, f_p): n_pairs = len(f) term1 = torch.matmul(f, torch.transpose(f_p, 0, 1)) term2 = torch.sum(f * f_p, keepdim=True, dim=1) f_apn = term1 - term2 mask = torch.ones_like(f_apn) - torch.eye(n_pairs).cuda() f_apn = f_apn * mask return torch.mean(torch.logsumexp(f_apn, dim=1))
angular loss
- 角度を使ったloss
- データセットは、n-pair samplingでもよい
- 損失関数(triplet版)
https://arxiv.org/abs/1708.01682
- コードイメージ(triplet版)
class AngularLoss(nn.Module): def __init__(self, alpha=45, in_degree=True): super(AngularLoss, self).__init__() if in_degree: alpha = np.deg2rad(alpha) self.tan_alpha = np.tan(alpha) ** 2 def forward(self, a, p, n): c = (a + p) / 2 sq_dist_ap = (a - p).pow(2).sum(1) sq_dist_nc = (n - c).pow(2).sum(1) loss = sq_dist_ap - 4*self.tan_alpha*sq_dist_nc return F.relu(loss).mean()
おわりに
- n-pair samplingでangular lossがよい?
- 損失関数をどうするかより、negative samplingどうするかが大事で、みんなそこを研究している?
- ...結局、Metric Learningで一番イケてる方法は?