@@ -3,20 +3,87 @@ import "./styles.css";
33import * as tf from "@tensorflow/tfjs" ;
44import * as tfvis from "@tensorflow/tfjs-vis" ;
55import * as Papa from "papaparse" ;
6+ import _ from "lodash" ;
67
78// data from:
89// https://www.kaggle.com/kandij/diabetes-dataset
910
1011Papa . parsePromise = function ( file ) {
1112 return new Promise ( function ( complete , error ) {
12- Papa . parse ( file , { complete, error } ) ;
13+ Papa . parse ( file , {
14+ header : true ,
15+ download : true ,
16+ dynamicTyping : true ,
17+ complete,
18+ error
19+ } ) ;
1320 } ) ;
1421} ;
1522
23+ const oneHot = outcome => Array . from ( tf . oneHot ( [ outcome ] , 2 ) . dataSync ( ) ) ;
24+ 25+ const prepareData = async testSize => {
26+ const csv = await Papa . parsePromise (
27+ "https://raw.githubusercontent.com/curiousily/Logistic-Regression-with-TensorFlow-js/master/src/data/diabetes.csv"
28+ ) ;
29+ 30+ const data = _ . shuffle ( csv . data ) ;
31+ 32+ const X = data . map ( r =>
33+ Object . values ( r ) . slice ( 0 , Object . values ( r ) . length - 1 )
34+ ) ;
35+ const y = data . map ( r => oneHot ( r . Outcome ) ) ;
36+ 37+ const [ xTrain , xTest ] = _ . chunk ( X , parseInt ( ( 1 - testSize ) * X . length , 10 ) ) ;
38+ const [ yTrain , yTest ] = _ . chunk ( y , parseInt ( ( 1 - testSize ) * y . length , 10 ) ) ;
39+ 40+ return [
41+ tf . tensor2d ( xTrain ) ,
42+ tf . tensor ( xTest ) ,
43+ tf . tensor2d ( yTrain ) ,
44+ tf . tensor ( yTest )
45+ ] ;
46+ } ;
47+ 1648const run = async ( ) => {
17- console . log ( new File ( "./diabetes.csv" ) ) ;
18- // const csv = await Papa.parsePromise(new File("./data/diabetes.csv"));
19- // console.log(csv);
49+ const [ xTrain , xTest , yTrain , yTest ] = await prepareData ( 0.1 ) ;
50+ 51+ console . log ( xTrain . shape ) ;
52+ const model = tf . sequential ( ) ;
53+ model . add (
54+ tf . layers . dense ( {
55+ units : 32 ,
56+ activation : "relu" ,
57+ inputShape : [ xTrain . shape [ 1 ] ]
58+ } )
59+ ) ;
60+ model . add ( tf . layers . dense ( { units : 2 , activation : "softmax" } ) ) ;
61+ const optimizer = tf . train . adam ( 0.001 ) ;
62+ model . compile ( {
63+ optimizer : optimizer ,
64+ loss : "categoricalCrossentropy" ,
65+ metrics : [ "accuracy" ]
66+ } ) ;
67+ const trainLogs = [ ] ;
68+ const lossContainer = document . getElementById ( "loss-cont" ) ;
69+ const accContainer = document . getElementById ( "acc-cont" ) ;
70+ console . log ( "Training..." ) ;
71+ await model . fit ( xTrain , yTrain , {
72+ validationData : [ xTest , yTest ] ,
73+ epochs : 100 ,
74+ shuffle : true ,
75+ callbacks : {
76+ onEpochEnd : async ( epoch , logs ) => {
77+ trainLogs . push ( logs ) ;
78+ tfvis . show . history ( lossContainer , trainLogs , [ "loss" , "val_loss" ] ) ;
79+ tfvis . show . history ( accContainer , trainLogs , [ "acc" , "val_acc" ] ) ;
80+ }
81+ }
82+ } ) ;
83+ 84+ const preds = model . predict ( tf . slice2d ( xTest , 1 , 1 ) ) . dataSync ( ) ;
85+ console . log ( "Prediction:" + preds ) ;
86+ console . log ( "Real:" + yTest . slice ( 1 , 1 ) . dataSync ( ) ) ;
2087} ;
2188
2289if ( document . readyState !== "loading" ) {
0 commit comments