【tensorflow.js x kerasの使い方】データの非線形回帰解析① ~ 燃費を予測するサンプルを理解する


2020/12/06

tensorflowのjavascript版・
tensorflow.jsとnode.jsを組み合わせて、実験データのような数値をもつデータセットを非線形回帰で解析する方法を使ってみようのコーナー第一弾です。

最終的に非線形性解析をtensorflow.jsで行わせてみるのが狙いです。

今回は先んじて線型回帰解析の手順を考えていきます。

そこでTensorflowの線形回帰のサンプルとしてよく凡例に取り上げられる
Auto MPG(Miles per Gallon)で自動車の走行距離vs.燃費の関係を解析してみましょう。


データの準備 ~ シェルコマンドからデータを整える

まず質のいい機械学習を行う前段階して、品質の良いデータを取得し、きれいに整形することがとても重要です。

慣例として、本家tensorflowはpythonで圧倒的に利用されているので、データを入手したり、tensorflowで使いやすいフォーマットに捌くのもやはりpython内で行われることが多いのが現状です。

この記事内ではlinuxシェルコマンドを利用してデータの下処理をしていきます。

データ処理をシェルコマンドで行うことのメリットは、下拵えしたデータセットを機械学習モデルで訓練させるにしても、本家tensorflow(python)でも、tensorflow.jsでも、他の機械学習用のアプリケーションでも、共通してデータを食わせることができる再生産性の高いスクリプトツールとして使うことができるからです。

参考サイト:
Create a machine learning model with Bash (海外のサイト)

なお、LinuxやMacの環境をお持ちの方は、bash相当がコマンドとして利用できますのでこのスクリプトは問答無用で走るはずです。

windows使いの方は、wslなどで仮想Linux環境を構築することもできますが、wslをインストールするのも結構面倒くさい話です。

以前の記事で、
WindowsでもLinuxシェルコマンドをいつでも簡単に使いたい!と思い立ったときの魔法のツール・busyboxの内容にてbusyboxを使えば、wslのようにwindows大工事を行わずとも簡単にLinuxシェルコマンドが使えるようになります。

busybox自体は非常に軽量で実用性も高くオススメのアプリケーションですので何かのためにwindowsにインストールしておくと良いと思います。

wget/sed/awkを使った実践的機械学習用のデータづくり

ではまず先んじて、tmpフォルダを作り、そこにAuto MPG用のデータセットをダウンロードしてみます。

            
            $ mkdir ./tmp
$ cd tmp && wget https://archive.ics.uci.edu/ml/machine-learning-databases/auto-mpg/auto-mpg.data
--2020-12-04 11:07:33-- https://archive.ics.uci.edu/ml/machine-learning-databases/auto-mpg/auto-mpg.data
Resolving archive.ics.uci.edu... 128.195.10.252
Connecting to archive.ics.uci.edu|128.195.10.252|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 30286 (30K) [application/x-httpd-php]
Saving to: 'auto-mpg.data'

auto-mpg.data                           100%[==============================================================================>]  29.58K  --.-KB/s    in 0.1s

2020-12-04 11:07:34 (213 KB/s) - 'auto-mpg.data' saved [30286/30286]
        
チュートリアルにも述べられている通り、ダウンロードした生ファイルには、いくつかの欠損値が存在しています。

例えば、以下の行のように4列目の数値が
?を不明な値として与えられているのですが、

            
            #...中略
25.0   4   98.00      ?          2046.      19.0   71  1    "ford pinto"
#...以下略
        
このままtensorflowに食わせるとエラーの元になるので、sedの正規表現を使って無用な欠損行は前処理として弾きます。

            
            $ sed -ie '/?/d' ./tmp/auto-mpg.data
        
これで?を含む行をデータから一括削除出ました。

次に空白文字切りのフォーマットではなく、csv(コンマ切り)にしたいので、
awkを使ってファイルをコンマ切りでで書き出します。

もう一度生ファイルの行のフォーマットをよくみて見ると、9列目はデータラベルとして文字列としてダブルクォーテーション(")の中に、複数の空白文字が含まれています。

            
            #...中略
16.0   8   318.0      150.0      4498.      14.5   75  1    "plymouth grand fury"
#...以下略
        
awkはデフォルトで区切り文字として空白文字(スペースとタブ)を使いますので、単にawkしてしまうと、ラベル列も切り刻まれてしまいます。

そこでまず数値列とラベル列を別々のファイルに一旦吐き出しておきます。

数値の列(1-8列目)は以下のように
a.dataとして一次保存します。

            
            $ cat ./tmp/auto-mpg.data | awk '{ print $1 "," $2 "," $3 "," $4 "," $5 "," $6 "," $7 "," $8}' > ./tmp/a.data
        
ラベル列は独立してb.dataへ一次保存します。

            
            $ cat ./tmp/auto-mpg.data | sed -n 's/.*"\(.*\)"/\1/p' > ./tmp/b.data
        
最後にa.dataとb.dataをawkで行番号を合わせて結合させたものをoutput.csvに保存させます。

            
            $ awk -F "," 'BEGIN {
    i=0;
    j=0;
}
FILENAME == ARGV[1] {
    label[i]=$0;
    i++;
}
FILENAME == ARGV[2] {
    price[j]=$0;
    j++;
}
END {
    for(i=0; i < length(label); i++) {
        print label[i] ",\"" price[i] "\"";
    }
}' ./tmp/a.data ./tmp/b.data > ./tmp/output.csv
        
ちなみにawkでは複数の読み込んだファイルを同時に捌くことが可能です。

処理の際には、ファイルを読み込んだ順に
ARGV[番号]でファイル名が引き出すことができ、現在処理中のファイルがFILENAME変数で取り出せますので、それらを使って上手く配列化しているようなスクリプトです。

仕上がったデータが以下のようにできると思います。

            
            #...中略
31.5,4,89.00,71.00,1990.,14.9,78,2,"volkswagen scirocco"
29.5,4,98.00,68.00,2135.,16.6,78,3,"honda accord lx"
21.5,6,231.0,115.0,3245.,15.4,79,1,"pontiac lemans v6"
19.8,6,200.0,85.00,2990.,18.2,79,1,"mercury zephyr 6"
#...以降略
        
余談で、上のデータ整形スクリプトはawkを使って汎用性のある正攻法のようなやり方となります。

ファイル特有の法則性を見抜くことができればもっとスマートにデータ整形が可能です。

例えば今回の生ファイルに限ったやり方ですが、

            
            $ cat ./tmp/auto-mpg.data | sed -e 's/\t/,/' -e 's/ \{2,\}/,/g' > ./tmp/output.csv
        
としたほうがもっと短い手数で同じ結果を得ます。

こういったスクリプトをワンライナーで書くシェル芸が、シェル上級者には好まれます。

高度(短い手数)になるほど不慣れな人間には難解なテクニックが使われることになるのでチーム開発ではなるべくシェル芸は控えた方がいいかもしれません。

カテゴリー化のためのワンホットエンコーディング

ここで改めて各列値が何を表しているかを整理します。

            
            #'MPG','Cylinders','Displacement','Horsepower','Weight','Acceleration','Model Year','Origin','Label'
#...中略
21.5,6,231.0,115.0,3245.,15.4,79,1,"pontiac lemans v6"
#...以下略
        
現状で9列ありまして、それぞれは

            
            MPG: 燃費(1ガロンあたり何マイル走るか)
Cylinders: エンジン気筒数
Displacement: エンジン排気量
Horsepower: 出力馬力
Weight: 車体重量
Acceleration: 加速性
Model Year: 販売開始年度
Origin:
    アメリカ=1
    ヨーロッパ=2
    日本=3
Label: モデル名
        
となっています。

さて、この内Originの列は数値ではなくカテゴリーになっているので、機械学習で取り扱うためには
ワンホットのフォーマットに直す必要があります。

そこでこのOriginの一列を、
USA, Europe, Japanの三列の形式に拡張してみましょう。

以下のようにawkを使ってやってみます。

            
            $ awk -F "," 'BEGIN { OFS="," } {
    if ($8 == 1) {
        print $1,$2,$3,$4,$5,$6,$7,"1.0",0,0,$9;
    } else if ($8 == 2) {
        print $1,$2,$3,$4,$5,$6,$7,0,"1.0",0,$9;
    } else {
        print $1,$2,$3,$4,$5,$6,$7,0,0,"1.0",$9;
    }
}' ./tmp/output.csv > ./tmp/output2.csv
        
出力したファイルを覗くと、8列目だった箇所がワンホット化されて新たに3列分にすることができました。

            
            #...中略
14.0,8,440.0,215.0,4312.,8.5,70,1.0,0,0,"plymouth fury iii"
14.0,8,455.0,225.0,4425.,10.0,70,1.0,0,0,"pontiac catalina"
15.0,8,390.0,190.0,3850.,8.5,70,1.0,0,0,"amc ambassador dpl"
#以下略
        

視覚化の準備 ~ Parcelを利用してブラウザで表示

ここからはデータの視覚化の手順を説明していきます。

Pythonからtensorflowを扱うのであれば、matplotlibからデータを視覚化することが定番となっています。

一方でjavascriptベースのtensorflow.jsで解析した結果をビジュアル化するとなると、matplotlib相当のメジャーでお手頃な定番アプリケーションはまだないかな...といった状況です。

tfjs-visというtensorflow.jsに正式に組み込まれた視覚化ライブラリはあることあるのですが、もともとjavascriptがブラウザ向けの開発言語なこともあり、グラフをもっと綺麗に簡単にプロットするチャートライブラリが他にもいくらでも存在しているため、tfjs-visをわざわざ使う必要性も感じないのが現状です。

今回は
Chart.jsというのサードパーティ製の有名どころの描画用ライブラリを使ってプロットしてみようと思います。

グラフの描画はともかくベースとなるプロジェクトをParcelでサクッとビルドして利用してみます。

前回、AlpineDocker環境でParcelプロジェクトを導入する方法を解説しましたので、今回の内容はParcelでビルドできるようになった状態から話を進めさせていただきます。

Parcelの開発環境の内容を知りたい方は下のリンク記事の方をご参照ください。

Chart.jsサンプルコードのビルド

本格的に自分専用のグラフ描画プログラムを作成するとなると割と時間と手間がかかります。

今回は練習がてらの取っ掛かりの話として、
csvを表示させるのにちょうどいい感じのサンプルを公開してくださっている方のソースコードを利用してParcelでビルドできるかで動作確認してみます。

まずは
index.htmlの中身を以下のようにします。

CDNのmin版のchart.jsは
Chart.js - Simple HTML5 charts using the canvas element.のサイトから最新版が利用できますので、その都度チェックしてみてください。

            
            <!DOCTYPE html>
    <html lang="en">
        <head>
            <meta charset="UTF-8">
            <script src="https://cdnjs.cloudflare.com/ajax/libs/Chart.js/2.9.4/Chart.min.js"></script>
            <script src="./main.js"></script>
            <title>Diagram to compare with csv data</title>
        </head>
        <body>
        <div style="width: 1024px; height: auto;">
            <canvas id="myChart"></canvas>
        </div>
    </body>
</html>
        
jsソースコードはmain.jsという名前でプロジェクトのルートに新規作成して以下の内容にしてみます。

            
            function csv2Array(str) {
    const csvData = [];
    const lines = str.split('\n');
    for (let i = 0; i < lines.length; ++i) {
        const cells = lines[i].split(',');
        csvData.push(cells);
    }
    return csvData;
}

function drawBarChart(data) {
    const tmpLabels = [], tmpData1 = [], tmpData2 = [];
    for (const row in data) {
        tmpLabels.push(data[row][0]);
        tmpData1.push(data[row][1]);
        tmpData2.push(data[row][2]);
    };

    const ctx = document.getElementById('myChart').getContext('2d');
    new Chart(ctx, {
        type: 'line',
        data: {
            labels: tmpLabels,
            datasets: [
                { label: 'Tokyo', data: tmpData1, borderColor: 'rgba(0,0,255,1)', backgroundColor: 'rgba(0,0,0,0)' },
                { label: 'Osaka', data: tmpData2, borderColor: 'rgba(255,0,0,1)', backgroundColor: 'rgba(0,0,0,0)' }
            ]
        }
    });
}

;(function () {
    const req = new XMLHttpRequest();
    const filePath = './data.csv';
    req.open('GET', filePath, true);
    req.onload = function() {
        data = csv2Array(req.responseText);
        drawBarChart(data);
    }
    req.send(null);
}());
        
この時点でアプリをビルドしておきます。

            
            $ parcel build index.html --no-source-maps --no-content-hash main.js --no-cache
✨  Built in 801ms.
dist/main.js       1.59 KB    339ms
dist/index.html      334 B    591ms
Done in 1.28s.
        
これでdistフォルダにindex.htmlmain.jsの2つが生成されました。

では、実際に読み込ませるサンプルデータを
data.csvとして、以下の内容でdistフォルダ内に保存します。

            
            January, -10.4, -5.5
Feburary, -30.3, 1
March, 3.8, 12.3
April, 5.9, 13.5
May, 9.6, 16.4
June, 12.0, 19.4
July, 16.1, 28.2
August, 20.6, 30.3
September, 17.2, 26.2
October, 15.0, 20.8
November, 5.9, 10.1
December, 0.0, 3.3
        
データファイルをdist内に配置できたら、http-serverを叩いてlocalhostからWebサーバー越しにこのSPAを立ち上げてみます。

            
            $ http-server ./dist -a 0.0.0.0 -p 8080 -c-1
Starting up http-server, serving ./dist
Available on:
  http://127.0.0.1:8080
Hit CTRL-C to stop the server
        
すると、http://localhost:8080にアクセスして以下のようにグラフが見えていたらひとまずグラフ表示テストは終了です。

合同会社タコスキングダム|蛸壺の技術ブログ

プロジェクトの下準備としてはこの程度として、以降の節で本題のtensorflow.jsでの線形回帰の話に戻りましょう。


線形回帰解析 ~ tensorflow.js + kerasでの線型回帰事始め

まず最初に線形回帰を理解してないと始まりません。そこで綺麗さっぱり忘れてしまった方(著者も含めて)のために線型回帰を復習してみます。

線形回帰?

y=ax+by = ax + bというの一次直線、y=ax2+bx+cy = ax^2 + bx + cを二次曲線、あるいはもっと次数を増やした曲線など...を中学校頃?の数学で習ったとおもいます。

そんなn次曲線の関数を利用した解析をカッコよくいうと
線形回帰と呼んでいます。

中学生から習う内容とはいえ、そこには奥深い数学的な背景が潜んでいますので、これから機械学習を勉強する方は、この内容をしっかり理解しておく必要があります。

tensorflow.jsとkerasの使い方の入門というですので、とりあえずこの
線形回帰を機械学習で最適解を導いてみましょう。

測定データ

では上の節でデータ整形したファイルを利用して線形回帰させるデータセットを抽出して、解析させてみます。

最初に簡単な例で説明したいので、x軸に車の重量(5列目)、y軸に燃費(1列目)で線形回帰させて遊んでみることにしましょう。

以下のようにsortコマンドを使って5列目を昇順にソートして、結果を保存します。

            
            $ cat ./tmp/output2.csv | sort -k 5 -t "," > ./tmp/output3.csv
        
今回は線型回帰の解析例として、このデータの一部を使って、重量vs.燃費のグラフをプロットしてみましょう。

ちなみに燃費はデータの1列目、車体重量は5列目です、車体のモデル名は11列目あります。

先程のサンプルとして使ったjsソースコード(main.js)では、直線チャートで描画していましたが、データを見せるのには散布図のほうが適切です。

そこで散布図表示させるために以下のように先程のコードを改造してみます。

            
            function csv2Array(str) {
    const csvData = [];
    const lines = str.split('\n');
    for (let i = 0; i < lines.length; ++i) {
        console.log(lines[i]);
        if (lines[i] == '') {
            continue;
        }
        const cells = lines[i].split(',');
        csvData.push(cells);
    }
    return csvData;
}

function drawBarChart(data) {
    const tmpLabels = [], tmpData = [];
    for (const row in data) {
        tmpLabels.push(data[row][10]);
        tmpData.push({x: data[row][4], y: data[row][0]});
    };
    const ctx = document.getElementById('myChart').getContext('2d');
    const options = {
        responsive: true,
        tooltips: {
            backgroundColor: "rgba(19, 56, 95, 0.9)",
            titleFontSize: 16,
            bodyFontSize: 16,
            xPadding: 12,
            yPadding: 10,
            callbacks: {
                label: (tooltipItem, data) => {
                    const groupName = data.labels[tooltipItem.index];
                    const xAxesLabel = options.scales.xAxes[0].scaleLabel.labelString;
                    const yAxesLabel = options.scales.yAxes[0].scaleLabel.labelString;
                    return `${groupName} | ${xAxesLabel}: ${tooltipItem.label} | ${yAxesLabel}: ${tooltipItem.value}`;
                }
            }
        },
        scales: {
            xAxes: [{
                scaleLabel: {
                    display: true,
                    labelString: 'Weight',
                },
                ticks: { min: 0 },
            }],
            yAxes: [{
                scaleLabel: {
                    display: true,
                    labelString: 'MPG',
                },
                ticks: { min: 0 }
            }]
        },
    };
    new Chart(ctx, {
        type: 'scatter',
        data: {
            labels: tmpLabels,
            datasets: [
                {
                    label: 'MPG',
                    data: tmpData,
                    borderColor: 'rgba(0,0,0,1)',
                    backgroundColor: 'rgba(0,0,255,1)',
                    pointRadius: 3,
                    pointHoverRadius: 9
                }
            ]
        },
        options
    });
}

;(function () {
    const req = new XMLHttpRequest();
    const filePath = './output3.csv';
    req.open('GET', filePath, true);
    req.onload = function() {
        data = csv2Array(req.responseText);
        drawBarChart(data);
    }
    req.send(null);
}());
        
色々とパワーアップしていますが変更部分の細かい説明は省きます。

Chart.jsの設定値の詳細は
APIリファレンスなどでご確認ください。

これをビルドしてグラフを描画してみると以下のように散布図を得ることができます。

合同会社タコスキングダム|蛸壺の技術ブログ

なおシェルコマンドでデータを捌くと、そのデータファイルのフォルダ移動までParcelは面倒を見てくれません。

ここでの整形済みデータファイル
output3.csvは手動でdistフォルダに移動してください。

手動で移動が嫌な方は、例えばビルド完了後にデータファイルを自動で移動できるように、

            
            $ cp -f ./tmp/output3.csv ./dist/
        
を叩くか、package.jsonのスクリプトに組み込んでしまうかなどを適宜検討してください。

tensorflow.jsを使う準備

tensorflow.jsがParcel環境で正常にビルド出来るのを確認する必要があります。

tensorflow.jsのインストールは簡単です。

            
            $ yarn add @tensorflow/tfjs -S
        
これでブラウザでtensorflowが使えるようになります。

またtensorflow.jsは機能上、
async/await構文の利用が多用されるのですが、この構文をそのまま使うとParcelでビルド中にregeneratorRuntime is not definedが発生します。そんな時はpackage.jsonに以下のフィールドを追加します。

            
            {
//...中略
    "browserslist": [
        "since 2017-06"
    ]
}
        
参考サイト: Parcelでasync/awaitを使うと「regeneratorRuntime is not defined」エラーが出る場合の対処法

tensorflow.js&Kerasモデルによる機械学習

ここからいよいよtensorflow.jsを使って線型回帰解析の中身の実装をおこなっていきます。

なお、tensorflow.jsの公式チュートリアル
TensorFlow.js — Making Predictions from 2D Dataの内容をアレンジしたものです。

まずはtensorflowの処理を行うコード部分を
main.jsとは別ファイルでlinear_regression.jsという名前で新規作成してみます。このlinear_regression.jsは以下の内容で編集しておきます。

            
            import * as tf from '@tensorflow/tfjs';

