Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 1ada79c

Browse files
committed
Replaced random & shuffle with versions from previous chapters; tuned wine and iris examples.
1 parent adaf753 commit 1ada79c

File tree

1 file changed

+119
-82
lines changed
  • Classic Computer Science Problems in Swift.playground/Pages/Chapter 7.xcplaygroundpage

1 file changed

+119
-82
lines changed

‎Classic Computer Science Problems in Swift.playground/Pages/Chapter 7.xcplaygroundpage/Contents.swift

Lines changed: 119 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,42 @@ import Foundation
2323

2424
// MARK: Randomization & Statistical Helpers
2525

26+
struct Random {
27+
private static var seeded = false
28+
29+
// a random Double between *from* and *to*, assumes *from* < *to*
30+
static func double(from: Double, to: Double) -> Double {
31+
if !Random.seeded {
32+
srand48(time(nil))
33+
Random.seeded = true
34+
}
35+
36+
return (drand48() * (to - from)) + from
37+
}
38+
}
39+
2640
/// Create *number* of random Doubles between 0.0 and 1.0
2741
func randomWeights(number: Int) -> [Double] {
28-
return (0..<number).map{ _ in Math.randomFractional() }
42+
return (0..<number).map{ _ in Random.double(from:0.0, to:1.0) }
2943
}
3044

3145
/// Create *number* of random Doubles between 0.0 and *limit*
3246
func randomNums(number: Int, limit: Double) -> [Double] {
33-
return (0..<number).map{ _ in Math.randomTo(limit: limit) }
47+
return (0..<number).map{ _ in Random.double(from:0.0, to: limit) }
3448
}
3549

36-
/// primitive shuffle - not fisher yates... not uniform distribution
37-
extension Sequence where Iterator.Element : Comparable {
38-
var shuffled: [Self.Iterator.Element] {
39-
return sorted { _, _ in arc4random() % 2 == 0 }
50+
// A derivative of the Fisher-Yates algorithm to shuffle an array
51+
extension Array {
52+
public func shuffled() -> Array<Element> {
53+
var shuffledArray = self // value semantics (Array is Struct) makes this a copy
54+
if count < 2 { return shuffledArray } // already shuffled
55+
for i in (1..<count).reversed() { // count backwards
56+
let position = Int(arc4random_uniform(UInt32(i + 1))) // random to swap
57+
if i != position { // swap with the end, don't bother with selp swaps
58+
shuffledArray.swapAt(i, position)
59+
}
60+
}
61+
return shuffledArray
4062
}
4163
}
4264

@@ -106,36 +128,7 @@ public func sum(x: [Double]) -> Double {
106128
return result
107129
}
108130

109-
// MARK: Random Number Generation
110131

111-
// this struct & the randomFractional() function
112-
// based on http://stackoverflow.com/a/35919911/281461
113-
struct Math {
114-
private static var seeded = false
115-
116-
static func randomFractional() -> Double {
117-
118-
if !Math.seeded {
119-
let time = Int(NSDate().timeIntervalSinceReferenceDate)
120-
srand48(time)
121-
Math.seeded = true
122-
}
123-
124-
return drand48()
125-
}
126-
127-
// addition, just multiplies random number by *limit*
128-
static func randomTo(limit: Double) -> Double {
129-
130-
if !Math.seeded {
131-
let time = Int(NSDate().timeIntervalSinceReferenceDate)
132-
srand48(time)
133-
Math.seeded = true
134-
}
135-
136-
return drand48() * limit
137-
}
138-
}
139132

140133
/// An individual node in a layer
141134
class Neuron {
@@ -282,8 +275,7 @@ class Network {
282275

283276
/// for generalized results that require classification
284277
/// this function will return the correct number of trials
285-
/// and the percentge correct out of the total
286-
/// See the unit tests for some examples
278+
/// and the percentage correct out of the total
287279
func validate<T: Equatable>(inputs:[[Double]], expecteds:[T], interpretOutput: ([Double]) -> T) -> (correct: Int, total: Int, percentage: Double) {
288280
var correct = 0
289281
for (input, expected) in zip(inputs, expecteds) {
@@ -295,75 +287,120 @@ class Network {
295287
let percentage = Double(correct) / Double(inputs.count)
296288
return (correct, inputs.count, percentage)
297289
}
298-
299-
// for when result is a single neuron
300-
func validate(inputs:[[Double]], expecteds:[Double], accuracy: Double) -> (correct: Int, total: Int, percentage: Double) {
301-
var correct = 0
302-
for (input, expected) in zip(inputs, expecteds) {
303-
let result = outputs(input: input)[0]
304-
if abs(expected - result) < accuracy {
305-
correct += 1
306-
}
307-
}
308-
let percentage = Double(correct) / Double(inputs.count)
309-
return (correct, inputs.count, percentage)
310-
}
311290
}
312291

313-
var network: Network = Network(layerStructure: [13,7,3], learningRate: 7.0)
314-
// for training
315-
var wineParameters: [[Double]] = [[Double]]()
316-
var wineClassifications: [[Double]] = [[Double]]()
317-
// for testing/validation
318-
var wineSamples: [[Double]] = [[Double]]()
319-
var wineCultivars: [Int] = [Int]()
292+
/// Wine Test
293+
294+
//var network: Network = Network(layerStructure: [13,7,3], learningRate: 7.0)
295+
//// for training
296+
//var wineParameters: [[Double]] = [[Double]]()
297+
//var wineClassifications: [[Double]] = [[Double]]()
298+
//// for testing/validation
299+
//var wineSamples: [[Double]] = [[Double]]()
300+
//var wineCultivars: [Int] = [Int]()
301+
//
302+
//func parseWineCSV() {
303+
// let myBundle = Bundle.main
304+
// let urlpath = myBundle.path(forResource: "wine", ofType: "csv")
305+
// let url = URL(fileURLWithPath: urlpath!)
306+
// let csv = try! String.init(contentsOf: url)
307+
// let lines = csv.components(separatedBy: "\n")
308+
//
309+
// let shuffledLines = lines.shuffled()
310+
// for line in shuffledLines {
311+
// if line == "" { continue }
312+
// let items = line.components(separatedBy: ",")
313+
// let parameters = items[1...13].map{ Double(0ドル)! }
314+
// wineParameters.append(parameters)
315+
// let species = Int(items[0])!
316+
// if species == 1 {
317+
// wineClassifications.append([1.0, 0.0, 0.0])
318+
// } else if species == 2 {
319+
// wineClassifications.append([0.0, 1.0, 0.0])
320+
// } else {
321+
// wineClassifications.append([0.0, 0.0, 1.0])
322+
// }
323+
// wineCultivars.append(species)
324+
// }
325+
// normalizeByColumnMax(dataset: &wineParameters)
326+
// wineSamples = Array(wineParameters.dropFirst(150))
327+
// wineCultivars = Array(wineCultivars.dropFirst(150))
328+
// wineParameters = Array(wineParameters.dropLast(28))
329+
//}
330+
//
331+
//func interpretOutput(output: [Double]) -> Int {
332+
// if output.max()! == output[0] {
333+
// return 1
334+
// } else if output.max()! == output[1] {
335+
// return 2
336+
// } else {
337+
// return 3
338+
// }
339+
//}
340+
//
341+
//parseWineCSV()
342+
//// train over entire data set 5 times
343+
//for _ in 0..<5 {
344+
// network.train(inputs: wineParameters, expecteds: wineClassifications, printError: false)
345+
//}
346+
//
347+
//let results = network.validate(inputs: wineSamples, expecteds: wineCultivars, interpretOutput: interpretOutput)
348+
//print("\(results.correct) correct of \(results.total) = \(results.percentage * 100)%")
349+
350+
var network: Network = Network(layerStructure: [4,5,3], learningRate: 0.3)
351+
var irisParameters: [[Double]] = [[Double]]()
352+
var irisClassifications: [[Double]] = [[Double]]()
353+
var irisSpecies: [String] = [String]()
320354

321-
func parseWineCSV() {
355+
func parseIrisCSV() {
322356
let myBundle = Bundle.main
323-
let urlpath = myBundle.path(forResource: "wine", ofType: "csv")
357+
let urlpath = myBundle.path(forResource: "iris", ofType: "csv")
324358
let url = URL(fileURLWithPath: urlpath!)
325359
let csv = try! String.init(contentsOf: url)
326360
let lines = csv.components(separatedBy: "\n")
327361

328-
let shuffledLines = lines.shuffled
362+
let shuffledLines = lines.shuffled()
329363
for line in shuffledLines {
330364
if line == "" { continue }
331365
let items = line.components(separatedBy: ",")
332-
let parameters = items[1...13].map{ Double(0ドル)! }
333-
wineParameters.append(parameters)
334-
let species = Int(items[0])!
335-
if species == 1 {
336-
wineClassifications.append([1.0, 0.0, 0.0])
337-
} else if species == 2 {
338-
wineClassifications.append([0.0, 1.0, 0.0])
366+
let parameters = items[0...3].map{ Double(0ドル)! }
367+
irisParameters.append(parameters)
368+
let species = items[4]
369+
if species == "Iris-setosa" {
370+
irisClassifications.append([1.0, 0.0, 0.0])
371+
} else if species == "Iris-versicolor" {
372+
irisClassifications.append([0.0, 1.0, 0.0])
339373
} else {
340-
wineClassifications.append([0.0, 0.0, 1.0])
374+
irisClassifications.append([0.0, 0.0, 1.0])
341375
}
342-
wineCultivars.append(species)
376+
irisSpecies.append(species)
343377
}
344-
normalizeByColumnMax(dataset: &wineParameters)
345-
wineSamples = Array(wineParameters.dropFirst(150))
346-
wineCultivars = Array(wineCultivars.dropFirst(150))
347-
wineParameters = Array(wineParameters.dropLast(28))
378+
normalizeByColumnMax(dataset: &irisParameters)
348379
}
349380

350-
func interpretOutput(output: [Double]) -> Int {
381+
func interpretOutput(output: [Double]) -> String {
351382
if output.max()! == output[0] {
352-
return 1
383+
return "Iris-setosa"
353384
} else if output.max()! == output[1] {
354-
return 2
385+
return "Iris-versicolor"
355386
} else {
356-
return 3
387+
return "Iris-virginica"
357388
}
358389
}
359390

360-
parseWineCSV()
361-
// train over entire data set 5 times
362-
for _ in 0..<5 {
363-
network.train(inputs: wineParameters, expecteds: wineClassifications, printError: false)
391+
// Put setup code here. This method is called before the invocation of each test method in the class.
392+
parseIrisCSV()
393+
// train over first 140 irises in data set 20 times
394+
let trainers = Array(irisParameters[0..<140])
395+
let trainersCorrects = Array(irisClassifications[0..<140])
396+
for _ in 0..<20 {
397+
network.train(inputs: trainers, expecteds: trainersCorrects, printError: false)
364398
}
365399

366-
let results = network.validate(inputs: wineSamples, expecteds: wineCultivars, interpretOutput: interpretOutput)
400+
// test over the last 10 of the irses in the data set
401+
let testers = Array(irisParameters[140..<150])
402+
let testersCorrects = Array(irisSpecies[140..<150])
403+
let results = network.validate(inputs: testers, expecteds: testersCorrects, interpretOutput: interpretOutput)
367404
print("\(results.correct) correct of \(results.total) = \(results.percentage * 100)%")
368405

369406
//: [Next](@next)

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /