Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 44b7763

Browse files
committed
cleaned up wine example
1 parent 41c8d05 commit 44b7763

File tree

1 file changed

+91
-92
lines changed
  • Classic Computer Science Problems in Swift.playground/Pages/Chapter 7.xcplaygroundpage

1 file changed

+91
-92
lines changed

‎Classic Computer Science Problems in Swift.playground/Pages/Chapter 7.xcplaygroundpage/Contents.swift

Lines changed: 91 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -269,119 +269,118 @@ func normalizeByFeatureScaling(dataset: inout [[Double]]) {
269269

270270
// MARK: Iris Test
271271

272-
func parseIrisCSV() -> (parameters: [[Double]], classifications: [[Double]], species: [String]) {
273-
let urlpath = Bundle.main.path(forResource: "iris", ofType: "csv")
272+
//func parseIrisCSV() -> (parameters: [[Double]], classifications: [[Double]], species: [String]) {
273+
// let urlpath = Bundle.main.path(forResource: "iris", ofType: "csv")
274+
// let url = URL(fileURLWithPath: urlpath!)
275+
// let csv = try! String.init(contentsOf: url)
276+
// let lines = csv.components(separatedBy: "\n")
277+
// var irisParameters: [[Double]] = [[Double]]()
278+
// var irisClassifications: [[Double]] = [[Double]]()
279+
// var irisSpecies: [String] = [String]()
280+
//
281+
// let shuffledLines = lines.shuffled()
282+
// for line in shuffledLines {
283+
// if line == "" { continue } // skip blank lines
284+
// let items = line.components(separatedBy: ",")
285+
// let parameters = items[0...3].map{ Double(0ドル)! }
286+
// irisParameters.append(parameters)
287+
// let species = items[4]
288+
// if species == "Iris-setosa" {
289+
// irisClassifications.append([1.0, 0.0, 0.0])
290+
// } else if species == "Iris-versicolor" {
291+
// irisClassifications.append([0.0, 1.0, 0.0])
292+
// } else {
293+
// irisClassifications.append([0.0, 0.0, 1.0])
294+
// }
295+
// irisSpecies.append(species)
296+
// }
297+
// normalizeByFeatureScaling(dataset: &irisParameters)
298+
// return (irisParameters, irisClassifications, irisSpecies)
299+
//}
300+
//
301+
//let (irisParameters, irisClassifications, irisSpecies) = parseIrisCSV()
302+
//
303+
//let irisNetwork: Network = Network(layerStructure: [4, 6, 3], learningRate: 0.3)
304+
//
305+
//func irisInterpretOutput(output: [Double]) -> String {
306+
// if output.max()! == output[0] {
307+
// return "Iris-setosa"
308+
// } else if output.max()! == output[1] {
309+
// return "Iris-versicolor"
310+
// } else {
311+
// return "Iris-virginica"
312+
// }
313+
//}
314+
//
315+
//// train over first 140 irises in data set 20 times
316+
//let irisTrainers = Array(irisParameters[0..<140])
317+
//let irisTrainersCorrects = Array(irisClassifications[0..<140])
318+
//for _ in 0..<20 {
319+
// irisNetwork.train(inputs: irisTrainers, expecteds: irisTrainersCorrects, printError: false)
320+
//}
321+
//
322+
//// test over the last 10 of the irses in the data set
323+
//let irisTesters = Array(irisParameters[140..<150])
324+
//let irisTestersCorrects = Array(irisSpecies[140..<150])
325+
//let irisResults = irisNetwork.validate(inputs: irisTesters, expecteds: irisTestersCorrects, interpretOutput: irisInterpretOutput)
326+
//print("\(irisResults.correct) correct of \(irisResults.total) = \(irisResults.percentage * 100)%")
327+
328+
/// Wine Test
329+
330+
func parseWineCSV() -> (parameters: [[Double]], classifications: [[Double]], species: [Int]) {
331+
let urlpath = Bundle.main.path(forResource: "wine", ofType: "csv")
274332
let url = URL(fileURLWithPath: urlpath!)
275333
let csv = try! String.init(contentsOf: url)
276334
let lines = csv.components(separatedBy: "\n")
277-
var irisParameters: [[Double]] = [[Double]]()
278-
var irisClassifications: [[Double]] = [[Double]]()
279-
var irisSpecies: [String] = [String]()
280-
335+
var wineParameters: [[Double]] = [[Double]]()
336+
var wineClassifications: [[Double]] = [[Double]]()
337+
var wineSpecies: [Int] = [Int]()
338+
281339
let shuffledLines = lines.shuffled()
282340
for line in shuffledLines {
283-
if line == "" { continue }
341+
if line == "" { continue } // skip blank lines
284342
let items = line.components(separatedBy: ",")
285-
let parameters = items[0...3].map{ Double(0ドル)! }
286-
irisParameters.append(parameters)
287-
let species = items[4]
288-
if species == "Iris-setosa" {
289-
irisClassifications.append([1.0, 0.0, 0.0])
290-
} else if species == "Iris-versicolor" {
291-
irisClassifications.append([0.0, 1.0, 0.0])
343+
let parameters = items[1...13].map{ Double(0ドル)! }
344+
wineParameters.append(parameters)
345+
let species = Int(items[0])!
346+
if species == 1 {
347+
wineClassifications.append([1.0, 0.0, 0.0])
348+
} else if species == 2 {
349+
wineClassifications.append([0.0, 1.0, 0.0])
292350
} else {
293-
irisClassifications.append([0.0, 0.0, 1.0])
351+
wineClassifications.append([0.0, 0.0, 1.0])
294352
}
295-
irisSpecies.append(species)
353+
wineSpecies.append(species)
296354
}
297-
normalizeByFeatureScaling(dataset: &irisParameters)
298-
return (irisParameters, irisClassifications, irisSpecies)
355+
normalizeByFeatureScaling(dataset: &wineParameters)
356+
return (wineParameters, wineClassifications, wineSpecies)
299357
}
300358

301-
let (irisParameters, irisClassifications, irisSpecies) = parseIrisCSV()
359+
let (wineParameters, wineClassifications, wineSpecies) = parseWineCSV()
302360

303-
varirisNetwork: Network = Network(layerStructure: [4,6,3], learningRate: 0.3)
361+
letwineNetwork: Network = Network(layerStructure: [13,7,3], learningRate: 0.9)
304362

305-
func irisInterpretOutput(output: [Double]) -> String {
363+
func wineInterpretOutput(output: [Double]) -> Int {
306364
if output.max()! == output[0] {
307-
return "Iris-setosa"
365+
return 1
308366
} else if output.max()! == output[1] {
309-
return "Iris-versicolor"
367+
return 2
310368
} else {
311-
return "Iris-virginica"
369+
return 3
312370
}
313371
}
314372

315-
// train over first 140 irises in data set 20 times
316-
let irisTrainers = Array(irisParameters[0..<140])
317-
let irisTrainersCorrects = Array(irisClassifications[0..<140])
318-
for _ in 0..<20 {
319-
irisNetwork.train(inputs: irisTrainers, expecteds: irisTrainersCorrects, printError: false)
373+
// train over the first 150 samples 5 times
374+
let wineTrainers = Array(wineParameters.dropLast(28))
375+
let wineTrainersCorrects = Array(wineClassifications.dropLast(28))
376+
for _ in 0..<5 {
377+
wineNetwork.train(inputs: wineTrainers, expecteds: wineTrainersCorrects, printError: false)
320378
}
321379

322-
// test over the last 10 of the irses in the data set
323-
let irisTesters = Array(irisParameters[140..<150])
324-
let irisTestersCorrects = Array(irisSpecies[140..<150])
325-
let irisResults = irisNetwork.validate(inputs: irisTesters, expecteds: irisTestersCorrects, interpretOutput: irisInterpretOutput)
326-
print("\(irisResults.correct) correct of \(irisResults.total) = \(irisResults.percentage * 100)%")
327-
328-
/// Wine Test
329-
330-
//var network: Network = Network(layerStructure: [13,7,3], learningRate: 7.0)
331-
//// for training
332-
//var wineParameters: [[Double]] = [[Double]]()
333-
//var wineClassifications: [[Double]] = [[Double]]()
334-
//// for testing/validation
335-
//var wineSamples: [[Double]] = [[Double]]()
336-
//var wineCultivars: [Int] = [Int]()
337-
//
338-
//func parseWineCSV() {
339-
// let myBundle = Bundle.main
340-
// let urlpath = myBundle.path(forResource: "wine", ofType: "csv")
341-
// let url = URL(fileURLWithPath: urlpath!)
342-
// let csv = try! String.init(contentsOf: url)
343-
// let lines = csv.components(separatedBy: "\n")
344-
//
345-
// let shuffledLines = lines.shuffled()
346-
// for line in shuffledLines {
347-
// if line == "" { continue }
348-
// let items = line.components(separatedBy: ",")
349-
// let parameters = items[1...13].map{ Double(0ドル)! }
350-
// wineParameters.append(parameters)
351-
// let species = Int(items[0])!
352-
// if species == 1 {
353-
// wineClassifications.append([1.0, 0.0, 0.0])
354-
// } else if species == 2 {
355-
// wineClassifications.append([0.0, 1.0, 0.0])
356-
// } else {
357-
// wineClassifications.append([0.0, 0.0, 1.0])
358-
// }
359-
// wineCultivars.append(species)
360-
// }
361-
// normalizeByFeatureScaling(dataset: &wineParameters)
362-
// wineSamples = Array(wineParameters.dropFirst(150))
363-
// wineCultivars = Array(wineCultivars.dropFirst(150))
364-
// wineParameters = Array(wineParameters.dropLast(28))
365-
//}
366-
//
367-
//func interpretOutput(output: [Double]) -> Int {
368-
// if output.max()! == output[0] {
369-
// return 1
370-
// } else if output.max()! == output[1] {
371-
// return 2
372-
// } else {
373-
// return 3
374-
// }
375-
//}
376-
//
377-
//parseWineCSV()
378-
//// train over entire data set 5 times
379-
//for _ in 0..<5 {
380-
// network.train(inputs: wineParameters, expecteds: wineClassifications, printError: false)
381-
//}
382-
//
383-
//let results = network.validate(inputs: wineSamples, expecteds: wineCultivars, interpretOutput: interpretOutput)
384-
//print("\(results.correct) correct of \(results.total) = \(results.percentage * 100)%")
380+
let wineTesters = Array(wineParameters.dropFirst(150))
381+
let wineTestersCorrects = Array(wineSpecies.dropFirst(150))
382+
let results = wineNetwork.validate(inputs: wineTesters, expecteds: wineTestersCorrects, interpretOutput: wineInterpretOutput)
383+
print("\(results.correct) correct of \(results.total) = \(results.percentage * 100)%")
385384

386385
//: [Next](@next)
387386

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /