[tensorflow.js x kerasの使い方] LSTMのステートフルにすると何がいいの?かを検証する


2020/12/27

2020年も残すところ後少しになりました。

弊社も色々と苦境の中にあって激動の年になりましたが、人工知能の話題もかなり充実しており他の業界と比べてかなり活気のある一年であったように思います。

さて、前回はtensorflow.js/kerasでLSTMをステートレス(デフォルト)でシンプルな正弦波モデルを解いてみる具体例を紹介してみました。

折角の機会ですので、ステートレスLSTMが出来たならステートフルLSTMもやってみいという天の声が聞こえてきましたので、何がどう違うのかを比較してみたいと思います。

ただし、KerasのLSTMモデルをステートフルで使う方法は非常に扱いにくいです。

使ってみたところでご利益があるのかどうかは適用する問題次第ですし、利用側のプログラマーはLSTMという構造の中身を正しく理解しておく必要があります。

本記事では具体的な実装のポイントに重点をおいた内容になってますので、原理的な詳しい違いは
こちらの方の記事が詳しく解説されておりますので、根本的にステートフルLSTMを理解したい方はそちらをご参照ください。

では、ステートフルLSTMモデルの使い方と評価を検討してみます。


ステートフルLSTMは訓練も予測もかなり速い

ステートフルLSTMを使う大きなメリットは、ステートレスLSTMと比較して学習回数が少なくて済むので計算の時短ができることと、バッチデータがシャッフルしないので結果に再現性があること、などが挙げられます。

適用する問題にもよるでしょうが、今回の評価した具体例からでいうとステートレスLSTMの方が計算時間が半分程度になっているようです。

ならば全部ステートフルLSTMを使って解けばいいんじゃない、というとそうではなく、モデルの扱いがとても面倒臭くなり、だたでさえステートレスLSTMでも出入力形式がややこしいのですが、ステートフルになるとデータのバッチサイズと出力層側の考慮も必要になります。

またKerasのステートフルLSTMモデルのもっとも厄介なデメリットとして、訓練済みモデル自体での予測は不可能で、その学習した重み係数だけを頂いて別のモデルで予測する、という謎テクニックを使う必要があり、tensorflowの初学者の方の躓きやすい障害になっています。

ただ、このハードルを乗り越えられたら、さぞやメリットも大きいはず...です。

とはいえ、ステートレスとステートフルの違いは内部状態をバッチ間で共有するかどうかだけなので、十分に長い系列で学習しているなら得られる結果はほぼ同じです。ステートレスで十分に学習させているならば、コードの実装を煩雑にしてまで、必ずしもステートフルを使う必要が無いかもしれません。


ステートフルLSTMを使う上での難しさ 👉 アレコレと注文が多いモデル

まずTensorflow/KerasのステートフルRNN(SimpleRNN/LSTM/GRU)を訓練させる場合には、以下の3つの決まり事を守る必要があります。

            
            + データの並びは時系列順とする(シャッフル不可)
+ エポック毎にモデルの内部状態をリセット(resetState)する
+ 入力データの長さ=バッチサイズでなければならない
        
もうこれだけでもゲンナリとなりじゃぁステートレスでいいやと感じる方も多くいるかと思われます。

また学習モデルの訓練が終わったら、その学習モデルを中身だけ取り出して、別のモデルで予想させないと使えません。まるで硬いクルミの中身だけ取り出して美味しく頂いたら、外の要らなくなった殻は捨て去るようなイメージです。


tensorflow.js的なステートフルLSTMの実装

ここからは簡単な具体例で、ステートフルLSTMを使った実装を考えてみます。

では前回のステートレスLSTMの実装部分と比較する形でステートフルLSTMのコードを紹介していきます(以下のリンクでその内容に飛びます)。

まずはステートフルLSTMモデルを適用させるためにメインの計算部分である
calcLstm関数の再設計をしてみます。

            
            import * as tf from '@tensorflow/tfjs';

