I tried to implement a DBScan in C# using kd-trees. I followed the implementation from here.
public class DbscanAlgorithm
{
private readonly Func<PointD, PointD, double> _metricFunc;
public DbscanAlgorithm(Func<PointD, PointD, double> metricFunc)
{
_metricFunc = metricFunc;
}
public void ComputeClusterDbscan(ScanPoint[] allPoints, double epsilon, int minPts, out HashSet<ScanPoint[]> clusters)
{
clusters = null;
var allPointsDbscan = allPoints.Select(x => new DbscanPoint(x)).ToArray();
var tree = new KDTree.KDTree<DbscanPoint>(2);
for (var i = 0; i < allPointsDbscan.Length; ++i)
{
tree.AddPoint(new double[] { allPointsDbscan[i].ClusterPoint.point.X, allPointsDbscan[i].ClusterPoint.point.Y }, allPointsDbscan[i]);
}
var C = 0;
for (int i = 0; i < allPointsDbscan.Length; i++)
{
var p = allPointsDbscan[i];
if (p.IsVisited)
continue;
p.IsVisited = true;
DbscanPoint[] neighborPts = null;
RegionQuery(tree, p.ClusterPoint.point, epsilon, out neighborPts);
if (neighborPts.Length < minPts)
p.ClusterId = (int)ClusterIds.NOISE;
else
{
C++;
ExpandCluster(tree, p, neighborPts, C, epsilon, minPts);
}
}
clusters = new HashSet<ScanPoint[]>(
allPointsDbscan
.Where(x => x.ClusterId > 0)
.GroupBy(x => x.ClusterId)
.Select(x => x.Select(y => y.ClusterPoint).ToArray())
);
return;
}
private void ExpandCluster(KDTree.KDTree<DbscanPoint> tree, DbscanPoint p, DbscanPoint[] neighborPts, int c, double epsilon, int minPts)
{
p.ClusterId = c;
for (int i = 0; i < neighborPts.Length; i++)
{
var pn = neighborPts[i];
if (!pn.IsVisited)
{
pn.IsVisited = true;
DbscanPoint[] neighborPts2 = null;
RegionQuery(tree, pn.ClusterPoint.point, epsilon, out neighborPts2);
if (neighborPts2.Length >= minPts)
{
neighborPts = neighborPts.Union(neighborPts2).ToArray();
}
}
if (pn.ClusterId == (int)ClusterIds.UNCLASSIFIED)
pn.ClusterId = c;
}
}
private void RegionQuery(KDTree.KDTree<DbscanPoint> tree, PointD p, double epsilon, out DbscanPoint[] neighborPts)
{
int totalCount = 0;
var pIter = tree.NearestNeighbors(new double[] { p.X, p.Y }, 10, epsilon);
while (pIter.MoveNext())
{
totalCount++;
}
neighborPts = new DbscanPoint[totalCount];
int currCount = 0;
pIter.Reset();
while (pIter.MoveNext())
{
neighborPts[currCount] = pIter.Current;
currCount++;
}
return;
}
}
//Dbscan clustering identifiers
public enum ClusterIds
{
UNCLASSIFIED = 0,
NOISE = -1
}
//Point container for Dbscan clustering
public class DbscanPoint
{
public bool IsVisited;
public ScanPoint ClusterPoint;
public int ClusterId;
public DbscanPoint(ScanPoint point)
{
ClusterPoint = point;
IsVisited = false;
ClusterId = (int)ClusterIds.UNCLASSIFIED;
}
}
and modifying the regionQuery(P, eps)
to invoke the nearest neighbour function of a kd-tree. To do so, I used the kd-sharp
library for C#, which is one of the fastest kd-tree implementations out there.
However, when given a dataset of about 20000 2d points, its performance is in the region of 40s, as compared to the scikit-learn
Python implementation of DBScan, which given the same parameters, takes about 2s.
Since this algorithm is for a C# program that I am writing, I am stuck using C#. As such, I would like to find out what am I still missing out in terms of optimization of the algorithm?
1 Answer 1
_metricFunc
is unused, which means it can either be removed, or there's a bug in the program.
The first line in ComputeClusterDbscan
, clusters = null;
, is superfluous and can be removed.
The use of out
parameters can be avoided by just returning a value.
Methods that can be marked static
should be marked static
.
In RegionQuery
, it is probably faster to iterate over nearest neighbours just one, like so:
private static DbscanPoint[] RegionQuery(KDTree<DbscanPoint> tree, PointD p, double epsilon)
{
var neighbors = new List<DbscanPoint>();
var e = tree.NearestNeighbors(new[] { p.X, p.Y }, 10, epsilon);
while (e.MoveNext())
{
neighbors.Add(e.Current);
}
return neighbors.ToArray();
}
I believe the bottleneck in your program is this line in ExpandCluster
:
neighborPts = neighborPts.Union(neighborPts2).ToArray();
Try something like this instead:
private static void ExpandCluster(KDTree<DbscanPoint> tree, DbscanPoint p, DbscanPoint[] neighborPts, int c, double epsilon, int minPts)
{
p.ClusterId = c;
var queue = new Queue<DbscanPoint>(neighborPts);
while (queue.Count > 0)
{
var point = queue.Dequeue();
if (point.ClusterId == (int)ClusterIds.UNCLASSIFIED)
{
point.ClusterId = c;
}
if (point.IsVisited)
{
continue;
}
point.IsVisited = true;
var neighbors = RegionQuery(tree, point.ClusterPoint.point, epsilon);
if (neighbors.Length >= minPts)
{
foreach (var neighbor in neighbors.Where(neighbor => !neighbor.IsVisited))
{
queue.Enqueue(neighbor);
}
}
}
}
-
\$\begingroup\$ I think in
RegionQuery
ifNearestNeighbors
returns anIEnumerable
(unfortunately we don't know it) (I guess an enumerator) it would be faster to just callToArray
instead of creating aList<>
first and thenToArray
it (iterating it twice) or still better make theReqionQuery
return anIEnumerable
. \$\endgroup\$t3chb0t– t3chb0t2015年10月28日 09:11:47 +00:00Commented Oct 28, 2015 at 9:11 -
\$\begingroup\$ @t3chb0t Thanks for the comment. You can find the source of
NearestNeighbors<T>
here -- it only implementsIEnumerator
. I wouldn't want to return anIEnumerable<DbscanPoint>
since the calling methods need to know the number of elements. But now you've got me thinking changingRegionQuery
to return anICollection<DbscanPoint>
(orIReadOnlyCollection<DbscanPoint>
) is a good choice. \$\endgroup\$mjolka– mjolka2015年10月28日 09:47:38 +00:00Commented Oct 28, 2015 at 9:47
Explore related questions
See similar questions with these tags.