I am an experienced Java developer (12+ years) and have recently switched to Scala and I love it. However I feel not comfy yet and I have a feeling that I might use to many paradigms from the good old Java days.
That's why I would like to start with some simple code I wrote while viewing a Stanford Seminar on Youtube with the Topic "Deep Learning for Dummies". I tried to simulate the results presented in minute ~15:40.
Basically the code simulates vibration of atoms in the Ising Model in parallel, and gives a brief summary of how often a particular stable state has been reached.
Everything works as expected and it proves the numbers shown in the presentation.
No code is missing below in order to work and I did not use third party libs.
deeplearning/IsingModelSmall.scala
package deeplearning
import java.util.concurrent.atomic.AtomicInteger
import deeplearning.AtomState.AtomState
import deeplearning.AtomState._
/**
* Created by Julian Liebl on 25.11.15.
*
* All inspiration taken from youtu.be/hvIptUuUCdU. This class reproduces and proves the result shown in minute ~15:40.
*/
class IsingModelSmall {
case class MinMax(val min:Double, val max:Double)
val x1 = new Atom(Up) .fuse(-50 , new Atom(Up)) .fuse(99, new Atom(Up))
val x2 = new Atom(Down) .fuse(-50 , new Atom(Up)) .fuse(99, new Atom(Up))
val x3 = new Atom(Up) .fuse(-50 , new Atom(Down)) .fuse(99, new Atom(Up))
val x4 = new Atom(Up) .fuse(-50 , new Atom(Up)) .fuse(99, new Atom(Down))
val x5 = new Atom(Down) .fuse(-50 , new Atom(Down)) .fuse(99, new Atom(Up))
val x6 = new Atom(Up) .fuse(-50 , new Atom(Down)) .fuse(99, new Atom(Down))
val x7 = new Atom(Down) .fuse(-50 , new Atom(Up)) .fuse(99, new Atom(Down))
val x8 = new Atom(Down) .fuse(-50 , new Atom(Down)) .fuse(99, new Atom(Down))
/**
* Calculates the stable state of a Ising Model according to youtu.be/hvIptUuUCdU.
* It takes a random atom from the model as parameter and parses from there all atoms and sub(n) atoms it is
* connected to.
*
*
* Here is an example how the stable state is calculated:
*
* Model = a1(Up) <- w1(-50) -> a2(Down) <- w2(99) -> a3(Down)
* => x = -((a1 * w1 * a2) + (a2 * w2 * a3))
* => x = -((1 * -50 * -1) + (-1 * 99 * -1))
* => x = -(50 + 99)
* => x = -
*
* @param atom A random atom form the model. Needs at least one connection. Otherwise stable state will be zero.
* @return stable state value
*/
def calcStableState(atom:Atom, touchedAtoms:Set[Atom] = Set()): Double ={
var sum:Double = 0
val a1v = getAtomStateValue(atom.atomState)
atom.getConnections().foreach(connection => {
val connectedAtom = connection.connectedAtom
if(!(touchedAtoms contains connectedAtom)){
val a2v = getAtomStateValue(connectedAtom.atomState)
sum += a1v * a2v * connection.weight
sum += calcStableState(connectedAtom, touchedAtoms + atom)
}
})
- sum
}
/**
* Retrieves the min and max weight for all atom connections in a model.
* It takes a random atom from the model as parameter and parses from there all connections and sub(n) connections.
*
*
* Example:
*
* Model = a1(Up) <- w1(-50) -> a2(Down) <- w2(99) -> a3(Down) <- w3(20) -> a4(Up)
* => min = -50
* => max = 99
*
* @param atom A random atom form the model. Needs at least one connection. Otherwise min and max will be zero.
* @return min and max weight
*/
def getMinMaxWeight(atom:Atom, touchedAtoms:Set[Atom] = Set()): MinMax ={
var minMax:MinMax = MinMax(0,0)
atom.getConnections().foreach(connection => {
val connectedAtom = connection.connectedAtom
if(!(touchedAtoms contains connectedAtom)){
val currentWeight = connection.weight
if (currentWeight < minMax.min){
minMax = minMax.copy(min = currentWeight)
}
else if (currentWeight > minMax.max) {
minMax = minMax.copy(max = currentWeight)
}
val provisionalMinMax = getMinMaxWeight(connectedAtom, touchedAtoms + atom)
if(provisionalMinMax.min < minMax.min) minMax = minMax.copy(min = provisionalMinMax.min)
if(provisionalMinMax.max > minMax.max) minMax = minMax.copy(max = provisionalMinMax.max)
}
})
minMax
}
/**
* Atom vibration simulation.
* It takes a random atom from the model as parameter and parses from there all connections. Simulating a random
* initial atom state and regarding probability of all connections and sub connections. Resulting in the same
* connections but may be with different states then before.
*
* @param atom A random atom form the model. Needs at least one connection. Otherwise the given atom will just be
* returned.
* @return The new atom with the same connections but eventually different states.
*/
def vibrate(atom:Atom): Atom ={
var touchedAtoms:Set[Atom] = scala.collection.immutable.Set()
val currentMinMaxWeight = getMinMaxWeight(atom)
val minWeight = currentMinMaxWeight.min
val maxWeight = currentMinMaxWeight.max
val weightRange = if(Math.abs(minWeight) > Math.abs(maxWeight)) Math.abs(minWeight) else Math.abs(maxWeight)
val scaledWeightRange = weightRange * 1.2
val random = scala.util.Random
def vibrateInner(innerAtom:Atom, currentAtomState:AtomState):Atom ={
val newAtom = new Atom(currentAtomState)
touchedAtoms += newAtom
innerAtom.getConnections().foreach(connection => {
val connectedAtom = connection.connectedAtom
connectedAtom.removeConnection(innerAtom)
if(!(touchedAtoms contains connectedAtom)){
val weight = connection.weight
val probability = Math.abs(weight) / scaledWeightRange
val randomDouble = random.nextDouble()
val isFollowing = probability - randomDouble >= 0
if(weight != 0){
var connectedAtomState:AtomState = null
if(weight < 0) {
connectedAtomState = if (isFollowing) getOppositeState(currentAtomState) else currentAtomState
}else{
connectedAtomState = if (isFollowing) currentAtomState else getOppositeState(currentAtomState)
}
connectedAtom.atomState = connectedAtomState
newAtom.fuse(connection.weight, vibrateInner(connectedAtom, connectedAtomState))
}else{
println("Error: Weight should never be 0!")
return newAtom
}
}
})
newAtom
}
vibrateInner(atom, getRandomAtomState())
}
}
object IsingModelSmall{
def main(args: Array[String]) {
val model = new IsingModelSmall
println("E(x1,w) = " + model.calcStableState(model.x1))
println("E(x2,w) = " + model.calcStableState(model.x2))
println("E(x3,w) = " + model.calcStableState(model.x3))
println("E(x4,w) = " + model.calcStableState(model.x4))
println("E(x5,w) = " + model.calcStableState(model.x5))
println("E(x6,w) = " + model.calcStableState(model.x6))
println("E(x7,w) = " + model.calcStableState(model.x7))
println("E(x8,w) = " + model.calcStableState(model.x8))
println(model.getMinMaxWeight(model.x1))
val vibrationLoopCount:Int = 10000
val atomicLoopIndex = new AtomicInteger()
println("Simulating vibration of atom " + vibrationLoopCount + " times.")
val statesToCount = (1 to vibrationLoopCount).toTraversable.par.map(loopIndex => {
val vibratedX1 = model.vibrate(model.x1)
if(atomicLoopIndex.incrementAndGet() % 10000 == 0) print("\r" + atomicLoopIndex.get())
model.calcStableState(vibratedX1)
}).groupBy(identity).mapValues(_.size)
println("\r" + atomicLoopIndex.get())
val states = statesToCount.keySet.toList.sorted
states.foreach(state => println(state + "\t: " + statesToCount.get(state).get))
}
}
deeplearning/Atom.scala
package deeplearning
import deeplearning.AtomState.AtomState
import scala.collection.mutable.ListBuffer
/**
* Created by Julian Liebl on 26.11.15.
*
* Class which represents an atom in the Ising Model.
*/
class Atom(var atomState: AtomState) {
var connections:ListBuffer[AtomConnection] = ListBuffer()
def addConnection(atomConnection: AtomConnection): Unit ={
connections += atomConnection
}
def removeConnection(atomConnection: AtomConnection): Unit ={
connections -= atomConnection
}
def removeConnection(atom:Atom): Unit ={
connections = connections.filter(connection => !(connection.connectedAtom equals atom))
}
def removeConnections(atoms:Seq[Atom]): Unit ={
connections = connections.filter(connection => !(atoms contains connection.connectedAtom))
}
def getConnections(): Seq[AtomConnection] ={
connections
}
/**
* Creates a weighted connection between the atom and anotherAtom. Returns the other atom in order to be able to
* chain the creation of a model.
*
* @param weight weight of the connection
* @param otherAtom other atom
* @return other atom
*/
def fuse(weight:Double, otherAtom:Atom): Atom ={
AtomConnection.fuse(this, otherAtom, weight)
otherAtom
}
}
deeplearning/AtomConnection.scala
package deeplearning
/**
* Created by Julian Liebl on 26.11.15.
*
* Class which represents an atom connection in the Ising Model.
*/
case class AtomConnection(connectedAtom:Atom, weight:Double)
object AtomConnection{
/**
* Creates a weighted connection between two atoms.
*
* @param a1 first atom
* @param a2 second atom
* @param weight weight of the connection
*/
def fuse(a1:Atom, a2:Atom, weight:Double): Unit ={
a1.addConnection(new AtomConnection(a2, weight))
a2.addConnection(new AtomConnection(a1, weight))
}
}
deeplearning/AtomState.scala
package deeplearning
/**
* Created by Julian Liebl on 25.11.15.
*
* Class which represents an atom state in the Ising Model.
*/
object AtomState extends Enumeration {
type AtomState = Value
val Up, Down = Value
/**
* Helper method which returns the numerical state value.
*
* Up = 1
* Down = -1
*
* @param atomState atom state
* @return the numerical representation of the atom state
*/
def getAtomStateValue(atomState: AtomState): Int ={
if(atomState equals Up) 1 else -1
}
/**
* Helper method which returns a random atom state.
* @return the random atom state
*/
def getRandomAtomState(): AtomState ={
val r = scala.util.Random
if(r.nextInt(2) equals 0) Up else Down
}
/**
* Helper method which return the opposite atom state.
*
* @param atomState atom state
* @return the opposite atom state
*/
def getOppositeState(atomState: AtomState) ={
if(atomState equals Up) Down else Up
}
}
4 Answers 4
Thank you for this interesting question. Generally, your code is quite clean and the years of experience are visible through it. I'm also a Java developer, switching on Scala, that I love more and more.
Here are my observations about this code.
class Atom
The main thing about it is the usage of var connections
. Var
iables are discouraged in Scala; they should be used in reduced scope without external exposure (connections
has public access here).
So we can just transform the field into val
:
val connections = ListBuffer[AtomConnection]()
// note that the type declaration with ":" is not necessary here
Now we cannot reassign the field, but we can operate on its contents when necessary. removeConection(Atom)
method would look like this:
def removeConnection(atom : Atom): Unit ={
connections.find(connection => connection.connectedAtom == atom) match {
case None => {}
case Some(atomConnection) => { removeConnection(atomConnection) }
}
}
The principle is: when the first AtomConnection
element is found that contains the specified Atom
, it redirects to the overloaded removeConnection
function with this element; otherwise it does nothing.
Please note that the equality is checked with ==
, which is equivalent to Java's equals
.
The same kind of changes can also be done with removeConnections(atoms:Seq[Atom])
function, but it doesn't seem to be used throughout the code.
The getConnections()
function is not necessary at all. It looks like a sort of Java's residue; val connections
already has public access and the reference is immutable.
Generally, I'm not sure about the validity of the choice to use var atomState
and the exposed ListBuffer
for connections
field in Atom
class. This sort of mutability is somewhat against Scala's principles, but I don't see yet how to work it around.
class IsingModelSmall
In def calcStableState
, def getMinMaxWeight
and def vibrateInner
there are similar calls:
atom.connections.foreach(connection => {
val connectedAtom = connection.connectedAtom
...
if(!(touchedAtoms contains connectedAtom)) {
...
}
}
They can be refactored into a dedicated function:
private def filterNonConnected(atom : Atom, touchedAtoms:Set[Atom]) =
atom.connections.filter(connection =>
!(touchedAtoms.contains(connection.connectedAtom))
).toList
Now the var sum
counter in calcStableState
can be eliminated:
def calcStableState(atom:Atom, touchedAtoms:Set[Atom] = Set()): Double ={
val a1v = getAtomStateValue(atom.atomState)
val sum = filterNonConnected(atom, touchedAtoms).foldLeft(0d)((sum, connection) => {
val connectedAtom = connection.connectedAtom
val a2v = getAtomStateValue(connectedAtom.atomState)
sum + a1v * a2v * connection.weight + calcStableState(connectedAtom, touchedAtoms + atom)
})
-sum
}
What is done in calcStableState
: 1) we filter the connections to work on with the refactored filterNonConnected
; 2) we use the standard foldLeft
function to calculate the sum. It takes the initial value 0d
and performs the calculation using the current value in sum
, tupled with each of the filtered connection
s.
getMinMaxWeight
function can also be simplified using our refactored filterNonConnected
and foldLeft
:
def getMinMaxWeight(atom:Atom, touchedAtoms:Set[Atom] = Set()): MinMax ={
filterNonConnected(atom, touchedAtoms).foldLeft(MinMax(0,0))((curMinMax, connection) => {
val currentWeight = connection.weight
val provisionalMinMax = getMinMaxWeight(connection.connectedAtom, touchedAtoms + atom)
MinMax(List(curMinMax.min, currentWeight, provisionalMinMax.min).min,
List(curMinMax.max, currentWeight, provisionalMinMax.max).max)
})
}
Please note that there is not a single if-else
left in this function.
BTW, there is a potential bug in the original implementation. MinMax
shouldn't be initialized with (0, 0)
, but rather with (Double.MaxValue, Double.MinValue)
.
vibrateInner
function can be refactored using the same principle, bu there will be two more things: 1) the removeConnection
call needs to be separated into a dedicated loop, for example:
innerAtom.connections.foreach(connection => {
connection.connectedAtom.removeConnection(innerAtom)
})
But I'm not sure that from the point of view of design it's a good solution, because it looks like a violation of LoD. There should be a better approach to solve it.
2) the calculation of connectedAtomState
should be moved into a separate method, which will allow to eliminate the last var
remaining in this part of code.
object IsingModelSmall
In Scala, there is no need to define def main(args : Array[String])
. A simple extension of App
would do the job:
object IsingModelSmall extends App {
// the body of the main(args) method to be placed here directly
}
-
\$\begingroup\$ Great! Thank you very much already. I will implement your suggestions. Your are right about the bug. When there is no negative or positiv weight the zeros would cause an error in calculation! \$\endgroup\$Julian Pieles– Julian Pieles2015年12月01日 08:17:27 +00:00Commented Dec 1, 2015 at 8:17
Lists
If you have 8 variables that are almost identical, you should group them into a list for order, conciseness and to simplify future changes. I am looking at:
val x1 = new Atom(Up) .fuse(-50 , new Atom(Up)) .fuse(99, new Atom(Up))
val x2 = new Atom(Down) .fuse(-50 , new Atom(Up)) .fuse(99, new Atom(Up))
val x3 = new Atom(Up) .fuse(-50 , new Atom(Down)) .fuse(99, new Atom(Up))
val x4 = new Atom(Up) .fuse(-50 , new Atom(Up)) .fuse(99, new Atom(Down))
val x5 = new Atom(Down) .fuse(-50 , new Atom(Down)) .fuse(99, new Atom(Up))
val x6 = new Atom(Up) .fuse(-50 , new Atom(Down)) .fuse(99, new Atom(Down))
val x7 = new Atom(Down) .fuse(-50 , new Atom(Up)) .fuse(99, new Atom(Down))
val x8 = new Atom(Down) .fuse(-50 , new Atom(Down)) .fuse(99, new Atom(Down))
Using a list will also simplify printing.
-
\$\begingroup\$ Thank you for pointing that out. I was a bit lazy there. I wrote a comment with code changes -> codereview.stackexchange.com/a/112079/90723 based on your input. \$\endgroup\$Julian Pieles– Julian Pieles2015年11月27日 20:35:11 +00:00Commented Nov 27, 2015 at 20:35
Nice question!
In scala, it is better to use objects as enum values, since it can be much more powerful and works better with pattern matching. This article is worth reading.
object AtomState{
sealed trait EnumVal
case object Up extends EnumVal
case object Down extends EnumVal
val states = Seq(Up, Down)
}
Also, to avoid the repetition, you could use a nice idiomatic for loop.
val combinations = for {
as1 <- AtomState.states
as2 <- AtomState.states
as3 <- AtomState.states
} yield (as1,as2,as3)
combinations.foreach{println}
Prints out:
(Up,Up,Up)
(Up,Up,Down)
(Up,Down,Up)
(Up,Down,Down)
(Down,Up,Up)
(Down,Up,Down)
(Down,Down,Up)
(Down,Down,Down)
You can see how this could be applied in to yield your 8 variables instead:
val atoms = for {
as1 <- AtomState.states
as2 <- AtomState.states
as3 <- AtomState.states
} yield new Atom(as1).fuse(-50, new Atom(as2)).fuse(99, new Atom(as3))
If a map is needed, although it feels like a bad practice:
val mapped = combinations.zipWithIndex.map{case (combo,index:Int) => ("x"+index,combo) }.toMap
mapped.foreach{println}
You zip with the index, then map to a tuples of (key -> value), then convert the tuples to map.
Which prints:
(x3,(Up,Down,Down))
(x7,(Down,Down,Down))
(x2,(Up,Down,Up))
(x0,(Up,Up,Up))
(x5,(Down,Up,Down))
(x6,(Down,Down,Up))
(x1,(Up,Up,Down))
(x4,(Down,Up,Up))
Here are my edits so far. I do not want to post a new question yet. According to this post it is a good solution if it is a descriptive comment and helpful for others.
I put my x1-8 variables in a map instead. It results in a cleaner less duplicated code and simplifies the printing as @Caridorc has suggested.
Map
var atomTemplates:Map[String,Atom] = Map(
"x1" -> new Atom(Up) .fuse(-50 , new Atom(Up)) .fuse(99, new Atom(Up)),
"x2" -> new Atom(Down) .fuse(-50 , new Atom(Up)) .fuse(99, new Atom(Up)),
"x3" -> new Atom(Up) .fuse(-50 , new Atom(Down)) .fuse(99, new Atom(Up)),
"x4" -> new Atom(Up) .fuse(-50 , new Atom(Up)) .fuse(99, new Atom(Down)),
"x5" -> new Atom(Down) .fuse(-50 , new Atom(Down)) .fuse(99, new Atom(Up)),
"x6" -> new Atom(Up) .fuse(-50 , new Atom(Down)) .fuse(99, new Atom(Down)),
"x7" -> new Atom(Down) .fuse(-50 , new Atom(Up)) .fuse(99, new Atom(Down)),
"x8" -> new Atom(Down) .fuse(-50 , new Atom(Down)) .fuse(99, new Atom(Down))
)
Printing
model.atomTemplates.toList.sortBy(_._1).foreach(entry => {
println("E(" + entry._1 + ",w) = " + model.calcStableState(entry._2))
})