async function calcLstm(
    rawData: number[][],
    testData: number[][]
) {
    //①シーケンシャルモデルを構築
    const model = tf.sequential();
    model.add(tf.layers.lstm({
        units: 32,
        returnSequences: false,
        batchInputShape: [175, 25, 1], // 👈バッチ数必須!
        stateful: true // 👈ステートフルモードに設定!
    }));
    model.add(tf.layers.dense({ units: 128, activation: "relu" }));
    model.add(tf.layers.dense({ units: 1, activation: "linear", useBias: true }));
    model.summary();

    model.compile({
        optimizer: tf.train.adam(3e-4),
        loss: tf.losses.meanSquaredError,
        metrics: ['accuracy']
    });

    //②生のデータセット配列から入力データとラベルデータをテンソル形式で返す関数
    const {inputs, labels} = await convertToTensor(rawData);

    //③モデルの訓練を開始
    console.log('Training started...');
    const hrstart = process.hrtime(); //👈計測用(Nodeの場合)

    const maxEpochs = 2000;
    let oldLoss: number = NaN;
    //👇model.fitはエポック毎に回す
    for (let epoch = 0; epoch < maxEpochs; epoch++) {
        model.resetStates();
        console.log(`Epoch#${epoch}:`);
        const history = await model.fit(inputs, labels, {
            epochs: 1, //👈バッチ数 = エポック数 = 1
            batchSize: 175, //👈時系列データはバッチデータ全てを使う
            shuffle: false, //👈シャッフルはしない
            callbacks:[
                new tf.CustomCallback({
                    onEpochEnd: async(epoch: any, logs: any) => {
                        console.log(`Loss(${logs.loss}) : acc(${logs.acc})`);
                    },
                })
            ]
        });
        const currentLoss = history.history.loss[0] as number;
        if (!Number.isNaN(oldLoss) && customStopper(currentLoss, oldLoss, 1e-8)) {
            console.log('Loss is now saturated.');
            break;
        }
        oldLoss = currentLoss;
    }

    const hrend = process.hrtime(hrstart); //👈計測終了(Nodeの場合)
    console.log(`Training has done. Elapsed Time: (${hrend[0]}s${hrend[1]}ms)`);

    //④訓練済みのモデルの重みをもった別のモデルで予想データを計算(検証用)
    //実装は後述
    const inputs4eval = await convertToTensorForValidation(testData);
    const copiedModel: tf.LayersModel = transformModel(model);
    return await testModel(copiedModel, inputs4eval, 25);
}
        
では補助メソッドの中身も見ていきます。

convertToTensorconvertToTensorForValidationは前回の記事とほぼ同じですがシャッフルしていないところだけが違います。以下に念の為に再掲しておきます。

            
            // 元データの配列の列の長さは26で固定
// [0:24] > LSTM層の入力データ, [25] > ラベル列
async function convertToTensor(data: number[][], stepNum: number = 26): any {
    return tf.tidy(() => {
        const [input2d, labelTensor] = tf.tensor2d(data, [data.length, stepNum]).split([stepNum - 1, 1], 1);
        const inputTensor: any = (input2d as tf.Tensor).reshape([data.length, stepNum - 1, 1]);

        const inputMax = inputTensor.max();
        const inputMin = inputTensor.min();
        const labelMax = labelTensor.max();
        const labelMin = labelTensor.min();

        // それぞれの最大・最小でデータ正規化(数値0から1の範囲へ変換)を行う
        // =Leru等の非負数を出力にする活性化関数を適用させる目的
        const normalizedInputs = inputTensor.sub(inputMin).div(inputMax.sub(inputMin));
        const normalizedLabels = labelTensor.sub(labelMin).div(labelMax.sub(labelMin));

        return {
            inputs: normalizedInputs,
            labels: normalizedLabels,
            inputMax,
            inputMin,
            labelMax,
            labelMin
        }
    });
}

