5
\$\begingroup\$

As an effort to teach myself Swift as well as to get familiar with machine learning algorithms, I've been trying to implement common algorithms, starting with a Random Forest. This is, for the moment just one of the tree, but I have been trying to implement it just from the theory, without looking at pseudo-code, in order to really understand the process.

It was harder than I thought, due to the lack of convenience statistical and data-related functions and methods that are common in R or Python. This code seems to work and builds correct trees, although some of the methods I use, (lots of mapping...) sometimes seem a bit convoluted.

First, a node to store the splits:

import Foundation
class Node: CustomStringConvertible
{
 let isTerminal:Bool
 var value:Double? = nil
 var leftChild:Node? = nil
 var rightChild:Node? = nil
 var variable:Int? = nil
 var description: String
 var result:Int? = nil
 init(value:Double, variable:Int)
 {
 //Split node
 self.value = value
 self.isTerminal = false
 self.variable = variable
 self.description = "\(variable): \(value)"
 }
 init(result: Int)
 {
 //Terminal node
 self.result = result
 self.isTerminal = true
 self.description = "Terminal node: \(result)\n"
 }
 func addLeftChild(child:Node)
 {
 self.leftChild = child
 self.description += " L -> \(child)\n"
 }
 func addRightChild(child:Node)
 {
 self.rightChild = child
 self.description += " R -> \(child)\n"
 }
 //For prediction
 func getChild(x:Double) -> Node
 {
 if x < value
 {
 return leftChild!
 }
 else
 {
 return rightChild!
 }
 }
}

Some data taken from the famous Iris dataset:

let x = [[6.8,6.2,5.9,5.9,5.7,7.7,4.5,5.8,5,6.3,5.1,4.3,5.7,4.9,7],
 [3,3.4,3.2,3,2.6,3,2.3,2.7,2.3,2.5,3.8,3,3.8,3,3.2],
 [5.5,5.4,4.8,5.1,3.5,6.1,1.3,4.1,3.3,4.9,1.9,1.1,1.7,1.4,4.7],
 [2.1,2.3,1.8,1.8,1,2.3,0.3,1,1,1.5,0.4,0.1,0.3,0.2,1.4]]
let y = [3,3,2,3,2,3,1,2,2,2,1,1,1,1,2]

The impurity criterion (gini) function:

func giniImpurity(y:[Int]) -> Double
{
 let len = Double(y.count)
 let countedSet = NSCountedSet(array: y)
 let squaredProbs = countedSet.map { (c) -> Double in
 let cnt = Double(countedSet.countForObject(c)) / len
 return cnt * cnt
 }
 return 1 - squaredProbs.reduce(0, combine: +)
}

Iterating through X to find the best split. I am not certain of the way to sort and iterate through the values...

func findBestSplit(x:[Double], y:[Int]) -> (bestVal: Double, maxDelta: Double)
{
 // Find the indices that sort x
 let xSortedIndices = x.indices.sort { x[0ドル] > x[1ドル] }
 let xSorted = xSortedIndices.map { x[0ドル] }
 //Sort y according to those
 let ySorted = xSortedIndices.map { y[0ドル] }
 var bestGin:Double = 0
 let origini = giniImpurity(y)
 var bestSplit = 0
 //Iterate through all values of x to find the best split
 for i in 0..<ySorted.count
 {
 let left = Array(ySorted[0..<i])
 let right = Array(ySorted[i..<ySorted.count])
 let gini = (giniImpurity(left) * Double(left.count) + giniImpurity(right) * Double(right.count)) / Double(y.count)
 let deltaGini = origini - gini
 if deltaGini > bestGin
 {
 bestGin = deltaGini
 bestSplit = i
 }
 }
 return (bestVal: xSorted[bestSplit], maxDelta: bestGin)
}

And finally the tree building:

func buildTree(x: [[Double]], y:[Int]) -> Node
{
 var bestVar:Int = 0
 var bestGini:Double = 0
 var bestVal:Double = 0
 // Apply the findBestSplit on all columns to find the best split among those
 for col in 0..<x.count
 {
 let res = findBestSplit(x[col], y: y)
 if res.maxDelta > bestGini
 {
 bestVar = col
 bestGini = res.maxDelta
 bestVal = res.bestVal
 }
 }
 let node = Node(value: bestVal, variable: bestVar)
 //Split X & Y according to the split found
 let rightIndices = x[bestVar].indices.filter { x[bestVar][0ドル] > bestVal}
 let leftIndices = x[bestVar].indices.filter { x[bestVar][0ドル] <= bestVal}
 let rightX = x.map { (col) -> [Double] in
 return rightIndices.map {col[0ドル]}
 }
 let leftX = x.map { (col) -> [Double] in
 return leftIndices.map {col[0ドル]}
 }
 let rightY = rightIndices.map {y[0ドル]}
 let leftY = leftIndices.map {y[0ドル]}
 // If pure enough, add terminal node, else recurse
 if giniImpurity(leftY) < 0.1
 {
 let countedSet = NSCountedSet(array: leftY)
 let counts = countedSet.map { countedSet.countForObject(0ドル) }
 let result = Array(countedSet)[counts.indexOf(counts.maxElement()!)!] as! Int
 node.addLeftChild(Node(result: result))
 }
 else
 {
 node.addLeftChild(buildTree(leftX, y: leftY))
 }
 if giniImpurity(rightY) < 0.1 
 {
 let countedSet = NSCountedSet(array: rightY)
 let counts = countedSet.map { countedSet.countForObject(0ドル) }
 let result = Array(countedSet)[counts.indexOf(counts.maxElement()!)!] as! Int
 node.addRightChild(Node(result: result))
 }
 else
 {
 node.addRightChild(buildTree(rightX, y: rightY))
 }
 return node
}
let root = buildTree(x, y: y)

I would love to have feedback on this, either on the Swift style or on the correctness of the algorithm.

Jamal
35.2k13 gold badges134 silver badges238 bronze badges
asked Apr 2, 2016 at 1:36
\$\endgroup\$
1
  • \$\begingroup\$ Welcome to Code Review! Good job on your first question. \$\endgroup\$ Commented Apr 2, 2016 at 2:25

1 Answer 1

6
\$\begingroup\$

So, let's focus on your Node class at the top. I started off, without looking at anything else, and just swiftlinting it.

You have a 55 line file with 20 violations.

Fortunately, 19 of them are autocorrectable. Seven of the violations are for opening brace placement. In Swift, we prefer our opening brace to be on the same line rather than new line. Another twelve of the violations are for your colon spacing. When declaring variables or parameters, the colon should appear next to the variable name without a space, followed by a space, and then the type.

Running

$ swiftlint autocorrect

Results in a file which looks like this:

import Foundation
class Node: CustomStringConvertible {
 let isTerminal: Bool
 var value: Double? = nil
 var leftChild: Node? = nil
 var rightChild: Node? = nil
 var variable: Int? = nil
 var description: String
 var result: Int? = nil
 init(value: Double, variable: Int) {
 //Split node
 self.value = value
 self.isTerminal = false
 self.variable = variable
 self.description = "\(variable): \(value)"
 }
 init(result: Int) {
 //Terminal node
 self.result = result
 self.isTerminal = true
 self.description = "Terminal node: \(result)\n"
 }
 func addLeftChild(child: Node) {
 self.leftChild = child
 self.description += " L -> \(child)\n"
 }
 func addRightChild(child: Node) {
 self.rightChild = child
 self.description += " R -> \(child)\n"
 }
 //For prediction
 func getChild(x: Double) -> Node {
 if x < value {
 return leftChild!
 }
 else {
 return rightChild!
 }
 }
}

But swiftlint still identifies one major error.

Node.swift:39:19: error: Variable Name Violation: Variable name should be between 3 and 40 characters long: 'x' (variable_name)

The parameter name to your getChild function is unacceptably short. x does not serve as a descriptive variable name, and it makes the code hard to read.


But... this linting was with swiftlint's default configuration, which is FAR too lax in my opinion, and egregiously, it lets you get away with force unwrapping!

If I pull in the configuration file that I use *, swiftlint finds six new serious violations.

Node.swift:14:1: error: Space After Comment Violation: There should be a space after // (comments_space)
Node.swift:22:1: error: Space After Comment Violation: There should be a space after // (comments_space)
Node.swift:38:1: error: Space After Comment Violation: There should be a space after // (comments_space)
Node.swift:3:1: error: Empty First Line Violation: There should be an empty line after a declaration (empty_first_line)
Node.swift:41:29: error: Force Unwrapping Violation: Force unwrapping should be avoided. (force_unwrapping)
Node.swift:44:30: error: Force Unwrapping Violation: Force unwrapping should be avoided. (force_unwrapping)

The first four violations here, I believe, are pretty self explanatory.

The last two are also pretty explanatory I believe... but it probably requires a lot more thinking on your part in terms of how you're going to handle your child nodes.


And that's where we can start getting into an actual discussion on the logic here.

First, let's handle the simplest problem you have...

var description: String

Anyone can edit this property. Its setter has the same scope as its getter, and really, it shouldn't be writable at all. In fact, it should simply be a computed property which is calculated when it is called.

var description: String {
 return "" // build description of the object
}

Some of your comments make it clear that you're thinking of two different sorts of objects...

init(value: Double, variable: Int) {
 //Split node
init(result: Int) {
 //Terminal node

Importantly, the getChild method doesn't even make sense to call on a terminal node. And the isTerminal property becomes completely obsolete if we just make these two different nodes into two different types.

It also allows us to get rid of all of the optional values (unless it makes sense for a split node to have just one child).

So, I'd propose something approximately looking like this...

protocol Node: CustomStringConvertible {}
class TerminalNode: Node {
 let result: Int
 var description: String {
 // build & return description
 }
 init(result: Int) {
 self.result = result
 }
}
class SplitNode: Node {
 let value: Double
 let variable: Int
 let leftChild: Node
 let rightChild: Node
 var description: String {
 // build and return description string
 }
 init(value: Double, variable: Int, leftChild: Node, rightChild: Node) {
 self.value = value
 self.variable = variable
 self.leftChild = leftChild
 self.rightChild = rightChild
 }
 func childForValue(value: Double) -> Node {
 return value < self.value ? leftChild : rightChild
 }
}
answered Apr 2, 2016 at 3:11
\$\endgroup\$
2
  • \$\begingroup\$ Very informative answer, thank you. I should have indeed spent more time thinking about the design of the node... For the position of the braces, I thought it was a matter of taste. As I get confused when I don't see my opening brace on the same column... I thought indeed that "x" and "y" as names were too short, but I wasn't sure what to call them. Most machine learning libraries refer to predictor and response variables as x and y, and I didn't want to breaks style. \$\endgroup\$ Commented Apr 2, 2016 at 8:03
  • \$\begingroup\$ Even if we consider brace style to be a matter of taste, what is not a matter of taste is that it should be done consistently, which you have not done. Consider your line let rightX = x.map { (col) -> [Double] in, you've broken the consistency. But it's when you get to the functional bits of Swift like this that we realize that the only way to be consistent and look good in all cases is with same line braces. \$\endgroup\$ Commented Apr 2, 2016 at 14:10

Your Answer

Draft saved
Draft discarded

Sign up or log in

Sign up using Google
Sign up using Email and Password

Post as a guest

Required, but never shown

Post as a guest

Required, but never shown

By clicking "Post Your Answer", you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.