カテゴリー
【tensorflowjs & kerasの使い方】燃費を予測するサンプルを理解する〜非線形回帰解析①
※ 当ページには【広告/PR】を含む場合があります。
2020/12/06
2022/08/18
データの準備 ~ シェルコマンドからデータを整える
wget/sed/awkを使った実践的機械学習用のデータづくり
tmp
Auto MPG
$ mkdir ./tmp
$ cd tmp && wget https://archive.ics.uci.edu/ml/machine-learning-databases/auto-mpg/auto-mpg.data
--2020-12-04 11:07:33-- https://archive.ics.uci.edu/ml/machine-learning-databases/auto-mpg/auto-mpg.data
Resolving archive.ics.uci.edu... 128.195.10.252
Connecting to archive.ics.uci.edu|128.195.10.252|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 30286 (30K) [application/x-httpd-php]
Saving to: 'auto-mpg.data'
auto-mpg.data 100%[==============================================================================>] 29.58K --.-KB/s in 0.1s
2020-12-04 11:07:34 (213 KB/s) - 'auto-mpg.data' saved [30286/30286]
?
#...中略
25.0 4 98.00 ? 2046. 19.0 71 1 "ford pinto"
#...以下略
$ sed -ie '/?/d' ./tmp/auto-mpg.data
?
awk
#...中略
16.0 8 318.0 150.0 4498. 14.5 75 1 "plymouth grand fury"
#...以下略
a.data
$ cat ./tmp/auto-mpg.data | awk '{ print $1 "," $2 "," $3 "," $4 "," $5 "," $6 "," $7 "," $8}' > ./tmp/a.data
b.data
$ cat ./tmp/auto-mpg.data | sed -n 's/.*"\(.*\)"/\1/p' > ./tmp/b.data
output.csv
$ awk -F "," 'BEGIN {
i=0;
j=0;
}
FILENAME == ARGV[1] {
label[i]=$0;
i++;
}
FILENAME == ARGV[2] {
price[j]=$0;
j++;
}
END {
for(i=0; i < length(label); i++) {
print label[i] ",\"" price[i] "\"";
}
}' ./tmp/a.data ./tmp/b.data > ./tmp/output.csv
ARGV[番号]
FILENAME
#...中略
31.5,4,89.00,71.00,1990.,14.9,78,2,"volkswagen scirocco"
29.5,4,98.00,68.00,2135.,16.6,78,3,"honda accord lx"
21.5,6,231.0,115.0,3245.,15.4,79,1,"pontiac lemans v6"
19.8,6,200.0,85.00,2990.,18.2,79,1,"mercury zephyr 6"
#...以降略
$ cat ./tmp/auto-mpg.data | sed -e 's/\t/,/' -e 's/ \{2,\}/,/g' > ./tmp/output.csv
カテゴリー化のためのワンホットエンコーディング
#'MPG','Cylinders','Displacement','Horsepower','Weight','Acceleration','Model Year','Origin','Label'
#...中略
21.5,6,231.0,115.0,3245.,15.4,79,1,"pontiac lemans v6"
#...以下略
MPG: 燃費(1ガロンあたり何マイル走るか)
Cylinders: エンジン気筒数
Displacement: エンジン排気量
Horsepower: 出力馬力
Weight: 車体重量
Acceleration: 加速性
Model Year: 販売開始年度
Origin:
アメリカ=1
ヨーロッパ=2
日本=3
Label: モデル名
USA, Europe, Japan
$ awk -F "," 'BEGIN { OFS="," } {
if ($8 == 1) {
print $1,$2,$3,$4,$5,$6,$7,"1.0",0,0,$9;
} else if ($8 == 2) {
print $1,$2,$3,$4,$5,$6,$7,0,"1.0",0,$9;
} else {
print $1,$2,$3,$4,$5,$6,$7,0,0,"1.0",$9;
}
}' ./tmp/output.csv > ./tmp/output2.csv
#...中略
14.0,8,440.0,215.0,4312.,8.5,70,1.0,0,0,"plymouth fury iii"
14.0,8,455.0,225.0,4425.,10.0,70,1.0,0,0,"pontiac catalina"
15.0,8,390.0,190.0,3850.,8.5,70,1.0,0,0,"amc ambassador dpl"
#以下略
視覚化の準備 ~ Parcelを利用してブラウザで表示
Chart.jsサンプルコードのビルド
index.html
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<script src="https://cdnjs.cloudflare.com/ajax/libs/Chart.js/2.9.4/Chart.min.js"></script>
<script src="./main.js"></script>
<title>Diagram to compare with csv data</title>
</head>
<body>
<div style="width: 1024px; height: auto;">
<canvas id="myChart"></canvas>
</div>
</body>
</html>
main.js
function csv2Array(str) {
const csvData = [];
const lines = str.split('\n');
for (let i = 0; i < lines.length; ++i) {
const cells = lines[i].split(',');
csvData.push(cells);
}
return csvData;
}
function drawBarChart(data) {
const tmpLabels = [], tmpData1 = [], tmpData2 = [];
for (const row in data) {
tmpLabels.push(data[row][0]);
tmpData1.push(data[row][1]);
tmpData2.push(data[row][2]);
};
const ctx = document.getElementById('myChart').getContext('2d');
new Chart(ctx, {
type: 'line',
data: {
labels: tmpLabels,
datasets: [
{ label: 'Tokyo', data: tmpData1, borderColor: 'rgba(0,0,255,1)', backgroundColor: 'rgba(0,0,0,0)' },
{ label: 'Osaka', data: tmpData2, borderColor: 'rgba(255,0,0,1)', backgroundColor: 'rgba(0,0,0,0)' }
]
}
});
}
;(function () {
const req = new XMLHttpRequest();
const filePath = './data.csv';
req.open('GET', filePath, true);
req.onload = function() {
data = csv2Array(req.responseText);
drawBarChart(data);
}
req.send(null);
}());
$ parcel build index.html --no-source-maps --no-content-hash main.js --no-cache
✨ Built in 801ms.
dist/main.js 1.59 KB 339ms
dist/index.html 334 B 591ms
Done in 1.28s.
dist
index.html
main.js
data.csv
dist
January, -10.4, -5.5
Feburary, -30.3, 1
March, 3.8, 12.3
April, 5.9, 13.5
May, 9.6, 16.4
June, 12.0, 19.4
July, 16.1, 28.2
August, 20.6, 30.3
September, 17.2, 26.2
October, 15.0, 20.8
November, 5.9, 10.1
December, 0.0, 3.3
http-server
$ http-server ./dist -a 0.0.0.0 -p 8080 -c-1
Starting up http-server, serving ./dist
Available on:
http://127.0.0.1:8080
Hit CTRL-C to stop the server
http://localhost:8080
線形回帰解析 ~ tensorflowjs + kerasでの線型回帰事始め
線形回帰
線形回帰?
線形回帰
測定データ
$ cat ./tmp/output2.csv | sort -k 5 -t "," > ./tmp/output3.csv
重量vs.燃費
function csv2Array(str) {
const csvData = [];
const lines = str.split('\n');
for (let i = 0; i < lines.length; ++i) {
console.log(lines[i]);
if (lines[i] == '') {
continue;
}
const cells = lines[i].split(',');
csvData.push(cells);
}
return csvData;
}
function drawBarChart(data) {
const tmpLabels = [], tmpData = [];
for (const row in data) {
tmpLabels.push(data[row][10]);
tmpData.push({x: data[row][4], y: data[row][0]});
};
const ctx = document.getElementById('myChart').getContext('2d');
const options = {
responsive: true,
tooltips: {
backgroundColor: "rgba(19, 56, 95, 0.9)",
titleFontSize: 16,
bodyFontSize: 16,
xPadding: 12,
yPadding: 10,
callbacks: {
label: (tooltipItem, data) => {
const groupName = data.labels[tooltipItem.index];
const xAxesLabel = options.scales.xAxes[0].scaleLabel.labelString;
const yAxesLabel = options.scales.yAxes[0].scaleLabel.labelString;
return `${groupName} | ${xAxesLabel}: ${tooltipItem.label} | ${yAxesLabel}: ${tooltipItem.value}`;
}
}
},
scales: {
xAxes: [{
scaleLabel: {
display: true,
labelString: 'Weight',
},
ticks: { min: 0 },
}],
yAxes: [{
scaleLabel: {
display: true,
labelString: 'MPG',
},
ticks: { min: 0 }
}]
},
};
new Chart(ctx, {
type: 'scatter',
data: {
labels: tmpLabels,
datasets: [
{
label: 'MPG',
data: tmpData,
borderColor: 'rgba(0,0,0,1)',
backgroundColor: 'rgba(0,0,255,1)',
pointRadius: 3,
pointHoverRadius: 9
}
]
},
options
});
}
;(function () {
const req = new XMLHttpRequest();
const filePath = './output3.csv';
req.open('GET', filePath, true);
req.onload = function() {
data = csv2Array(req.responseText);
drawBarChart(data);
}
req.send(null);
}());
output3.csv
dist
$ cp -f ./tmp/output3.csv ./dist/
tensorflowjsを使う準備
$ yarn add @tensorflow/tfjs -S
async/await
regeneratorRuntime is not defined
{
//...中略
"browserslist": [
"since 2017-06"
]
}
tensorflowjs&Kerasモデルによる機械学習
main.js
linear_regression.js
import * as tf from '@tensorflow/tfjs';
export async function linearRegression(rawData) {
const model = tf.sequential();
//👇入力層
model.add(tf.layers.dense({inputShape: [1], units: 1, useBias: true}));
//👇中間層(追加コードの内容は後述)
//👇出力層
model.add(tf.layers.dense({ units: 1, useBias: true, activation: 'linear'}));
console.log('Training started.');
const originalData = [], labels = [];
for (const row in rawData) {
originalData.push({x: parseFloat(rawData[row][4]), y: parseFloat(rawData[row][0])});
labels.push(rawData[row][10]);
};
// テンソルに変換・データの正規化
const tensorData = convertToTensor(originalData);
// モデルをトレーニング
await trainModel(model, tensorData.inputs, tensorData.labels);
// モデルのテスト
const predictedData = testModel(model, tensorData);
console.log('Fitting has done.');
return { labels, originalData, predictedData };
}
// モデルの学習を行う関数
async function trainModel(model, inputs, labels) {
// モデルをコンパイル=学習方法を指定
model.compile({
optimizer: tf.train.adam(),
loss: tf.losses.meanSquaredError,
metrics: ['mse'],
});
// バッチサイズ
const batchSize = 32;
// エポック数
const epochs = 50;
// エポック回数の学習を実行する
return await model.fit(inputs, labels, {
batchSize,
epochs,
shuffle: true,
callbacks: {
onEpochEnd: async(epoch, logs) => {
// 繰り返し回数と損失をコンソール出力
console.log(`Epoch#${epoch} : Loss(${logs.loss}) : mse(${logs.mse})`);
}
}
});
}
// 学習済みモデルからフィッティング曲線を生成
function testModel(model, normalizationData) {
const {inputMax, inputMin, labelMin, labelMax} = normalizationData;
const [xs, preds] = tf.tidy(() => {
// tf.linespaceで、0から1までの間で等間隔刻みに100個の値を生成
const xs = tf.linspace(0, 1, 100);
const preds = model.predict(xs.reshape([100, 1]));
// モデルの入出力値は正規化されていたので、これを元のスケールに復元する
const unNormXs = xs.mul(inputMax.sub(inputMin)).add(inputMin);
const unNormPreds = preds.mul(labelMax.sub(labelMin)).add(labelMin);
// Tensor型から配列型に変換
return [unNormXs.dataSync(), unNormPreds.dataSync()];
});
// Chart.jsで描画するデータ形式に配列を整える
const predictedPoints = Array.from(xs).map((val, i) => {
return {
x: val,
y: preds[i]
}
});
return predictedPoints;
}
// 学習データをTensor型に変換する関数
function convertToTensor(data) {
return tf.tidy(() => {
// 学習データをシャッフル
tf.util.shuffle(data);
// xyデータ配列をNx1テンソルに変換
const inputs = data.map(d => d.x)
const labels = data.map(d => d.y);
const inputTensor = tf.tensor2d(inputs, [inputs.length, 1]);
const labelTensor = tf.tensor2d(labels, [labels.length, 1]);
// 入力データの正規化
const inputMax = inputTensor.max();
const inputMin = inputTensor.min();
const labelMax = labelTensor.max();
const labelMin = labelTensor.min();
const normalizedInputs = inputTensor.sub(inputMin).div(inputMax.sub(inputMin));
const normalizedLabels = labelTensor.sub(labelMin).div(labelMax.sub(labelMin));
return {
inputs: normalizedInputs,
labels: normalizedLabels,
inputMax,
inputMin,
labelMax,
labelMin
}
});
}
main.js
const lr = require('./linear_regression');
function csv2Array(str) {
const csvData = [];
const lines = str.split('\n');
for (let i = 0; i < lines.length; ++i) {
if (lines[i] == '') { continue; }
const cells = lines[i].split(',');
csvData.push(cells);
}
return csvData;
}
function drawPredictedCurve(labelData_, originalData_, predictedData_) {
const ctx = document.getElementById('myChart').getContext('2d');
const options = {
responsive: true,
tooltips: {
backgroundColor: "rgba(19, 56, 95, 0.9)",
titleFontSize: 16,
bodyFontSize: 16,
xPadding: 12,
yPadding: 10
},
scales: {
xAxes: [{
scaleLabel: {
display: true,
labelString: 'Weight',
},
ticks: {
min: 0,
max: 7000,
stepSize: 1000
},
}],
yAxes: [
{
id: "y-axis-p",
type: "linear",
position: "right",
ticks: {
max: 50,
min: 0,
stepSize: 10
},
},
{
id: "y-axis-o",
type: "linear",
position: "left",
ticks: {
max: 50,
min: 0,
stepSize: 10
},
}
]
}
};
new Chart(ctx, {
type: 'scatter',
data: {
labels: labelData_,
datasets: [
{
label: 'Prediction',
type: 'line',
data: predictedData_,
borderColor: 'rgba(0,0,0,1)',
backgroundColor: 'rgba(0,0,0,0)',
pointRadius: 0,
yAxisID: "y-axis-p",
},
{
label: 'MPG',
type: 'scatter',
data: originalData_,
borderColor: 'rgba(0,0,0,1)',
backgroundColor: 'rgba(0,0,255,1)',
pointRadius: 3,
pointHoverRadius: 9,
yAxisID: "y-axis-o",
}
]
},
options
});
}
;(async () => {
const req = new XMLHttpRequest();
const filePath = './output3.csv';
req.open('GET', filePath, true);
req.onload = function() {
data = csv2Array(req.responseText);
const {tmpLabels, originalData, predictedData} = await lr.linearRegression(data);
drawPredictedCurve(tmpLabels, originalData, predictedData);
}
req.send(null);
})();
解析結果
//...中略
//👇中間層(追加コードの内容は後述)
model.add(tf.layers.dense({
units: 16,
useBias: true
}));
//...以下略
まとめ
参考にしたサイト
記事を書いた人
ナンデモ系エンジニア
主にAngularでフロントエンド開発することが多いです。 開発環境はLinuxメインで進めているので、シェルコマンドも多用しております。 コツコツとプログラミングするのが好きな人間です。
カテゴリー