//こちらはシャッフル無しの検証用テンソルを返す
//こちらはLSTMモデルに入力するための検証用データセット(25列使用)を利用
async function convertToTensorForValidation(
    data: number[][],
    stepNum: number = 25
) {
    return tf.tidy(() => {
        const input2d = tf.tensor2d(postData, [postData.length, stepNum]);
        const inputTensor = input2d.reshape([postData.length, stepNum, 1]);

        const inputMax = inputTensor.max();
        const inputMin = inputTensor.min();

        // それぞれの最大・最小でデータ正規化(数値0から1の範囲へ変換)を行う
        const normalizedInputs = inputTensor.sub(inputMin).div(inputMax.sub(inputMin));

        return {
            inputs: normalizedInputs,
            inputMax,
            inputMin
        }
    });
}
        

次に前回のステートレスより大幅に変わった・追加された
testModeltransformModelcustomStopperは以下のようになります。

            
            function transformModel(oldModel: tf.LayersModel): tf.LayersModel {
    // 予測専用モデルのバッチサイズは必ず1にする
    const nBatch = 1;
    // モデルの再構築
    const newModel = tf.sequential();
    newModel.add(tf.layers.lstm({
        units: 32,
        returnSequences: false,
        batchInputShape: [nBatch, 25, 1],
        stateful: true
    }));
    newModel.add(tf.layers.dense({ units: 128, activation: "relu" }));
    newModel.add(tf.layers.dense({ units: 1, activation: "linear", useBias: true }));
    newModel.summary();

    //👇訓練済モデルから各レイヤーの内部の重みを複製
    newModel.setWeights(oldModel.getWeights());

    return newModel;
}

async function testModel(
    model: any,
    normalizationData: any, // 入力データは正規化した集合[0,1]^Nであることに注意
    offset: number = 1 //x軸方向のオフセット(チャート描画用)
) : Array<{x: number, y: number}> {

    const {
        inputs, //正規化済みの集合I: [0,1]^N
        inputMax,
        inputMin
    } = normalizationData;

    //👇inputs([バッチ数,25,1])から一旦、[25,1]のテンソルのバッチ数分の配列に仕立てる
    const inputArray = tf.unstack(inputs);

    //👇予想結果を格納する配列
    const predictedPoints: any[] = [];
    for (const item of inputArray) {
        //👇stackで[25,1]のテンソルを[1,25,1]のテンソルに戻す
        const inputPerBatch = tf.stack([item]);
        const [preds] = tf.tidy(() => {
            //👇コピーされたモデルは[1,25,1]の入力データなら受け付けてくれる!
            const preds: any = model.predict(inputPerBatch);
            const unNormPreds = preds.mul(inputMax.sub(inputMin)).add(inputMin);
            return [unNormPreds.dataSync()];
        });
        const predictedPoint = Array.from(preds)[0];
        predictedPoints.push(predictedPoint);
    }

    return predictedPoints.map((val, i) => {
        return {
            x: i + offset,
            y: val as number
        };
    });
}

// 通常のmodel.fitではないのでearlyStoppingの代替として使うメソッド
function customStopper(
    currentLoss: number,
    oldLoss: number,
    delta: number = 1.0e-5
): boolean {
    return Math.abs(currentLoss - oldLoss) < delta;
}
        
前回と同じように、計算を走らせて得られたデータから曲線を表示してみると以下のようになります。

合同会社タコスキングダム|蛸壺の技術ブログ

ほぼステートレスLSTMと同じような結果が得られました。

なお、計算過程のコンソール出力の一部は以下のようになっています。

            
            #...中略
