@@ -269,119 +269,118 @@ func normalizeByFeatureScaling(dataset: inout [[Double]]) {
269
269
270
270
// MARK: Iris Test
271
271
272
- func parseIrisCSV( ) -> ( parameters: [ [ Double ] ] , classifications: [ [ Double ] ] , species: [ String ] ) {
273
- let urlpath = Bundle . main. path ( forResource: " iris " , ofType: " csv " )
272
+ //func parseIrisCSV() -> (parameters: [[Double]], classifications: [[Double]], species: [String]) {
273
+ // let urlpath = Bundle.main.path(forResource: "iris", ofType: "csv")
274
+ // let url = URL(fileURLWithPath: urlpath!)
275
+ // let csv = try! String.init(contentsOf: url)
276
+ // let lines = csv.components(separatedBy: "\n")
277
+ // var irisParameters: [[Double]] = [[Double]]()
278
+ // var irisClassifications: [[Double]] = [[Double]]()
279
+ // var irisSpecies: [String] = [String]()
280
+ //
281
+ // let shuffledLines = lines.shuffled()
282
+ // for line in shuffledLines {
283
+ // if line == "" { continue } // skip blank lines
284
+ // let items = line.components(separatedBy: ",")
285
+ // let parameters = items[0...3].map{ Double(0ドル)! }
286
+ // irisParameters.append(parameters)
287
+ // let species = items[4]
288
+ // if species == "Iris-setosa" {
289
+ // irisClassifications.append([1.0, 0.0, 0.0])
290
+ // } else if species == "Iris-versicolor" {
291
+ // irisClassifications.append([0.0, 1.0, 0.0])
292
+ // } else {
293
+ // irisClassifications.append([0.0, 0.0, 1.0])
294
+ // }
295
+ // irisSpecies.append(species)
296
+ // }
297
+ // normalizeByFeatureScaling(dataset: &irisParameters)
298
+ // return (irisParameters, irisClassifications, irisSpecies)
299
+ //}
300
+ //
301
+ //let (irisParameters, irisClassifications, irisSpecies) = parseIrisCSV()
302
+ //
303
+ //let irisNetwork: Network = Network(layerStructure: [4, 6, 3], learningRate: 0.3)
304
+ //
305
+ //func irisInterpretOutput(output: [Double]) -> String {
306
+ // if output.max()! == output[0] {
307
+ // return "Iris-setosa"
308
+ // } else if output.max()! == output[1] {
309
+ // return "Iris-versicolor"
310
+ // } else {
311
+ // return "Iris-virginica"
312
+ // }
313
+ //}
314
+ //
315
+ //// train over first 140 irises in data set 20 times
316
+ //let irisTrainers = Array(irisParameters[0..<140])
317
+ //let irisTrainersCorrects = Array(irisClassifications[0..<140])
318
+ //for _ in 0..<20 {
319
+ // irisNetwork.train(inputs: irisTrainers, expecteds: irisTrainersCorrects, printError: false)
320
+ //}
321
+ //
322
+ //// test over the last 10 of the irses in the data set
323
+ //let irisTesters = Array(irisParameters[140..<150])
324
+ //let irisTestersCorrects = Array(irisSpecies[140..<150])
325
+ //let irisResults = irisNetwork.validate(inputs: irisTesters, expecteds: irisTestersCorrects, interpretOutput: irisInterpretOutput)
326
+ //print("\(irisResults.correct) correct of \(irisResults.total) = \(irisResults.percentage * 100)%")
327
+
328
+ /// Wine Test
329
+
330
+ func parseWineCSV( ) -> ( parameters: [ [ Double ] ] , classifications: [ [ Double ] ] , species: [ Int ] ) {
331
+ let urlpath = Bundle . main. path ( forResource: " wine " , ofType: " csv " )
274
332
let url = URL ( fileURLWithPath: urlpath!)
275
333
let csv = try ! String . init ( contentsOf: url)
276
334
let lines = csv. components ( separatedBy: " \n " )
277
- var irisParameters : [ [ Double ] ] = [ [ Double] ] ( )
278
- var irisClassifications : [ [ Double ] ] = [ [ Double] ] ( )
279
- var irisSpecies : [ String ] = [ String ] ( )
280
-
335
+ var wineParameters : [ [ Double ] ] = [ [ Double] ] ( )
336
+ var wineClassifications : [ [ Double ] ] = [ [ Double] ] ( )
337
+ var wineSpecies : [ Int ] = [ Int ] ( )
338
+
281
339
let shuffledLines = lines. shuffled ( )
282
340
for line in shuffledLines {
283
- if line == " " { continue }
341
+ if line == " " { continue } // skip blank lines
284
342
let items = line. components ( separatedBy: " , " )
285
- let parameters = items [ 0 ... 3 ] . map { Double ( 0ドル) ! }
286
- irisParameters . append ( parameters)
287
- let species = items [ 4 ]
288
- if species == " Iris-setosa " {
289
- irisClassifications . append ( [ 1.0 , 0.0 , 0.0 ] )
290
- } else if species == " Iris-versicolor " {
291
- irisClassifications . append ( [ 0.0 , 1.0 , 0.0 ] )
343
+ let parameters = items [ 1 ... 13 ] . map { Double ( 0ドル) ! }
344
+ wineParameters . append ( parameters)
345
+ let species = Int ( items [ 0 ] ) !
346
+ if species == 1 {
347
+ wineClassifications . append ( [ 1.0 , 0.0 , 0.0 ] )
348
+ } else if species == 2 {
349
+ wineClassifications . append ( [ 0.0 , 1.0 , 0.0 ] )
292
350
} else {
293
- irisClassifications . append ( [ 0.0 , 0.0 , 1.0 ] )
351
+ wineClassifications . append ( [ 0.0 , 0.0 , 1.0 ] )
294
352
}
295
- irisSpecies . append ( species)
353
+ wineSpecies . append ( species)
296
354
}
297
- normalizeByFeatureScaling ( dataset: & irisParameters )
298
- return ( irisParameters , irisClassifications , irisSpecies )
355
+ normalizeByFeatureScaling ( dataset: & wineParameters )
356
+ return ( wineParameters , wineClassifications , wineSpecies )
299
357
}
300
358
301
- let ( irisParameters , irisClassifications , irisSpecies ) = parseIrisCSV ( )
359
+ let ( wineParameters , wineClassifications , wineSpecies ) = parseWineCSV ( )
302
360
303
- var irisNetwork : Network = Network ( layerStructure: [ 4 , 6 , 3 ] , learningRate: 0.3 )
361
+ let wineNetwork : Network = Network ( layerStructure: [ 13 , 7 , 3 ] , learningRate: 0.9 )
304
362
305
- func irisInterpretOutput ( output: [ Double ] ) -> String {
363
+ func wineInterpretOutput ( output: [ Double ] ) -> Int {
306
364
if output. max ( ) ! == output [ 0 ] {
307
- return " Iris-setosa "
365
+ return 1
308
366
} else if output. max ( ) ! == output [ 1 ] {
309
- return " Iris-versicolor "
367
+ return 2
310
368
} else {
311
- return " Iris-virginica "
369
+ return 3
312
370
}
313
371
}
314
372
315
- // train over first 140 irises in data set 20 times
316
- let irisTrainers = Array ( irisParameters [ 0 ..< 140 ] )
317
- let irisTrainersCorrects = Array ( irisClassifications [ 0 ..< 140 ] )
318
- for _ in 0 ..< 20 {
319
- irisNetwork . train ( inputs: irisTrainers , expecteds: irisTrainersCorrects , printError: false )
373
+ // train over the first 150 samples 5 times
374
+ let wineTrainers = Array ( wineParameters . dropLast ( 28 ) )
375
+ let wineTrainersCorrects = Array ( wineClassifications . dropLast ( 28 ) )
376
+ for _ in 0 ..< 5 {
377
+ wineNetwork . train ( inputs: wineTrainers , expecteds: wineTrainersCorrects , printError: false )
320
378
}
321
379
322
- // test over the last 10 of the irses in the data set
323
- let irisTesters = Array ( irisParameters [ 140 ..< 150 ] )
324
- let irisTestersCorrects = Array ( irisSpecies [ 140 ..< 150 ] )
325
- let irisResults = irisNetwork. validate ( inputs: irisTesters, expecteds: irisTestersCorrects, interpretOutput: irisInterpretOutput)
326
- print ( " \( irisResults. correct) correct of \( irisResults. total) = \( irisResults. percentage * 100 ) % " )
327
-
328
- /// Wine Test
329
-
330
- //var network: Network = Network(layerStructure: [13,7,3], learningRate: 7.0)
331
- //// for training
332
- //var wineParameters: [[Double]] = [[Double]]()
333
- //var wineClassifications: [[Double]] = [[Double]]()
334
- //// for testing/validation
335
- //var wineSamples: [[Double]] = [[Double]]()
336
- //var wineCultivars: [Int] = [Int]()
337
- //
338
- //func parseWineCSV() {
339
- // let myBundle = Bundle.main
340
- // let urlpath = myBundle.path(forResource: "wine", ofType: "csv")
341
- // let url = URL(fileURLWithPath: urlpath!)
342
- // let csv = try! String.init(contentsOf: url)
343
- // let lines = csv.components(separatedBy: "\n")
344
- //
345
- // let shuffledLines = lines.shuffled()
346
- // for line in shuffledLines {
347
- // if line == "" { continue }
348
- // let items = line.components(separatedBy: ",")
349
- // let parameters = items[1...13].map{ Double(0ドル)! }
350
- // wineParameters.append(parameters)
351
- // let species = Int(items[0])!
352
- // if species == 1 {
353
- // wineClassifications.append([1.0, 0.0, 0.0])
354
- // } else if species == 2 {
355
- // wineClassifications.append([0.0, 1.0, 0.0])
356
- // } else {
357
- // wineClassifications.append([0.0, 0.0, 1.0])
358
- // }
359
- // wineCultivars.append(species)
360
- // }
361
- // normalizeByFeatureScaling(dataset: &wineParameters)
362
- // wineSamples = Array(wineParameters.dropFirst(150))
363
- // wineCultivars = Array(wineCultivars.dropFirst(150))
364
- // wineParameters = Array(wineParameters.dropLast(28))
365
- //}
366
- //
367
- //func interpretOutput(output: [Double]) -> Int {
368
- // if output.max()! == output[0] {
369
- // return 1
370
- // } else if output.max()! == output[1] {
371
- // return 2
372
- // } else {
373
- // return 3
374
- // }
375
- //}
376
- //
377
- //parseWineCSV()
378
- //// train over entire data set 5 times
379
- //for _ in 0..<5 {
380
- // network.train(inputs: wineParameters, expecteds: wineClassifications, printError: false)
381
- //}
382
- //
383
- //let results = network.validate(inputs: wineSamples, expecteds: wineCultivars, interpretOutput: interpretOutput)
384
- //print("\(results.correct) correct of \(results.total) = \(results.percentage * 100)%")
380
+ let wineTesters = Array ( wineParameters. dropFirst ( 150 ) )
381
+ let wineTestersCorrects = Array ( wineSpecies. dropFirst ( 150 ) )
382
+ let results = wineNetwork. validate ( inputs: wineTesters, expecteds: wineTestersCorrects, interpretOutput: wineInterpretOutput)
383
+ print ( " \( results. correct) correct of \( results. total) = \( results. percentage * 100 ) % " )
385
384
386
385
//: [Next](@next)
387
386
0 commit comments