I have implemented for learning purposes a simple K-Means clustering algorithm in Rust. For those who are not familiar: you are given N
points, say in the plane, and you want to group them in n
clusters of nearby points.
To do so, you start with n
random points, for instance the first n
of the given ones. Call these centroids. At each iteration:
- you group the
N
points based on the nearest centroid - you produce a new set of centroids as the average of the groups of the preceding step
You can stop after a fixed number of iterations, or after some convergence.
Here is my implementation, with the help of SO. For some reasons, the code runs slower than an equivalent algorithm written in Scala. I think I might be introducing some unnecessary copying or other hidden overhead, but I am not familiar enough with Rust to tell.
Just to be clear: I am not interested in changing algorithm (I want to compare apples to apples), and I would rather have idiomatic Rust than hyper-optimized code.
use std::collections::TreeMap;
use point::Point;
fn dist(v: Point, w: Point) -> f64 { (v - w).norm() }
fn avg(points: & Vec<Point>) -> Point {
let Point(x, y) = points.iter().fold(Point(0.0, 0.0), |p, &q| p + q);
let k = points.len() as f64;
Point(x / k, y / k)
}
fn closest(x: Point, ys: & Vec<Point>) -> Point {
let y0 = ys[0];
let d0 = dist(y0, x);
let (_, y) = ys.iter().fold((d0, y0),
|(m, p), &q| {
let d = dist(q, x);
if d < m { (d, q) } else { (m, p) }
}
);
y
}
fn clusters(xs: & Vec<Point>, centroids: & Vec<Point>) -> Vec<Vec<Point>> {
let mut groups: TreeMap<Point, Vec<Point>> = TreeMap::new();
for x in xs.iter() {
let y = closest(*x, centroids);
let should_insert = match groups.find_mut(&y) {
Some(val) => {
val.push(*x);
false
},
None => true
};
if should_insert {
groups.insert(y, vec![*x]);
}
}
groups.into_iter().map(|(_, v)| v).collect::<Vec<Vec<Point>>>()
}
pub fn run(points: & Vec<Point>, n: uint, iters: uint) -> Vec<Vec<Point>> {
let mut centroids: Vec<Point> = Vec::from_fn(n, |i| points[i]);
for _ in range(0, iters) {
centroids = clusters(points, & centroids).iter().map(|g| avg(g)).collect();
}
clusters(points, & centroids)
}
Definition of Point
:
use serialize::{Decoder, Decodable};
#[deriving(Show, PartialEq, PartialOrd, Clone)]
pub struct Point(pub f64, pub f64);
fn sq(x: f64) -> f64 { x * x }
impl Point {
pub fn norm(self: &Point) -> f64 {
let Point(x, y) = *self;
(sq(x) + sq(y)).sqrt()
}
}
impl<E, D: Decoder<E>> Decodable<D, E> for Point {
fn decode(d: &mut D) -> Result<Point, E> {
d.read_tuple(|d, n| {
if n != 2 { Err(d.error("invalid number of elements, need 2")) }
else {
d.read_tuple_arg(0, |d| d.read_f64()).and_then(|e1|
d.read_tuple_arg(1, |d| d.read_f64()).map(|e2|
Point(e1, e2)
)
)
}
})
}
}
impl Add<Point, Point> for Point {
fn add(&self, other: &Point) -> Point {
let &Point(a, b) = self;
let &Point(c, d) = other;
Point(a + c, b + d)
}
}
impl Sub<Point, Point> for Point {
fn sub(&self, other: &Point) -> Point {
let &Point(a, b) = self;
let &Point(c, d) = other;
Point(a - c, b - d)
}
}
impl Eq for Point {}
impl Ord for Point {
fn cmp(&self, other: &Point) -> Ordering {
self.partial_cmp(other).unwrap_or(Equal)
}
}
Is there anything that I am doing that justifies the unexpected slowness?
1 Answer 1
I took your code and got it to compile with my version of Rust (rustc 0.13.0-dev (29ad8539b 2014年12月24日 16:21:23 +0000)
).
I ran with your parameters (I hope I understood them correctly) and got an average time of 207.8 ms.
I made a few changes here and there, but the main thing was changing TreeMap
to HashMap
. TreeMap
doesn't exist anymore, only BTreeMap
. At the same time, I switched to using entry
, which avoids doing multiple lookups in the hash to add a value if it is missing.
The main problem with using a HashMap
is that f64
isn't hashable. This is for good reason - floating point numbers have lots of edge cases (like 16 million NaN values!). I don't know how Scala or Python deal with these values, so I did the simplest thing: ignored them. We just cast our 64-bit float to a 64-bit uint and away we go!
In your case, I think it's safe to assume that all Points
will have non-infinite and non-NaN values (and maybe we can ignore every other detail of floating point?). However, it might be worth adding a constructor that validates these assumptions and dies if they aren't true.
I didn't make any other changes, as 20% of your previous time seems like a good start.
extern crate time;
use std::collections::HashMap;
use std::collections::hash_map::Entry::{Occupied,Vacant};
use std::hash::Hash;
use std::hash::sip::SipState;
use std::mem;
use std::num::Float;
use std::rand::{Rng,StdRng,SeedableRng};
#[deriving(Show, PartialEq, Copy, Clone)]
pub struct Point(pub f64, pub f64);
fn sq(x: f64) -> f64 { x * x }
// This needs to have a guarantee that x and y will never be an
// Infinity or NaN
impl Point {
pub fn norm(self: &Point) -> f64 {
let Point(x, y) = *self;
(sq(x) + sq(y)).sqrt()
}
}
impl Hash for Point {
fn hash(&self, state: &mut SipState) {
// Perform a bit-wise transform, relying on the fact that we
// are never Infinity or NaN
let Point(x, y) = *self;
let x: u64 = unsafe { mem::transmute(x) };
let y: u64 = unsafe { mem::transmute(y) };
x.hash(state);
y.hash(state);
}
}
impl Add<Point, Point> for Point {
fn add(self, other: Point) -> Point {
let Point(a, b) = self;
let Point(c, d) = other;
Point(a + c, b + d)
}
}
impl Sub<Point, Point> for Point {
fn sub(self, other: Point) -> Point {
let Point(a, b) = self;
let Point(c, d) = other;
Point(a - c, b - d)
}
}
impl Eq for Point {}
fn dist(v: Point, w: Point) -> f64 { (v - w).norm() }
fn avg(points: & Vec<Point>) -> Point {
let Point(x, y) = points.iter().fold(Point(0.0, 0.0), |p, &q| p + q);
let k = points.len() as f64;
Point(x / k, y / k)
}
fn closest(x: Point, ys: & Vec<Point>) -> Point {
let y0 = ys[0];
let d0 = dist(y0, x);
let (_, y) = ys.iter().fold((d0, y0),
|(m, p), &q| {
let d = dist(q, x);
if d < m { (d, q) } else { (m, p) }
}
);
y
}
fn clusters(xs: & Vec<Point>, centroids: & Vec<Point>) -> Vec<Vec<Point>> {
let mut groups: HashMap<Point, Vec<Point>> = HashMap::new();
for x in xs.iter() {
let y = closest(*x, centroids);
// Notable change: avoid double hash lookups
match groups.entry(y) {
Occupied(entry) => entry.into_mut().push(*x),
Vacant(entry) => { entry.set(vec![*x]); () },
}
}
groups.into_iter().map(|(_, v)| v).collect::<Vec<Vec<Point>>>()
}
pub fn run(points: & Vec<Point>, n: uint, iters: uint) -> Vec<Vec<Point>> {
let mut centroids: Vec<Point> = Vec::from_fn(n, |i| points[i]);
for _ in range(0, iters) {
centroids = clusters(points, & centroids).iter().map(|g| avg(g)).collect();
}
clusters(points, & centroids)
}
fn main() {
let seed: &[_] = &[1, 2, 3, 4];
let mut rng: StdRng = SeedableRng::from_seed(seed);
let points = Vec::from_fn(100000, |_| Point(rng.gen(), rng.gen()));
println!("Made {} points: {}", points.len(), points.slice_to(3));
let repeat_count = 20u;
let mut total = 0;
for _ in range(0, repeat_count) {
let start = time::precise_time_ns();
let res = run(&points, 10, 15);
let end = time::precise_time_ns();
total += end - start
}
let avg_ns: f64 = total as f64 / repeat_count as f64;
let avg_ms = avg_ns / 1.0e6;
println!("{} runs, avg {}", repeat_count, avg_ms);
}
Comparison numbers
Here are the numbers I got from running various versions from your suite. Hopefully this gives some base comparison across machines.
$ cargo run --release
The average time is 209.44
$ cargo run
The average time is 2103.81
$ python kmeans.py
Made 100 iterations with an average of 13504.76 milliseconds
$ pypy kmeans.py
Made 100 iterations with an average of 608.97 milliseconds
$ lein trampoline run
The average time was 2799.54 ms
$ node kmeans.js
Running 100 iterations as required 4741.34 ms
-
\$\begingroup\$ Thank you very much! I was suspecting that the use of a TreeMap was the culprit, but I was unable to implement Hash for Point. I understand the issues of floating point arithmetic, but still being able to group points by their nearest neighbour seems a quite reasonable task for which a hashmap is definitely desired. In fact, your solution requires the use of unsafe code, which is unexpected for this task! I will try this implementation on the same laptop I have used for the other benchmarks and let you know the time \$\endgroup\$Andrea– Andrea2015年01月03日 22:41:05 +00:00Commented Jan 3, 2015 at 22:41
-
\$\begingroup\$ Turns out, this implementation is even slower than the previous one. On my laptop, this implementation runs in 1807ms. You can see the whole Rust implementation, as well as a few other languages, at github.com/andreaferretti/kmeans \$\endgroup\$Andrea– Andrea2015年01月03日 23:34:09 +00:00Commented Jan 3, 2015 at 23:34
-
2\$\begingroup\$ Not to sound like a broken record, but you are compiling with optimizations right? I grabbed your repo, shoehorned my changes in and ran:
cargo run
->2103.81
.cargo run --release
->209.44
\$\endgroup\$Shepmaster– Shepmaster2015年01月04日 00:20:16 +00:00Commented Jan 4, 2015 at 0:20 -
2\$\begingroup\$ It turns out that
cargo run --release
does the trick! What I was doing wascargo build --release
and thentarget/kmeans
. Is there a way to build with optimizations on, instead of relying on Cargo to run the program? \$\endgroup\$Andrea– Andrea2015年01月05日 09:03:29 +00:00Commented Jan 5, 2015 at 9:03 -
3\$\begingroup\$ Release builds go to
target/release/*
. Non-release builds go to/target/*
\$\endgroup\$Shepmaster– Shepmaster2015年01月05日 13:18:51 +00:00Commented Jan 5, 2015 at 13:18
Explore related questions
See similar questions with these tags.
rustc -O
,cargo --release
)? \$\endgroup\$