Epoch#1141:
Loss(0.00028409407241269946) : acc(0.011428571306169033)
Loss is now saturated.
Training has done. Elapsed: (121s850999999ms)
_________________________________________________________________
Layer (type)                 Output shape              Param #   
=================================================================
lstm_LSTM2 (LSTM)            [1,32]                    4352      
_________________________________________________________________
dense_Dense3 (Dense)         [1,128]                   4224      
_________________________________________________________________
dense_Dense4 (Dense)         [1,1]                     129       
=================================================================
Total params: 8705
Trainable params: 8705
Non-trainable params: 0
        
エポックが1141回分で122秒ということは、1エポックあたりだいたい107ms程度の計算時間です。

前回のモデルでの訓練にかかる計算時間も折角なので比較してみます。

ステートフルとステートレスの学習の手順が異なるので単純な比較とはいかないのですが、学習モデルは同じにしておいて、model.fit内のearlyStoppingの
patience: 20相当でお任せして、エポック400回で走らせたときの結果と比較します。

            
            _________________________________________________________________
Layer (type)                 Output shape              Param #   
=================================================================
lstm_LSTM2 (LSTM)            [1,32]                    4352      
_________________________________________________________________
dense_Dense3 (Dense)         [1,128]                   4224      
_________________________________________________________________
dense_Dense4 (Dense)         [1,1]                     129       
=================================================================
Total params: 8705
Trainable params: 8705
Non-trainable params: 0
#...中略
Epoch#399 : Loss(0.0002763311786111444) : acc(0.011428571306169033)
Training has done. Erasp Time: (213s249844999ms)
        
こちらは400回で213秒として、1エポックあたりだいたい平均533msの計算時間を要しているようです。

結果だけみると、1エポックあたりでステートフルのほうが5倍ほど処理速度が速いという結果になりました。

ステートレスLSTMの場合には、オプティマイザの最適化処理が効きやすく最適解に素早く到達する可能性もあったり、かと思えばロスが十分に小さくなったように見えても実は局所解にハマってしまった結果まだ学習が未熟だったりと、おそらくシャッフルの影響で、その場その場の解析ごとに結果が大きく違ったりするので、おおよそステートフルにすると計算時間は(体感的に)半分以下には期待できる気がします。

この辺はLSTMを適用する問題にもよるので、今回のような単純な正弦波を予測する問題に限った結果なのですが、一般にはステートフルにするほうが処理が速い、と言っても嘘ではないというのが確認できました。

さらなる予測(結果だけ)

与えられた実測点からの一点先の予想だけでは少し物足りないですので、この訓練済み(の重みコピー)モデルから、一つづつ予測点を増やしいく手法を使って、前回と同様に正弦波4波長分先まで予測してみます。

コードの実装のほうは
testModelの中身で使ったようにmodel.predictメソッドで得られた値で更に入力データを一つづつ更新していくだけですので、ここでは具体例なコードの説明は省きます。ご自身の腕試しと思って、実装してみてください。

合同会社タコスキングダム|蛸壺の技術ブログ

さて、当然ながらステートフルLSTMでもなかなか綺麗な正弦波形が予想されている気がします。


まとめ

さて今回の記事のエッセンスを振り返りますと、ステートフルLSTMを使うアドバンテージとしては、

            
            + ステートレスLSTMと比較して高速に訓練・予測ができる(処理時間短縮)
+ バッチデータがシャッフルしないので結果の再現性がある程度期待できる
        
などが挙げられます。

ただ近年では最適解を効率的に見つけ出すオプティマイザーの研究も進み、今後はかなり優れた最適化のアルゴリズムが発見され、それを用いることで、ステートフルとステートレスの差がさほど出にくくなるかもしれません。

ディープラーニングの分野は今も物凄い速さで技術が進化していますので、何事もご利益があるうちに積極的に取り入れてみられるのがよろしかろうと思います。
記事を書いた人

記事の担当:taconocat

ナンデモ系エンジニア

主にAngularでフロントエンド開発することが多いです。 開発環境はLinuxメインで進めているので、シェルコマンドも多用しております。 コツコツとプログラミングするのが好きな人間です。