Keras Conv1Dで心電図の不整脈を検出する

今回は、Keras Conv1Dで MITの心電図の波形データセットから不整脈を検出してみます。

こんにちは cedro です。

最近、Preferred Networks は、日本メディカルAI学会公認資格「メディカルAI専門コース」のオンライン講座資料を一般公開しました。

この講座は、コードを実際に実行しながら学べるように、Google ColaboratoryのGPUマシンを使ってJupyter Notebookで動かせる仕様になっていて、無料で非常に快適な学習が進められるため、おすすめです。

ただ、Preferred Networks が作ったので当然フレームワークは Chainer です。Keras もまだ良く分かっていない私にとっては、本格的にやるのはもう少し後にしようかなと考えているところです。

そうした中、ざっと講座の内容を見たら、8章に「心電図の波形データから不整脈を検知する」という課題があり、MITの心電図の波形データセットがあることが分りました。最近、異常検知にハマっている私にとっては、これはちょっと遅いクリスマスプレゼントの様なものです(笑)。

ということで、今回は、Keras Conv1Dで MITの心電図の波形データセットから不整脈を検知してみます。

 

データセットを作成します

MITの心電図の波形データをダウンロードし、データの前処理をするコードです。

まず、WFDB ライブラリーを pip install wfdb でイントールします。このWFDBライブラリーは、MITの生理学的信号の記録を閲覧、分析、作成するためのもので、MITが過去20年間に渡って収集したデータにアクセスできます。さすが、MIT!

ダウンロードするのは、波形ファイル(dat)、属性ファイル(atr)、位置ファイル(hea)の3種類×各48個=144個です。そして、属性ファイルと位置ファイルの情報を元に、波形ファイルを2秒間隔で切り取り、正常か異常に分類してデータセットを作成します。

7-9行目はデータのダウンロード、12−13行目は波形データの切り取り間隔の指定、16-22行目は学習・テストに使用するファイルの指定、26-28行目はどの属性を正常・異常とするかの指定です。

30-39行目はファイルを読み込む関数、41−52行目は波形データを2秒間隔で切り取る関数、54-65行目は他の2つの関数と連携して作成したデータを保管する関数です。

このコードを dataset.py という名前で適切なフォルダー(egc_anomalyとしました)に保存して、実行します。

MITのサーバーとやり取りしてファイル(トータル94MB)をダウンロードするので少々時間が掛かります。Finished downloading files と表示されればOKです。

 

コード実行後はこんな形になります。一番左がecg_anomaly フォルダーの中で、dataset フォルダーが出来ています。

dataset フォルダーの下に、前処理を行なった x_test.npy, x_train.npy, y_train.npy, y_test.npy4つのnumpy 形式のデータセットが出来ています。

ちなみに、dataset / download フォルダーの下にあるのは、MITからダウンロードした生データです。

 

データセットの内容を見てみます

まず、予備知識として心電図のデータ波形の基本について見てみましょう。

これは、心電図の波形データの模式図(出典:Wikipedia)です。P波は心房の興奮(収縮)、QRS波は心室の興奮(収縮)、T波は心室の回復(拡張)を表しています。

この波形データがどの範囲であれば正常で、どこからが異常なのかを経験豊かな専門家が読み解くのが診断です。今回は、この診断をニューラルネットワークにやらせようと言うことです。

データセットの作成に当たっては、中央にあるQRS波Rのピークを中心として、前後1秒づつ合計2秒間を1つのデータにしています。

 

データセットの内容を見るコードです。

13-29行目は波形データを可視化する部分です。y_trainのラベルが0なら正常、1なら異常なので、それに該当するx_train を選択します。x_trainの1つのデータは720個の数値が並んでいるだけなので、並びを時間軸(Time)に、数値を信号の強さ(Signal strength)にしてグラフ化します。

32-45行目は学習・テストデータのシェイプや正常の個数と異常の個数を表示する部分です。このコードにdataset_disp.py という名前を付けてecg_anomaly フォルダーに保存し、実行します。

 

正常な波形サンプルです。横軸が時間で最大値は720(2秒間)、縦軸がシグナルの強さです。確かに、先程見た模式図と似ていますね。

 

異常な波形サンプルです。P波、T波が大きく、QRS波の中のS波が大きく落ち込んでいて、模式図とは異なっていることが分かります。

 

データのシェイプと数に関する情報です。学習データは47738個で、その内正常データは43995個、異常データは3743個です。テストデータは45349個で、その内正常データは42149個、異常データは3200個です。

 

心電図の波形データを正常か異常か学習・判定するコードです。

今回のデータセットはかなり量があるので、CPUで軽く動く様にデータセットをダイエットしています(勿体ないですが)。

x_trainを20%にダイエットし、その80%(7637個)を学習データ、20%(1910個)をテストデータに使用しています。学習データはさらに、6109個を本当の学習に、1528個をバリデーションに分けて使っています。

 

ネットワークSummary です。最初に、CONV1Dという1次元の畳み込みを使い、その後全結合層を3つ繋いでいます。畳み込みで良く使うのは2次元ですが、今回は1次元の時系列データなので、1次元の畳み込みを使っています。

53-74行目は精度とロスの推移グラフを描く部分、77-98行目はConfusion Matrix を描く部分です。train.py の名前を付けてecg_anmaly フォルダーに保存し、実行します。

データの量を抑えたのでMacbookAirで軽快に動き、 50sec / epochで、5epochが4分ちょっとで完了します。

Validation_accの推移グラフです。5epoch 完了時の値は99.28%で、最終テストデータでのaccyracy(精度)は99.42%でした。ただ、今回の場合、異常データが少ないので、精度だけ見ていては不十分です。

 

Confusion Matrix です。accuracy = (1754+144) / (1754+2+9+144) = 99.42%、precision = 144 / (144+2) = 98.6 %、Recall = 144 / (144+9) = 94.1% です。

今回のネットワークは本当に異常な人をどれくらいの精度で検出できるか、つまりRecall が重要です。結果は94.1%で、本当に異常な153人の内、144人は正しく異常と判断出来、見逃す人は9人と言う結果でした。

まあ、データ数を大幅にダイエットした中では、結構イイ感じではないでしょうか。なお、今回のデータセット作成や内容を見る部分のコードは、Preferred Networks の資料を参考にさせて頂きました。ありがとうございました。

さて、2018年ももうすぐ終わりですね。皆さん良いお年をお迎え下さい。

では、また。

コメントを残す

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です

日本語が含まれない投稿は無視されますのでご注意ください。(スパム対策)

ABOUTこの記事をかいた人

2017年8月に、SONY Neural Network Console に一目惚れして、ディープラーニングを始めました。初心者の試行錯誤をブログにしています。