Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 41c8d05

Browse files
committed
cleaned up iris example
1 parent 8d5f17d commit 41c8d05

File tree

1 file changed

+24
-24
lines changed
  • Classic Computer Science Problems in Swift.playground/Pages/Chapter 7.xcplaygroundpage

1 file changed

+24
-24
lines changed

‎Classic Computer Science Problems in Swift.playground/Pages/Chapter 7.xcplaygroundpage/Contents.swift

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -255,31 +255,28 @@ class Network {
255255
/// MARK: Normalization
256256

257257
/// assumes all rows are of equal length
258-
/// and divide each column by its max throughout the data set
259-
/// for that column
260-
func normalizeByColumnMax( dataset:inout [[Double]]) {
258+
/// and feature scale each column to be in the range 0 – 1
259+
func normalizeByFeatureScaling(dataset: inout [[Double]]) {
261260
for colNum in 0..<dataset[0].count {
262261
let column = dataset.map { 0ドル[colNum] }
263262
let maximum = column.max()!
263+
let minimum = column.min()!
264264
for rowNum in 0..<dataset.count {
265-
dataset[rowNum][colNum] = dataset[rowNum][colNum] /maximum
265+
dataset[rowNum][colNum] = (dataset[rowNum][colNum] - minimum)/(maximum- minimum)
266266
}
267267
}
268268
}
269269

270270
// MARK: Iris Test
271271

272-
var network: Network = Network(layerStructure: [4,6,3], learningRate: 0.3)
273-
var irisParameters: [[Double]] = [[Double]]()
274-
var irisClassifications: [[Double]] = [[Double]]()
275-
var irisSpecies: [String] = [String]()
276-
277-
func parseIrisCSV() {
278-
let myBundle = Bundle.main
279-
let urlpath = myBundle.path(forResource: "iris", ofType: "csv")
272+
func parseIrisCSV() -> (parameters: [[Double]], classifications: [[Double]], species: [String]) {
273+
let urlpath = Bundle.main.path(forResource: "iris", ofType: "csv")
280274
let url = URL(fileURLWithPath: urlpath!)
281275
let csv = try! String.init(contentsOf: url)
282276
let lines = csv.components(separatedBy: "\n")
277+
var irisParameters: [[Double]] = [[Double]]()
278+
var irisClassifications: [[Double]] = [[Double]]()
279+
var irisSpecies: [String] = [String]()
283280

284281
let shuffledLines = lines.shuffled()
285282
for line in shuffledLines {
@@ -297,10 +294,15 @@ func parseIrisCSV() {
297294
}
298295
irisSpecies.append(species)
299296
}
300-
normalizeByColumnMax(dataset: &irisParameters)
297+
normalizeByFeatureScaling(dataset: &irisParameters)
298+
return (irisParameters, irisClassifications, irisSpecies)
301299
}
302300

303-
func interpretOutput(output: [Double]) -> String {
301+
let (irisParameters, irisClassifications, irisSpecies) = parseIrisCSV()
302+
303+
var irisNetwork: Network = Network(layerStructure: [4,6,3], learningRate: 0.3)
304+
305+
func irisInterpretOutput(output: [Double]) -> String {
304306
if output.max()! == output[0] {
305307
return "Iris-setosa"
306308
} else if output.max()! == output[1] {
@@ -310,20 +312,18 @@ func interpretOutput(output: [Double]) -> String {
310312
}
311313
}
312314

313-
// Put setup code here. This method is called before the invocation of each test method in the class.
314-
parseIrisCSV()
315315
// train over first 140 irises in data set 20 times
316-
let trainers = Array(irisParameters[0..<140])
317-
let trainersCorrects = Array(irisClassifications[0..<140])
316+
let irisTrainers = Array(irisParameters[0..<140])
317+
let irisTrainersCorrects = Array(irisClassifications[0..<140])
318318
for _ in 0..<20 {
319-
network.train(inputs: trainers, expecteds: trainersCorrects, printError: false)
319+
irisNetwork.train(inputs: irisTrainers, expecteds: irisTrainersCorrects, printError: false)
320320
}
321321

322322
// test over the last 10 of the irses in the data set
323-
let testers = Array(irisParameters[140..<150])
324-
let testersCorrects = Array(irisSpecies[140..<150])
325-
let results = network.validate(inputs: testers, expecteds: testersCorrects, interpretOutput: interpretOutput)
326-
print("\(results.correct) correct of \(results.total) = \(results.percentage * 100)%")
323+
let irisTesters = Array(irisParameters[140..<150])
324+
let irisTestersCorrects = Array(irisSpecies[140..<150])
325+
let irisResults = irisNetwork.validate(inputs: irisTesters, expecteds: irisTestersCorrects, interpretOutput: irisInterpretOutput)
326+
print("\(irisResults.correct) correct of \(irisResults.total) = \(irisResults.percentage * 100)%")
327327

328328
/// Wine Test
329329

@@ -358,7 +358,7 @@ print("\(results.correct) correct of \(results.total) = \(results.percentage * 1
358358
// }
359359
// wineCultivars.append(species)
360360
// }
361-
// normalizeByColumnMax(dataset: &wineParameters)
361+
// normalizeByFeatureScaling(dataset: &wineParameters)
362362
// wineSamples = Array(wineParameters.dropFirst(150))
363363
// wineCultivars = Array(wineCultivars.dropFirst(150))
364364
// wineParameters = Array(wineParameters.dropLast(28))

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /