@@ -23,20 +23,42 @@ import Foundation
23
23
24
24
// MARK: Randomization & Statistical Helpers
25
25
26
+ struct Random {
27
+ private static var seeded = false
28
+
29
+ // a random Double between *from* and *to*, assumes *from* < *to*
30
+ static func double( from: Double , to: Double ) -> Double {
31
+ if !Random. seeded {
32
+ srand48 ( time ( nil ) )
33
+ Random . seeded = true
34
+ }
35
+
36
+ return ( drand48 ( ) * ( to - from) ) + from
37
+ }
38
+ }
39
+
26
40
/// Create *number* of random Doubles between 0.0 and 1.0
27
41
func randomWeights( number: Int ) -> [ Double ] {
28
- return ( 0 ..< number) . map { _ in Math . randomFractional ( ) }
42
+ return ( 0 ..< number) . map { _ in Random . double ( from : 0.0 , to : 1.0 ) }
29
43
}
30
44
31
45
/// Create *number* of random Doubles between 0.0 and *limit*
32
46
func randomNums( number: Int , limit: Double ) -> [ Double ] {
33
- return ( 0 ..< number) . map { _ in Math . randomTo ( limit : limit) }
47
+ return ( 0 ..< number) . map { _ in Random . double ( from : 0.0 , to : limit) }
34
48
}
35
49
36
- /// primitive shuffle - not fisher yates... not uniform distribution
37
- extension Sequence where Iterator. Element : Comparable {
38
- var shuffled : [ Self . Iterator . Element ] {
39
- return sorted { _, _ in arc4random ( ) % 2 == 0 }
50
+ // A derivative of the Fisher-Yates algorithm to shuffle an array
51
+ extension Array {
52
+ public func shuffled( ) -> Array < Element > {
53
+ var shuffledArray = self // value semantics (Array is Struct) makes this a copy
54
+ if count < 2 { return shuffledArray } // already shuffled
55
+ for i in ( 1 ..< count) . reversed ( ) { // count backwards
56
+ let position = Int ( arc4random_uniform ( UInt32 ( i + 1 ) ) ) // random to swap
57
+ if i != position { // swap with the end, don't bother with selp swaps
58
+ shuffledArray. swapAt ( i, position)
59
+ }
60
+ }
61
+ return shuffledArray
40
62
}
41
63
}
42
64
@@ -106,36 +128,7 @@ public func sum(x: [Double]) -> Double {
106
128
return result
107
129
}
108
130
109
- // MARK: Random Number Generation
110
131
111
- // this struct & the randomFractional() function
112
- // based on http://stackoverflow.com/a/35919911/281461
113
- struct Math {
114
- private static var seeded = false
115
-
116
- static func randomFractional( ) -> Double {
117
-
118
- if !Math. seeded {
119
- let time = Int ( NSDate ( ) . timeIntervalSinceReferenceDate)
120
- srand48 ( time)
121
- Math . seeded = true
122
- }
123
-
124
- return drand48 ( )
125
- }
126
-
127
- // addition, just multiplies random number by *limit*
128
- static func randomTo( limit: Double ) -> Double {
129
-
130
- if !Math. seeded {
131
- let time = Int ( NSDate ( ) . timeIntervalSinceReferenceDate)
132
- srand48 ( time)
133
- Math . seeded = true
134
- }
135
-
136
- return drand48 ( ) * limit
137
- }
138
- }
139
132
140
133
/// An individual node in a layer
141
134
class Neuron {
@@ -282,8 +275,7 @@ class Network {
282
275
283
276
/// for generalized results that require classification
284
277
/// this function will return the correct number of trials
285
- /// and the percentge correct out of the total
286
- /// See the unit tests for some examples
278
+ /// and the percentage correct out of the total
287
279
func validate< T: Equatable > ( inputs: [ [ Double ] ] , expecteds: [ T ] , interpretOutput: ( [ Double ] ) -> T ) -> ( correct: Int , total: Int , percentage: Double ) {
288
280
var correct = 0
289
281
for (input, expected) in zip ( inputs, expecteds) {
@@ -295,75 +287,120 @@ class Network {
295
287
let percentage = Double ( correct) / Double( inputs. count)
296
288
return ( correct, inputs. count, percentage)
297
289
}
298
-
299
- // for when result is a single neuron
300
- func validate( inputs: [ [ Double ] ] , expecteds: [ Double ] , accuracy: Double ) -> ( correct: Int , total: Int , percentage: Double ) {
301
- var correct = 0
302
- for (input, expected) in zip ( inputs, expecteds) {
303
- let result = outputs ( input: input) [ 0 ]
304
- if abs ( expected - result) < accuracy {
305
- correct += 1
306
- }
307
- }
308
- let percentage = Double ( correct) / Double( inputs. count)
309
- return ( correct, inputs. count, percentage)
310
- }
311
290
}
312
291
313
- var network : Network = Network ( layerStructure: [ 13 , 7 , 3 ] , learningRate: 7.0 )
314
- // for training
315
- var wineParameters : [ [ Double ] ] = [ [ Double] ] ( )
316
- var wineClassifications : [ [ Double ] ] = [ [ Double] ] ( )
317
- // for testing/validation
318
- var wineSamples : [ [ Double ] ] = [ [ Double] ] ( )
319
- var wineCultivars : [ Int ] = [ Int] ( )
292
+ /// Wine Test
293
+
294
+ //var network: Network = Network(layerStructure: [13,7,3], learningRate: 7.0)
295
+ //// for training
296
+ //var wineParameters: [[Double]] = [[Double]]()
297
+ //var wineClassifications: [[Double]] = [[Double]]()
298
+ //// for testing/validation
299
+ //var wineSamples: [[Double]] = [[Double]]()
300
+ //var wineCultivars: [Int] = [Int]()
301
+ //
302
+ //func parseWineCSV() {
303
+ // let myBundle = Bundle.main
304
+ // let urlpath = myBundle.path(forResource: "wine", ofType: "csv")
305
+ // let url = URL(fileURLWithPath: urlpath!)
306
+ // let csv = try! String.init(contentsOf: url)
307
+ // let lines = csv.components(separatedBy: "\n")
308
+ //
309
+ // let shuffledLines = lines.shuffled()
310
+ // for line in shuffledLines {
311
+ // if line == "" { continue }
312
+ // let items = line.components(separatedBy: ",")
313
+ // let parameters = items[1...13].map{ Double(0ドル)! }
314
+ // wineParameters.append(parameters)
315
+ // let species = Int(items[0])!
316
+ // if species == 1 {
317
+ // wineClassifications.append([1.0, 0.0, 0.0])
318
+ // } else if species == 2 {
319
+ // wineClassifications.append([0.0, 1.0, 0.0])
320
+ // } else {
321
+ // wineClassifications.append([0.0, 0.0, 1.0])
322
+ // }
323
+ // wineCultivars.append(species)
324
+ // }
325
+ // normalizeByColumnMax(dataset: &wineParameters)
326
+ // wineSamples = Array(wineParameters.dropFirst(150))
327
+ // wineCultivars = Array(wineCultivars.dropFirst(150))
328
+ // wineParameters = Array(wineParameters.dropLast(28))
329
+ //}
330
+ //
331
+ //func interpretOutput(output: [Double]) -> Int {
332
+ // if output.max()! == output[0] {
333
+ // return 1
334
+ // } else if output.max()! == output[1] {
335
+ // return 2
336
+ // } else {
337
+ // return 3
338
+ // }
339
+ //}
340
+ //
341
+ //parseWineCSV()
342
+ //// train over entire data set 5 times
343
+ //for _ in 0..<5 {
344
+ // network.train(inputs: wineParameters, expecteds: wineClassifications, printError: false)
345
+ //}
346
+ //
347
+ //let results = network.validate(inputs: wineSamples, expecteds: wineCultivars, interpretOutput: interpretOutput)
348
+ //print("\(results.correct) correct of \(results.total) = \(results.percentage * 100)%")
349
+
350
+ var network : Network = Network ( layerStructure: [ 4 , 5 , 3 ] , learningRate: 0.3 )
351
+ var irisParameters : [ [ Double ] ] = [ [ Double] ] ( )
352
+ var irisClassifications : [ [ Double ] ] = [ [ Double] ] ( )
353
+ var irisSpecies : [ String ] = [ String] ( )
320
354
321
- func parseWineCSV ( ) {
355
+ func parseIrisCSV ( ) {
322
356
let myBundle = Bundle . main
323
- let urlpath = myBundle. path ( forResource: " wine " , ofType: " csv " )
357
+ let urlpath = myBundle. path ( forResource: " iris " , ofType: " csv " )
324
358
let url = URL ( fileURLWithPath: urlpath!)
325
359
let csv = try ! String . init ( contentsOf: url)
326
360
let lines = csv. components ( separatedBy: " \n " )
327
361
328
- let shuffledLines = lines. shuffled
362
+ let shuffledLines = lines. shuffled ( )
329
363
for line in shuffledLines {
330
364
if line == " " { continue }
331
365
let items = line. components ( separatedBy: " , " )
332
- let parameters = items [ 1 ... 13 ] . map { Double ( 0ドル) ! }
333
- wineParameters . append ( parameters)
334
- let species = Int ( items [ 0 ] ) !
335
- if species == 1 {
336
- wineClassifications . append ( [ 1.0 , 0.0 , 0.0 ] )
337
- } else if species == 2 {
338
- wineClassifications . append ( [ 0.0 , 1.0 , 0.0 ] )
366
+ let parameters = items [ 0 ... 3 ] . map { Double ( 0ドル) ! }
367
+ irisParameters . append ( parameters)
368
+ let species = items [ 4 ]
369
+ if species == " Iris-setosa " {
370
+ irisClassifications . append ( [ 1.0 , 0.0 , 0.0 ] )
371
+ } else if species == " Iris-versicolor " {
372
+ irisClassifications . append ( [ 0.0 , 1.0 , 0.0 ] )
339
373
} else {
340
- wineClassifications . append ( [ 0.0 , 0.0 , 1.0 ] )
374
+ irisClassifications . append ( [ 0.0 , 0.0 , 1.0 ] )
341
375
}
342
- wineCultivars . append ( species)
376
+ irisSpecies . append ( species)
343
377
}
344
- normalizeByColumnMax ( dataset: & wineParameters)
345
- wineSamples = Array ( wineParameters. dropFirst ( 150 ) )
346
- wineCultivars = Array ( wineCultivars. dropFirst ( 150 ) )
347
- wineParameters = Array ( wineParameters. dropLast ( 28 ) )
378
+ normalizeByColumnMax ( dataset: & irisParameters)
348
379
}
349
380
350
- func interpretOutput( output: [ Double ] ) -> Int {
381
+ func interpretOutput( output: [ Double ] ) -> String {
351
382
if output. max ( ) ! == output [ 0 ] {
352
- return 1
383
+ return " Iris-setosa "
353
384
} else if output. max ( ) ! == output [ 1 ] {
354
- return 2
385
+ return " Iris-versicolor "
355
386
} else {
356
- return 3
387
+ return " Iris-virginica "
357
388
}
358
389
}
359
390
360
- parseWineCSV ( )
361
- // train over entire data set 5 times
362
- for _ in 0 ..< 5 {
363
- network. train ( inputs: wineParameters, expecteds: wineClassifications, printError: false )
391
+ // Put setup code here. This method is called before the invocation of each test method in the class.
392
+ parseIrisCSV ( )
393
+ // train over first 140 irises in data set 20 times
394
+ let trainers = Array ( irisParameters [ 0 ..< 140 ] )
395
+ let trainersCorrects = Array ( irisClassifications [ 0 ..< 140 ] )
396
+ for _ in 0 ..< 20 {
397
+ network. train ( inputs: trainers, expecteds: trainersCorrects, printError: false )
364
398
}
365
399
366
- let results = network. validate ( inputs: wineSamples, expecteds: wineCultivars, interpretOutput: interpretOutput)
400
+ // test over the last 10 of the irses in the data set
401
+ let testers = Array ( irisParameters [ 140 ..< 150 ] )
402
+ let testersCorrects = Array ( irisSpecies [ 140 ..< 150 ] )
403
+ let results = network. validate ( inputs: testers, expecteds: testersCorrects, interpretOutput: interpretOutput)
367
404
print ( " \( results. correct) correct of \( results. total) = \( results. percentage * 100 ) % " )
368
405
369
406
//: [Next](@next)
0 commit comments