Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 8ae780d

Browse files
Update models
Rename & change models
1 parent 4f6729e commit 8ae780d

File tree

2 files changed

+36
-63
lines changed

2 files changed

+36
-63
lines changed

‎index.html‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
<div id="skin-bmi-cont" />
1515
<div id="loss-cont"></div>
1616
<div id="acc-cont"></div>
17+
<div id="confusion-matrix"></div>
1718

1819
<script src="src/index.js"></script>
1920
</body>

‎src/index.js‎

Lines changed: 35 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ Papa.parsePromise = function(file) {
2222
});
2323
};
2424

25-
const oneHot = outcome => Array.from(tf.oneHot([outcome], 2).dataSync());
25+
const oneHot = outcome => Array.from(tf.oneHot(outcome, 2).dataSync());
2626

2727
const prepareData = async () => {
2828
const csv = await Papa.parsePromise(
@@ -32,56 +32,7 @@ const prepareData = async () => {
3232
return csv.data;
3333
};
3434

35-
const normalize = (val, min, max) => {
36-
var delta = max - min;
37-
return (val - min) / delta;
38-
};
39-
4035
const createDataSets = (data, features, testSize, batchSize) => {
41-
// const data = _.shuffle(csvData);
42-
43-
// const X = data.map(r => features.map(f => r[f]));
44-
// const y = data.map(r => oneHot(r.Outcome));
45-
46-
// const [xTrain, xTest] = _.chunk(X, parseInt((1 - testSize) * X.length, 10));
47-
// const [yTrain, yTest] = _.chunk(y, parseInt((1 - testSize) * y.length, 10));
48-
49-
// return [
50-
// tf.data.array(xTrain),
51-
// tf.data.array(xTest),
52-
// tf.data.array(yTrain),
53-
// tf.data.array(yTest)
54-
// ];
55-
56-
// const normalized = [];
57-
58-
// for (const f of features) {
59-
// const values = data.map(r => r[f]);
60-
// const min = _.min(values);
61-
// const max = _.max(values);
62-
63-
// const norm = values.map(val => {
64-
// if (val === undefined) {
65-
// return 0;
66-
// }
67-
// return normalize(val, min, max);
68-
// });
69-
// normalized.push(norm);
70-
// }
71-
72-
// const rowCount = data.length;
73-
// const colCount = features.length;
74-
75-
// const X = [];
76-
77-
// for (let row = 0; row < rowCount; row++) {
78-
// X.push([]);
79-
// for (let col = 0; col < colCount; col++) {
80-
// // i
81-
// X[row].push(normalized[col][row]);
82-
// }
83-
// }
84-
8536
const X = data.map(r =>
8637
features.map(f => {
8738
const val = r[f];
@@ -101,7 +52,9 @@ const createDataSets = (data, features, testSize, batchSize) => {
10152

10253
return [
10354
ds.take(splitIdx).batch(batchSize),
104-
ds.skip(splitIdx + 1).batch(batchSize)
55+
ds.skip(splitIdx + 1).batch(batchSize),
56+
tf.tensor(X.slice(splitIdx)),
57+
tf.tensor(y.slice(splitIdx))
10558
];
10659
};
10760

@@ -200,11 +153,7 @@ const renderScatter = (container, data, columns, config) => {
200153
});
201154
};
202155

203-
const trainSimpleModel = async (featureCount, trainDs, validDs) => {
204-
// const arr = await ds.take(10).toArray();
205-
206-
// console.log(arr[0].xs.arraySync());
207-
156+
const trainLogisticRegression = async (featureCount, trainDs, validDs) => {
208157
const model = tf.sequential();
209158
model.add(
210159
tf.layers.dense({
@@ -224,7 +173,7 @@ const trainSimpleModel = async (featureCount, trainDs, validDs) => {
224173
const accContainer = document.getElementById("acc-cont");
225174
console.log("Training...");
226175
await model.fitDataset(trainDs, {
227-
epochs: 30,
176+
epochs: 100,
228177
validationData: validDs,
229178
callbacks: {
230179
onEpochEnd: async (epoch, logs) => {
@@ -234,6 +183,8 @@ const trainSimpleModel = async (featureCount, trainDs, validDs) => {
234183
}
235184
}
236185
});
186+
187+
return model;
237188
};
238189

239190
const trainComplexModel = async (featureCount, trainDs, validDs) => {
@@ -255,7 +206,7 @@ const trainComplexModel = async (featureCount, trainDs, validDs) => {
255206
activation: "softmax"
256207
})
257208
);
258-
const optimizer = tf.train.adam(0.001);
209+
const optimizer = tf.train.adam(0.0001);
259210
model.compile({
260211
optimizer: optimizer,
261212
loss: "binaryCrossentropy",
@@ -266,7 +217,7 @@ const trainComplexModel = async (featureCount, trainDs, validDs) => {
266217
const accContainer = document.getElementById("acc-cont");
267218
console.log("Training...");
268219
await model.fitDataset(trainDs, {
269-
epochs: 30,
220+
epochs: 100,
270221
validationData: validDs,
271222
callbacks: {
272223
onEpochEnd: async (epoch, logs) => {
@@ -276,6 +227,8 @@ const trainComplexModel = async (featureCount, trainDs, validDs) => {
276227
}
277228
}
278229
});
230+
231+
return model;
279232
};
280233

281234
const run = async () => {
@@ -321,11 +274,30 @@ const run = async () => {
321274

322275
// const [trainDs, validDs] = createDataSets(data, features, 0.1, 16);
323276

324-
// trainSimpleModel(features.length, trainDs, validDs);
277+
// trainLogisticRegression(features.length, trainDs, validDs);
278+
279+
const features = ["Glucose", "Age", "Insulin", "BloodPressure"];
280+
281+
const [trainDs, validDs, xTest, yTest] = createDataSets(
282+
data,
283+
features,
284+
0.1,
285+
16
286+
);
287+
288+
const model = await trainComplexModel(features.length, trainDs, validDs);
325289

326-
// const preds = model.predict(tf.slice2d(xTest, 1, 1)).dataSync();
327-
// console.log("Prediction:" + preds);
328-
// console.log("Real:" + yTest.slice(1, 1).dataSync());
290+
const preds = model.predict(xTest).argMax(-1);
291+
const labels = yTest.argMax(-1);
292+
293+
const confusionMatrix = await tfvis.metrics.confusionMatrix(labels, preds);
294+
295+
const container = document.getElementById("confusion-matrix");
296+
297+
tfvis.render.confusionMatrix(container, {
298+
values: confusionMatrix,
299+
tickLabels: ["Healthy", "Diabetic"]
300+
});
329301
};
330302

331303
if (document.readyState !== "loading") {

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /