Modeling the Distribution of Normal Data in Pre-Trained Deep Features for Anomaly Detection

モチベ

特定ドメインの良品画像のみを利用して異常検知を実装することが多い中で、オープンなデータセットをうまく使えないかなと悶々としていたら

もはやオープンなデータセット利用を超えて、公開されてる学習済みモデルを利用した異常検知手法が提案されたとの噂を聞いた

精度もめちゃくちゃいいとのこと。しゅごい。

これは確認しないといけないよね、ということでこちらの内容をまとめた

簡単まとめ

どんな研究?

  • ImageNetで学習済みのモデル(ResNet, EfficientNet)から最終レイヤーを除くレイヤーの出力をチャネル単位でGlobalAveragePoolingしたものを入力として利用

  • 各レイヤー、全レイヤーで多変量ガウス分布の母数を学習データから算出、マハラノビス距離を異常度のスコアとして異常画像の検知

  • MVTecADデータセットを用いた異常検知でAUROC95.8±1.2%のSoTAを達成

何がすごいの?

多変量ガウス分布の母数を求める部分以外ではデータにfittingしていないのにSoTAを達成(巨人の肩に乗れる可能性大)

どう使えそう?

  • DNN部分の学習の必要がないので、時間のない時にささっと試せる

  • 新しい画像系の手法の学習済みモデルが出たら、それをささっと流用して異常検知モデルを作れる

使えない場面は?

  • Image単位の手法なので、pixel単位の異常位置を知る必要のあるタスクには利用できない

  • 異常部位を見せることができないため、説明ができない

所感

イメージ単位でブラックボックスな手法で良い場合、この手法を適用してあげれば十分かも

詳細

概要

一般の異常検知タスクには下記の2点の特徴があるため、データセットに大きな偏りがでちゃう

  • 異常の発生はまれ

  • 異常の定義があいまい

結果、異常検知タスクには正常データのみを利用する半教師あり学習でモデルを作ることが多い

ただ、正常データ自体も件数が決して多いとは言えない

こういったタスクにはImageNet等の大規模データベースのデータを用いたモデルを作るのが有効なはずだが、残念ながらこういった方向の研究はあまり進んでいない

本研究の貢献は下記

  • 学習済みのdeepなfeature representationを異常検知タスクで利用できるようにした

  • 正常データのfeature representaionに多変量ガウス分布を当てはめ、そこからのマハラノビス距離を異常度として定義したモデルでMVTecADの異常検知タスクを解いたところ、SoTAを達成

  • feature representationを次元削減したモデルの精度をみることで、feature representationにおいて分散の小さいデータが異常検知タスクにとって重要という知見を獲得

先行研究

本提案の発想の元となった論文はLee et al.の下記論文とのこと

A Simple Unified Framework for Detecting Out-of-Distribution Samples and Adversarial Attacks

この論文では、分類タスクに関する学習済みモデルの予測結果に対してモデルの中間出力のマハラノビス距離をConfidence Scoreとして算出することで、OODやAdversarial Sampleを検知できることを示した

本提案はモデルの中間出力をImageNetで学習済みのモデルの中間出力を変更とすることで、異常検知タスクに転移学習的な要素を組み込めるように拡張したもの

方法

提案手法

  1. ImageNetで学習済みのモデルに学習用正常データを入力し中間出力のfeature representaionを取得

  2. 学習用feature representationに対して多変量ガウス分布を当てはめ母数を算出

  3. 多変量ガウス分布からの距離を異常度の指標として各データに対して異常度を算出し、異常値を判定

中間出力のfeature representationについて

ここではEfficientNet-B0を学習済みモデルとして利用する場合を例として示す

f:id:okehara_aoi:20201014183910p:plain

上図はEfficientNet-B0の構造を簡易的に模した図

本提案では単位構造ごとの中間出力をチャネル単位でGlobalAveragePoolingをとったもの(※論文での記載はないが実装ではそうなってた)をfeature representationとして利用

この時feature representationを入力に近い側からLevelという単位でラベルづけし、各レベルのみを使った場合と全レベルの結果を足し合わせた場合を検討している

※ここではEfficientNet-B0の場合を示したがResNet等を用いた場合も基本的な考え方は一緒

距離を用いた異常度算出について

本論ではユークリッド距離の他にマハラノビス距離を利用しているので、その簡単な説明

マハラノビス距離はxをデータ点、μを平均、Σを分散共分散行列とした時、下記のように表せる

f:id:okehara_aoi:20201014184152p:plain

マハラノビス距離は特徴量の分散と共分散を考慮した指標といえる

異常度をマハラノビス距離(d)で求めるとその閾値は確率的に求められるというメリットがある

特定のd内に正常サンプルが存在する確率pとすると、pは特定距離dでの判別機のTrue Negative Rate(TNR)と見なせ、1-pはFalse Positive Rate(FPR)と見なせる

Mをマハラノビス距離、tを距離の閾値、Fをカイ2乗分布の累積確率密度関数、Dを特長量次元とすると、下記のようにtと累積密度関数の関係式が成り立つ(ガウス分布からサンプルを取得した場合、マハラノビス距離の2乗はカイ2乗分布になることについて興味がある方はこちらからどうぞ)

f:id:okehara_aoi:20201014184228p:plain

これをtに関して解くと

f:id:okehara_aoi:20201014184245p:plain

となるので、与えられたFPRに関して閾値を計算できることがわかる

最終的にこのマハラノビス距離を異常度の指標として用いるが、この有効性は実験で確認しているので説明はそちらで

実験

本論文では下記を目的とした実験を実施

  1. 異常度指標選定とモデルの挙動確認

  2. 学習済みモデルアーキテクチャ選定

  3. 学習済みモデル複雑度の影響確認

  4. 有効な特徴量の性質検証

  5. 閾値の決め方の妥当性検証

  6. 他手法との精度比較

上記実験はMVTecADのデータを利用しており、カテゴリ ごとに5-foldのCVを行っているため結果は、平均と標準誤差(SEM)で記載

1. 異常度指標選定とモデルの挙動確認

異常度として平均値からの距離を利用するのはイメージしやすいが、どのような距離の定義を利用するかについては検討の余地がある

ここでは下記の距離の定義を用いて分類の精度がどう変化するかを確認

  • L2

    • シンプルな平均からの距離
  • Standard Euclidian Distance(SED)

    • 特徴量内の分散を考慮した距離
  • Mahalanobis

    • 特徴量内、特徴量間の分散を考慮した距離
f:id:okehara_aoi:20201014184330p:plain

上図はEfficientNet-B4で各距離指標を用いた場合のAUROC

全てのLevelにおいてマハラノビス距離を用いた場合の精度が圧倒的に良いことが確認できる

この結果を受け異常度の指標としてマハラノビス距離を用いることを決定した

またこの結果より下記の2点が確認できる

  • 高いAUROCを実現していることから、多変量ガウス分布を利用したモデルが妥当であること

  • 異常検知タスクの転移学習では抽象度(Level)の高い特徴量が有用であること

2.学習済みモデルアーキテクチャ選定

EfficientNetの他にResNet-18, 34, 50のアーキテクチャを試したが最良のAUROCで89.0%±3.0%の結果しか得られなかったため、以降EfficientNetをメインで用いることに決定

3.学習済みモデル複雑度の影響確認

EfficientNetにはB0 からB7の種類が存在し、添字の値が大きくなるにつれモデルの複雑度が上がる

この実験では複雑度の変化が精度に与える影響について確認

f:id:okehara_aoi:20201014184426p:plain

上図の結果からB4、B5を頂点とした緩やかな山が形成されていることが確認できる

B5以上の複雑度のモデルで精度が改善しない理由としては、モデルがImageNetに対してオーバーフィットしてしまい、新しいドメインへの適用に適さなくなっていることなどが想定される

この結果を受け、以降はメインのモデルとしてEfficientNet-B4を、複雑度の低いモデルとしてEfficientNet-B0を用いて検証することに決定

4.有効な特徴量の性質検証実験

ここまでの結果からもdeepなfeature representationに対して多変量ガウス分布を転移学習的に当てはめる提案手法が異常検知のタスクに有効であることがわかる

ドメインに依存しないはずのfeature representationがなぜ有効なのだろうか

筆者らは「異常検知タスクにとって重要な特徴は正常データ内で大きく変化しないのではないか」という仮説をたてた

これを検証するために多変量ガウス分布を当てはめる前のfeature representationに対して下記の操作を加えたモデルで精度を検証した

  • PCA : 分散の小さい特徴を除く

  • Negated PCA(NPCA):分散の小さい特徴を残す

f:id:okehara_aoi:20201014184609p:plain

上図より分散の小さい特徴を除いた(PCA)モデルの精度は悪化し、分散の小さい特徴のみを残す(NPCA)モデルの精度は維持されていることがわかる

この結果を受け、筆者らは仮説を立証できたとしている

また、正常データのみから学習する半教師あり学習モデルの精度が事前学習済みの特徴量を利用したモデルの精度に劣る理由もここにあるのではないかとしている

※理解しきれてませんが、AEなどで埋め込みを実施する場合、正常データ中における分散の小さい細かな特徴が切り捨てられてしまうから、うまくいかないと言っているような気がする。。。

検証目的とは別だが、今回の結果より分散の小さい特徴のみを残す方法が異常検知タスクの次元削減に有効な可能性が示唆されたとしている

5.閾値の決め方の妥当性検証実験

前に軽く述べたが、マハラノビス距離を利用すると距離の閾値は与えられたFPRに対して一意に定めることができるはず

与えられたターゲットのFPRを動かした際に、実際のFPRやTPRが想定した挙動を示さない場合、閾値による異常検知ができなくなってしまう。ので、その確認

f:id:okehara_aoi:20201014185056p:plain

※FPRを全Levelを用いたモデルに対して設定するのが難しいため、ここではLevel7で作られたEfficientNet-B0とB4モデルを対象とした。また、データサイズが小さいためかデータのaugmentationを実施した上で実験を実施

上図よりTarget FPRを徐々に減少させることにより、実際のFPRも徐々に減少しており、想定した動きをしていることがわかる(ただ、augmentationを実施しなかった場合はうまくいかなかったそう。この点の検証は今後やっていくそう)

6.他手法との精度比較

MVTecADのデータセットを対象として精度比較を実施

※Pre-Trained Classifierはモデルの上限を知るために、全データを用いた教師あり二値分類問題を解いたものなので、無視して大丈夫

f:id:okehara_aoi:20201014185120p:plain

全カテゴリのAUROC平均で比較すると、当時SoTAだったSPADEと比較して10ポイント近く更新してSoTAを達成したとのこと。

PapersWithCodeで見ると現在のSoTAがDifferNetで94.9%なので、それよりもいいみたい。

めでたし、めでたし!

最後にカテゴリごとの精度を確認

f:id:okehara_aoi:20201014185227p:plain

カテゴリ ごとの結果は上記、PillやScrewと行った小さくて対称性の高いタスクは苦手な感じがありそう

今後の展望

  • 元になるモデルを異常検知タスクに関してfine tuningすることで予測精度のさらなる改善がはかれそう

  • 今回の提案した手法をマルチモーダルな異常検知に応用したい

まとめ

  • 学習済みのdeepなfeature representationを異常検知タスクで利用できるようにした

  • 正常データのfeature representaionに多変量ガウス分布を当てはめ、そこからのマハラノビス距離を異常度として定義したモデルでMVTecの異常検知タスクを解いたところ、SoTAを達成

  • feature representationをPCAにかけ次元削減したモデルの精度をみることで、学習済みモデルに正常データを食わせた結果得られるfeature representationにおいて分散の小さいデータが異常検知タスクにとって重要という知見を獲得

リンク

Modeling the Distribution of Normal Data in Pre-Trained Deep Features for Anomaly Detection
byungjae89/MahalanobisAD-pytorch(論文の実装)
A Simple Unified Framework for Detecting Out-of-Distribution Samples and Adversarial Attacks