export async function linearRegression(rawData) {
    const model = tf.sequential();

    //👇入力層
    model.add(tf.layers.dense({inputShape: [1], units: 1, useBias: true}));

    //👇中間層(追加コードの内容は後述)

    //👇出力層
    model.add(tf.layers.dense({ units: 1, useBias: true, activation: 'linear'}));

    console.log('Training started.');

    const originalData = [], labels = [];
    for (const row in rawData) {
        originalData.push({x: parseFloat(rawData[row][4]), y: parseFloat(rawData[row][0])});
        labels.push(rawData[row][10]);
    };

    // テンソルに変換・データの正規化
    const tensorData = convertToTensor(originalData);

    // モデルをトレーニング
    await trainModel(model, tensorData.inputs, tensorData.labels);

    // モデルのテスト
    const predictedData = testModel(model, tensorData);

    console.log('Fitting has done.');

    return { labels, originalData, predictedData };
}

// モデルの学習を行う関数
async function trainModel(model, inputs, labels) {
    // モデルをコンパイル=学習方法を指定
    model.compile({
        optimizer: tf.train.adam(),
        loss: tf.losses.meanSquaredError,
        metrics: ['mse'],
    });

    // バッチサイズ
    const batchSize = 32;
    // エポック数
    const epochs = 50;
    // エポック回数の学習を実行する
    return await model.fit(inputs, labels, {
        batchSize,
        epochs,
        shuffle: true,
        callbacks: {
            onEpochEnd: async(epoch, logs) => {
                // 繰り返し回数と損失をコンソール出力
                console.log(`Epoch#${epoch} : Loss(${logs.loss}) : mse(${logs.mse})`);
            }
        }
    });
}

// 学習済みモデルからフィッティング曲線を生成
function testModel(model, normalizationData) {
    const {inputMax, inputMin, labelMin, labelMax} = normalizationData;
    const [xs, preds] = tf.tidy(() => {
        // tf.linespaceで、0から1までの間で等間隔刻みに100個の値を生成
        const xs = tf.linspace(0, 1, 100);
        const preds = model.predict(xs.reshape([100, 1]));

        // モデルの入出力値は正規化されていたので、これを元のスケールに復元する
        const unNormXs = xs.mul(inputMax.sub(inputMin)).add(inputMin);
        const unNormPreds = preds.mul(labelMax.sub(labelMin)).add(labelMin);

        // Tensor型から配列型に変換
        return [unNormXs.dataSync(), unNormPreds.dataSync()];
    });

    // Chart.jsで描画するデータ形式に配列を整える
    const predictedPoints = Array.from(xs).map((val, i) => {
        return {
            x: val,
            y: preds[i]
        }
    });

    return predictedPoints;
}

// 学習データをTensor型に変換する関数
function convertToTensor(data) {
    return tf.tidy(() => {
        // 学習データをシャッフル
        tf.util.shuffle(data);

        // xyデータ配列をNx1テンソルに変換
        const inputs = data.map(d => d.x)
        const labels = data.map(d => d.y);
        const inputTensor = tf.tensor2d(inputs, [inputs.length, 1]);
        const labelTensor = tf.tensor2d(labels, [labels.length, 1]);

        // 入力データの正規化
        const inputMax = inputTensor.max();
        const inputMin = inputTensor.min();
        const labelMax = labelTensor.max();
        const labelMin = labelTensor.min();
        const normalizedInputs = inputTensor.sub(inputMin).div(inputMax.sub(inputMin));
        const normalizedLabels = labelTensor.sub(labelMin).div(labelMax.sub(labelMin));

        return {
            inputs: normalizedInputs,
            labels: normalizedLabels,
            inputMax,
            inputMin,
            labelMax,
            labelMin
        }
    });
}
        
それではmain.js側でlinear_regression.jsの関数を使って線形回帰を行えるように以下のように修正してみます。

            
            const lr = require('./linear_regression');

function csv2Array(str) {
    const csvData = [];
    const lines = str.split('\n');
    for (let i = 0; i < lines.length; ++i) {
        if (lines[i] == '') { continue; }
        const cells = lines[i].split(',');
        csvData.push(cells);
    }
    return csvData;
}

