@@ -255,31 +255,28 @@ class Network {
255
255
/// MARK: Normalization
256
256
257
257
/// assumes all rows are of equal length
258
- /// and divide each column by its max throughout the data set
259
- /// for that column
260
- func normalizeByColumnMax( dataset: inout [ [ Double ] ] ) {
258
+ /// and feature scale each column to be in the range 0 – 1
259
+ func normalizeByFeatureScaling( dataset: inout [ [ Double ] ] ) {
261
260
for colNum in 0 ..< dataset [ 0 ] . count {
262
261
let column = dataset. map { 0ドル [ colNum] }
263
262
let maximum = column. max ( ) !
263
+ let minimum = column. min ( ) !
264
264
for rowNum in 0 ..< dataset. count {
265
- dataset [ rowNum] [ colNum] = dataset [ rowNum] [ colNum] / maximum
265
+ dataset [ rowNum] [ colNum] = ( dataset [ rowNum] [ colNum] - minimum ) / ( maximum- minimum )
266
266
}
267
267
}
268
268
}
269
269
270
270
// MARK: Iris Test
271
271
272
- var network : Network = Network ( layerStructure: [ 4 , 6 , 3 ] , learningRate: 0.3 )
273
- var irisParameters : [ [ Double ] ] = [ [ Double] ] ( )
274
- var irisClassifications : [ [ Double ] ] = [ [ Double] ] ( )
275
- var irisSpecies : [ String ] = [ String] ( )
276
-
277
- func parseIrisCSV( ) {
278
- let myBundle = Bundle . main
279
- let urlpath = myBundle. path ( forResource: " iris " , ofType: " csv " )
272
+ func parseIrisCSV( ) -> ( parameters: [ [ Double ] ] , classifications: [ [ Double ] ] , species: [ String ] ) {
273
+ let urlpath = Bundle . main. path ( forResource: " iris " , ofType: " csv " )
280
274
let url = URL ( fileURLWithPath: urlpath!)
281
275
let csv = try ! String . init ( contentsOf: url)
282
276
let lines = csv. components ( separatedBy: " \n " )
277
+ var irisParameters : [ [ Double ] ] = [ [ Double] ] ( )
278
+ var irisClassifications : [ [ Double ] ] = [ [ Double] ] ( )
279
+ var irisSpecies : [ String ] = [ String] ( )
283
280
284
281
let shuffledLines = lines. shuffled ( )
285
282
for line in shuffledLines {
@@ -297,10 +294,15 @@ func parseIrisCSV() {
297
294
}
298
295
irisSpecies. append ( species)
299
296
}
300
- normalizeByColumnMax ( dataset: & irisParameters)
297
+ normalizeByFeatureScaling ( dataset: & irisParameters)
298
+ return ( irisParameters, irisClassifications, irisSpecies)
301
299
}
302
300
303
- func interpretOutput( output: [ Double ] ) -> String {
301
+ let ( irisParameters, irisClassifications, irisSpecies) = parseIrisCSV ( )
302
+
303
+ var irisNetwork : Network = Network ( layerStructure: [ 4 , 6 , 3 ] , learningRate: 0.3 )
304
+
305
+ func irisInterpretOutput( output: [ Double ] ) -> String {
304
306
if output. max ( ) ! == output [ 0 ] {
305
307
return " Iris-setosa "
306
308
} else if output. max ( ) ! == output [ 1 ] {
@@ -310,20 +312,18 @@ func interpretOutput(output: [Double]) -> String {
310
312
}
311
313
}
312
314
313
- // Put setup code here. This method is called before the invocation of each test method in the class.
314
- parseIrisCSV ( )
315
315
// train over first 140 irises in data set 20 times
316
- let trainers = Array ( irisParameters [ 0 ..< 140 ] )
317
- let trainersCorrects = Array ( irisClassifications [ 0 ..< 140 ] )
316
+ let irisTrainers = Array ( irisParameters [ 0 ..< 140 ] )
317
+ let irisTrainersCorrects = Array ( irisClassifications [ 0 ..< 140 ] )
318
318
for _ in 0 ..< 20 {
319
- network . train ( inputs: trainers , expecteds: trainersCorrects , printError: false )
319
+ irisNetwork . train ( inputs: irisTrainers , expecteds: irisTrainersCorrects , printError: false )
320
320
}
321
321
322
322
// test over the last 10 of the irses in the data set
323
- let testers = Array ( irisParameters [ 140 ..< 150 ] )
324
- let testersCorrects = Array ( irisSpecies [ 140 ..< 150 ] )
325
- let results = network . validate ( inputs: testers , expecteds: testersCorrects , interpretOutput: interpretOutput )
326
- print ( " \( results . correct) correct of \( results . total) = \( results . percentage * 100 ) % " )
323
+ let irisTesters = Array ( irisParameters [ 140 ..< 150 ] )
324
+ let irisTestersCorrects = Array ( irisSpecies [ 140 ..< 150 ] )
325
+ let irisResults = irisNetwork . validate ( inputs: irisTesters , expecteds: irisTestersCorrects , interpretOutput: irisInterpretOutput )
326
+ print ( " \( irisResults . correct) correct of \( irisResults . total) = \( irisResults . percentage * 100 ) % " )
327
327
328
328
/// Wine Test
329
329
@@ -358,7 +358,7 @@ print("\(results.correct) correct of \(results.total) = \(results.percentage * 1
358
358
// }
359
359
// wineCultivars.append(species)
360
360
// }
361
- // normalizeByColumnMax (dataset: &wineParameters)
361
+ // normalizeByFeatureScaling (dataset: &wineParameters)
362
362
// wineSamples = Array(wineParameters.dropFirst(150))
363
363
// wineCultivars = Array(wineCultivars.dropFirst(150))
364
364
// wineParameters = Array(wineParameters.dropLast(28))
0 commit comments