カテゴリー
【tensorflowjs x kerasの使い方】LSTMをステートフルにすると何がいいの?を検証する
※ 当ページには【広告/PR】を含む場合があります。
2020/12/27
2022/08/18
ステートフルLSTMは訓練も予測もかなり速い
ステートフルLSTMを使う上での難しさ 👉 アレコレと注文が多いモデル
+ データの並びは時系列順とする(シャッフル不可)
+ エポック毎にモデルの内部状態をリセット(resetState)する
+ 入力データの長さ=バッチサイズでなければならない
TensorFlowを動かしながら学ぶ TensorFlowとKerasで動かしながら学ぶ ディープラーニングの仕組み 畳み込みニューラルネットワーク徹底解説
tensorflowjsでステートフルLSTMの実装
calcLstm
import * as tf from '@tensorflow/tfjs';
async function calcLstm(
rawData: number[][],
testData: number[][]
) {
//①シーケンシャルモデルを構築
const model = tf.sequential();
model.add(tf.layers.lstm({
units: 32,
returnSequences: false,
batchInputShape: [175, 25, 1], // 👈バッチ数必須!
stateful: true // 👈ステートフルモードに設定!
}));
model.add(tf.layers.dense({ units: 128, activation: "relu" }));
model.add(tf.layers.dense({ units: 1, activation: "linear", useBias: true }));
model.summary();
model.compile({
optimizer: tf.train.adam(3e-4),
loss: tf.losses.meanSquaredError,
metrics: ['accuracy']
});
//②生のデータセット配列から入力データとラベルデータをテンソル形式で返す関数
const {inputs, labels} = await convertToTensor(rawData);
//③モデルの訓練を開始
console.log('Training started...');
const hrstart = process.hrtime(); //👈計測用(Nodeの場合)
const maxEpochs = 2000;
let oldLoss: number = NaN;
//👇model.fitはエポック毎に回す
for (let epoch = 0; epoch < maxEpochs; epoch++) {
model.resetStates();
console.log(`Epoch#${epoch}:`);
const history = await model.fit(inputs, labels, {
epochs: 1, //👈バッチ数 = エポック数 = 1
batchSize: 175, //👈時系列データはバッチデータ全てを使う
shuffle: false, //👈シャッフルはしない
callbacks:[
new tf.CustomCallback({
onEpochEnd: async(epoch: any, logs: any) => {
console.log(`Loss(${logs.loss}) : acc(${logs.acc})`);
},
})
]
});
const currentLoss = history.history.loss[0] as number;
if (!Number.isNaN(oldLoss) && customStopper(currentLoss, oldLoss, 1e-8)) {
console.log('Loss is now saturated.');
break;
}
oldLoss = currentLoss;
}
const hrend = process.hrtime(hrstart); //👈計測終了(Nodeの場合)
console.log(`Training has done. Elapsed Time: (${hrend[0]}s${hrend[1]}ms)`);
//④訓練済みのモデルの重みをもった別のモデルで予想データを計算(検証用)
//実装は後述
const inputs4eval = await convertToTensorForValidation(testData);
const copiedModel: tf.LayersModel = transformModel(model);
return await testModel(copiedModel, inputs4eval, 25);
}
convertToTensor
convertToTensorForValidation
// 元データの配列の列の長さは26で固定
// [0:24] > LSTM層の入力データ, [25] > ラベル列
async function convertToTensor(data: number[][], stepNum: number = 26): any {
return tf.tidy(() => {
const [input2d, labelTensor] = tf.tensor2d(data, [data.length, stepNum]).split([stepNum - 1, 1], 1);
const inputTensor: any = (input2d as tf.Tensor).reshape([data.length, stepNum - 1, 1]);
const inputMax = inputTensor.max();
const inputMin = inputTensor.min();
const labelMax = labelTensor.max();
const labelMin = labelTensor.min();
// それぞれの最大・最小でデータ正規化(数値0から1の範囲へ変換)を行う
// =Leru等の非負数を出力にする活性化関数を適用させる目的
const normalizedInputs = inputTensor.sub(inputMin).div(inputMax.sub(inputMin));
const normalizedLabels = labelTensor.sub(labelMin).div(labelMax.sub(labelMin));
return {
inputs: normalizedInputs,
labels: normalizedLabels,
inputMax,
inputMin,
labelMax,
labelMin
}
});
}
//こちらはシャッフル無しの検証用テンソルを返す
//こちらはLSTMモデルに入力するための検証用データセット(25列使用)を利用
async function convertToTensorForValidation(
data: number[][],
stepNum: number = 25
) {
return tf.tidy(() => {
const input2d = tf.tensor2d(postData, [postData.length, stepNum]);
const inputTensor = input2d.reshape([postData.length, stepNum, 1]);
const inputMax = inputTensor.max();
const inputMin = inputTensor.min();
// それぞれの最大・最小でデータ正規化(数値0から1の範囲へ変換)を行う
const normalizedInputs = inputTensor.sub(inputMin).div(inputMax.sub(inputMin));
return {
inputs: normalizedInputs,
inputMax,
inputMin
}
});
}
testModel
transformModel
customStopper
function transformModel(oldModel: tf.LayersModel): tf.LayersModel {
// 予測専用モデルのバッチサイズは必ず1にする
const nBatch = 1;
// モデルの再構築
const newModel = tf.sequential();
newModel.add(tf.layers.lstm({
units: 32,
returnSequences: false,
batchInputShape: [nBatch, 25, 1],
stateful: true
}));
newModel.add(tf.layers.dense({ units: 128, activation: "relu" }));
newModel.add(tf.layers.dense({ units: 1, activation: "linear", useBias: true }));
newModel.summary();
//👇訓練済モデルから各レイヤーの内部の重みを複製
newModel.setWeights(oldModel.getWeights());
return newModel;
}
async function testModel(
model: any,
normalizationData: any, // 入力データは正規化した集合[0,1]^Nであることに注意
offset: number = 1 //x軸方向のオフセット(チャート描画用)
) : Array<{x: number, y: number}> {
const {
inputs, //正規化済みの集合I: [0,1]^N
inputMax,
inputMin
} = normalizationData;
//👇inputs([バッチ数,25,1])から一旦、[25,1]のテンソルのバッチ数分の配列に仕立てる
const inputArray = tf.unstack(inputs);
//👇予想結果を格納する配列
const predictedPoints: any[] = [];
for (const item of inputArray) {
//👇stackで[25,1]のテンソルを[1,25,1]のテンソルに戻す
const inputPerBatch = tf.stack([item]);
const [preds] = tf.tidy(() => {
//👇コピーされたモデルは[1,25,1]の入力データなら受け付けてくれる!
const preds: any = model.predict(inputPerBatch);
const unNormPreds = preds.mul(inputMax.sub(inputMin)).add(inputMin);
return [unNormPreds.dataSync()];
});
const predictedPoint = Array.from(preds)[0];
predictedPoints.push(predictedPoint);
}
return predictedPoints.map((val, i) => {
return {
x: i + offset,
y: val as number
};
});
}
// 通常のmodel.fitではないのでearlyStoppingの代替として使うメソッド
function customStopper(
currentLoss: number,
oldLoss: number,
delta: number = 1.0e-5
): boolean {
return Math.abs(currentLoss - oldLoss) < delta;
}
#...中略
Epoch#1141:
Loss(0.00028409407241269946) : acc(0.011428571306169033)
Loss is now saturated.
Training has done. Elapsed: (121s850999999ms)
_________________________________________________________________
Layer (type) Output shape Param #
=================================================================
lstm_LSTM2 (LSTM) [1,32] 4352
_________________________________________________________________
dense_Dense3 (Dense) [1,128] 4224
_________________________________________________________________
dense_Dense4 (Dense) [1,1] 129
=================================================================
Total params: 8705
Trainable params: 8705
Non-trainable params: 0
patience: 20
_________________________________________________________________
Layer (type) Output shape Param #
=================================================================
lstm_LSTM2 (LSTM) [1,32] 4352
_________________________________________________________________
dense_Dense3 (Dense) [1,128] 4224
_________________________________________________________________
dense_Dense4 (Dense) [1,1] 129
=================================================================
Total params: 8705
Trainable params: 8705
Non-trainable params: 0
#...中略
Epoch#399 : Loss(0.0002763311786111444) : acc(0.011428571306169033)
Training has done. Erasp Time: (213s249844999ms)
さらなる予測(結果だけ)
testModel
model.predict
TensorFlowを動かしながら学ぶ TensorFlowとKerasで動かしながら学ぶ ディープラーニングの仕組み 畳み込みニューラルネットワーク徹底解説
まとめ
+ ステートレスLSTMと比較して高速に訓練・予測ができる(処理時間短縮)
+ バッチデータがシャッフルしないので結果の再現性がある程度期待できる
記事を書いた人
ナンデモ系エンジニア
主にAngularでフロントエンド開発することが多いです。 開発環境はLinuxメインで進めているので、シェルコマンドも多用しております。 コツコツとプログラミングするのが好きな人間です。
カテゴリー