function drawPredictedCurve(labelData_, originalData_, predictedData_) {
    const ctx = document.getElementById('myChart').getContext('2d');
    const options = {
        responsive: true,
        tooltips: {
            backgroundColor: "rgba(19, 56, 95, 0.9)",
            titleFontSize: 16,
            bodyFontSize: 16,
            xPadding: 12,
            yPadding: 10
        },
        scales: {
            xAxes: [{
                scaleLabel: {
                    display: true,
                    labelString: 'Weight',
                },
                ticks: {
                    min: 0,
                    max: 7000,
                    stepSize: 1000
                },
            }],
            yAxes: [
                {
                    id: "y-axis-p",
                    type: "linear",
                    position: "right",
                    ticks: {
                        max: 50,
                        min: 0,
                        stepSize: 10
                    },
                },
                {
                    id: "y-axis-o",
                    type: "linear",
                    position: "left",
                    ticks: {
                        max: 50,
                        min: 0,
                        stepSize: 10
                    },
                }
            ]
        }
    };
    new Chart(ctx, {
        type: 'scatter',
        data: {
            labels: labelData_,
            datasets: [
                {
                    label: 'Prediction',
                    type: 'line',
                    data: predictedData_,
                    borderColor: 'rgba(0,0,0,1)',
                    backgroundColor: 'rgba(0,0,0,0)',
                    pointRadius: 0,
                    yAxisID: "y-axis-p",
                },
                {
                    label: 'MPG',
                    type: 'scatter',
                    data: originalData_,
                    borderColor: 'rgba(0,0,0,1)',
                    backgroundColor: 'rgba(0,0,255,1)',
                    pointRadius: 3,
                    pointHoverRadius: 9,
                    yAxisID: "y-axis-o",
                }
            ]
        },
        options
    });
}

;(async () => {
    const req = new XMLHttpRequest();
    const filePath = './output3.csv';
    req.open('GET', filePath, true);
    req.onload = function() {
        data = csv2Array(req.responseText);
        lr.testInnerTensorflow(data).then((res) => {
            drawPredictedCurve(res.tmpLabels, res.originalData, res.predictedData);
        });
    }
    req.send(null);
})();
        
これでもっとも単純な線形回帰、いわゆる一次直線によるフィッティングを行うプログラムになっています。

解析結果

早速ビルドしてプログラムを実行しブラウザで開くと、データによる機械学習モデルの訓練と予想計算がコンソールにエポック毎に表示され、終わったらcanvas要素に描画されます。

合同会社タコスキングダム|蛸壺の技術ブログ

流石にもっとも単純なモデルだけあって、誤差も収束してなさそうですし、ブラウザをリロードするたびに結果がその都度変わります。

モデルに入力層と出力層の2層しかなく、この場合ほどんどkerasモデルが学習していないため結果として当然と言えば当然です。

そこで深層学習のキモの技術と言える中間層を1枚モデルに挿入してみたらどうなるか考えてみます。

先程のコードで中間層の追加位置として空けていたところに以下のコードを挿入します。

            
            //...中略
    //👇中間層(追加コードの内容は後述)
    model.add(tf.layers.dense({
        units: 16,
        useBias: true
    }));
//...以下略
        
隠れノード16ユニットの中間層を一層噛ませて、先程のフィッティングがどうなるかというと、

合同会社タコスキングダム|蛸壺の技術ブログ

ズバッと散布点の真ん中を貫いているように精度がかなり向上しています。

一般的に、中間層の隠れノード数を増やしたり、中間層の数自体を増やしたり、活性化関数を様々に考慮することでフィッティング精度が向上します。

今回の解説内容では、精度の比較までは深堀しません。

各自色々とパラメーターを弄ってみてください。


まとめ

今回の例では実用性はないのですが、tensorflow.jsを使った線形回帰の基本的なプログラミングの流れは掴んでいただけたのではないかと思います。

また改めて別記事で、今回の例に中間層を用いてもう少し解析の精度を高める方法を解析しようと思います。

参考にしたサイト

回帰:燃費を予測する

TensorFlow.jsでDeepLearning(Making Predictions from 2D Data)