Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 8de69cb

Browse files
Logistic Regressor
Complete example
1 parent 570a569 commit 8de69cb

File tree

3 files changed

+76
-7
lines changed

3 files changed

+76
-7
lines changed

‎index.html‎

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
<html>
33

44
<head>
5-
<title>Parcel Sandbox</title>
5+
<title>Logistic Regression with TensorFlow.js</title>
66
<meta charset="UTF-8" />
77
</head>
88

99
<body>
10-
<div id="app"></div>
10+
<div id="loss-cont"></div>
11+
<div id="acc-cont"></div>
1112

1213
<script src="src/index.js">
1314
</script>

‎package.json‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
"dependencies": {
1111
"@tensorflow/tfjs": "1.2.2",
1212
"@tensorflow/tfjs-vis": "1.1.0",
13+
"lodash": "4.17.11",
1314
"papaparse": "5.0.0"
1415
},
1516
"devDependencies": {
@@ -19,4 +20,4 @@
1920
"parcel-bundler": "^1.6.1"
2021
},
2122
"keywords": []
22-
}
23+
}

‎src/index.js‎

Lines changed: 71 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,87 @@ import "./styles.css";
33
import * as tf from "@tensorflow/tfjs";
44
import * as tfvis from "@tensorflow/tfjs-vis";
55
import * as Papa from "papaparse";
6+
import _ from "lodash";
67

78
// data from:
89
// https://www.kaggle.com/kandij/diabetes-dataset
910

1011
Papa.parsePromise = function(file) {
1112
return new Promise(function(complete, error) {
12-
Papa.parse(file, { complete, error });
13+
Papa.parse(file, {
14+
header: true,
15+
download: true,
16+
dynamicTyping: true,
17+
complete,
18+
error
19+
});
1320
});
1421
};
1522

23+
const oneHot = outcome => Array.from(tf.oneHot([outcome], 2).dataSync());
24+
25+
const prepareData = async testSize => {
26+
const csv = await Papa.parsePromise(
27+
"https://raw.githubusercontent.com/curiousily/Logistic-Regression-with-TensorFlow-js/master/src/data/diabetes.csv"
28+
);
29+
30+
const data = _.shuffle(csv.data);
31+
32+
const X = data.map(r =>
33+
Object.values(r).slice(0, Object.values(r).length - 1)
34+
);
35+
const y = data.map(r => oneHot(r.Outcome));
36+
37+
const [xTrain, xTest] = _.chunk(X, parseInt((1 - testSize) * X.length, 10));
38+
const [yTrain, yTest] = _.chunk(y, parseInt((1 - testSize) * y.length, 10));
39+
40+
return [
41+
tf.tensor2d(xTrain),
42+
tf.tensor(xTest),
43+
tf.tensor2d(yTrain),
44+
tf.tensor(yTest)
45+
];
46+
};
47+
1648
const run = async () => {
17-
console.log(new File("./diabetes.csv"));
18-
// const csv = await Papa.parsePromise(new File("./data/diabetes.csv"));
19-
// console.log(csv);
49+
const [xTrain, xTest, yTrain, yTest] = await prepareData(0.1);
50+
51+
console.log(xTrain.shape);
52+
const model = tf.sequential();
53+
model.add(
54+
tf.layers.dense({
55+
units: 32,
56+
activation: "relu",
57+
inputShape: [xTrain.shape[1]]
58+
})
59+
);
60+
model.add(tf.layers.dense({ units: 2, activation: "softmax" }));
61+
const optimizer = tf.train.adam(0.001);
62+
model.compile({
63+
optimizer: optimizer,
64+
loss: "categoricalCrossentropy",
65+
metrics: ["accuracy"]
66+
});
67+
const trainLogs = [];
68+
const lossContainer = document.getElementById("loss-cont");
69+
const accContainer = document.getElementById("acc-cont");
70+
console.log("Training...");
71+
await model.fit(xTrain, yTrain, {
72+
validationData: [xTest, yTest],
73+
epochs: 100,
74+
shuffle: true,
75+
callbacks: {
76+
onEpochEnd: async (epoch, logs) => {
77+
trainLogs.push(logs);
78+
tfvis.show.history(lossContainer, trainLogs, ["loss", "val_loss"]);
79+
tfvis.show.history(accContainer, trainLogs, ["acc", "val_acc"]);
80+
}
81+
}
82+
});
83+
84+
const preds = model.predict(tf.slice2d(xTest, 1, 1)).dataSync();
85+
console.log("Prediction:" + preds);
86+
console.log("Real:" + yTest.slice(1, 1).dataSync());
2087
};
2188

2289
if (document.readyState !== "loading") {

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /