dlib C++ Library - kmeans.cpp

// Copyright (C) 2011 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#include <sstream>
#include <string>
#include <cstdlib>
#include <ctime>
#include <dlib/svm.h>
#include <dlib/matrix.h>
#include "tester.h"
namespace 
{
 using namespace test;
 using namespace dlib;
 using namespace std;
 logger dlog("test.kmeans");
 dlib::rand rnd;
 template <typename sample_type>
 void run_test(
 const std::vector<sample_type>& seed_centers
 )
 {
 print_spinner();
 sample_type samp;
 std::vector<sample_type> samples;
 for (unsigned long j = 0; j < seed_centers.size(); ++j)
 {
 for (int i = 0; i < 250; ++i)
 {
 samp = randm(seed_centers[0].size(),1,rnd) - 0.5;
 samples.push_back(samp + seed_centers[j]);
 }
 }
 randomize_samples(samples);
 {
 std::vector<sample_type> centers;
 pick_initial_centers(seed_centers.size(), centers, samples, linear_kernel<sample_type>());
 find_clusters_using_kmeans(samples, centers);
 DLIB_TEST(centers.size() == seed_centers.size());
 std::vector<int> hits(centers.size(),0);
 for (unsigned long i = 0; i < samples.size(); ++i)
 {
 unsigned long best_idx = 0;
 double best_dist = 1e100;
 for (unsigned long j = 0; j < centers.size(); ++j)
 {
 if (length(samples[i] - centers[j]) < best_dist)
 {
 best_dist = length(samples[i] - centers[j]);
 best_idx = j;
 }
 }
 hits[best_idx]++;
 }
 for (unsigned long i = 0; i < hits.size(); ++i)
 {
 DLIB_TEST(hits[i] == 250);
 }
 }
 {
 std::vector<sample_type> centers;
 pick_initial_centers(seed_centers.size(), centers, samples, linear_kernel<sample_type>());
 find_clusters_using_angular_kmeans(samples, centers);
 DLIB_TEST(centers.size() == seed_centers.size());
 std::vector<int> hits(centers.size(),0);
 for (unsigned long i = 0; i < samples.size(); ++i)
 {
 unsigned long best_idx = 0;
 double best_dist = 1e100;
 for (unsigned long j = 0; j < centers.size(); ++j)
 {
 if (length(samples[i] - centers[j]) < best_dist)
 {
 best_dist = length(samples[i] - centers[j]);
 best_idx = j;
 }
 }
 hits[best_idx]++;
 }
 for (unsigned long i = 0; i < hits.size(); ++i)
 {
 DLIB_TEST(hits[i] == 250);
 }
 }
 }
 class test_kmeans : public tester
 {
 public:
 test_kmeans (
 ) :
 tester ("test_kmeans",
 "Runs tests on the find_clusters_using_kmeans() function.")
 {}
 void perform_test (
 )
 {
 {
 dlog << LINFO << "test dlib::vector<double,2>";
 typedef dlib::vector<double,2> sample_type;
 std::vector<sample_type> seed_centers;
 seed_centers.push_back(sample_type(10,10));
 seed_centers.push_back(sample_type(10,-10));
 seed_centers.push_back(sample_type(-10,10));
 seed_centers.push_back(sample_type(-10,-10));
 run_test(seed_centers);
 }
 {
 dlog << LINFO << "test dlib::vector<double,2>";
 typedef dlib::vector<float,2> sample_type;
 std::vector<sample_type> seed_centers;
 seed_centers.push_back(sample_type(10,10));
 seed_centers.push_back(sample_type(10,-10));
 seed_centers.push_back(sample_type(-10,10));
 seed_centers.push_back(sample_type(-10,-10));
 run_test(seed_centers);
 }
 {
 dlog << LINFO << "test dlib::matrix<double,3,1>";
 typedef dlib::matrix<double,3,1> sample_type;
 std::vector<sample_type> seed_centers;
 sample_type samp;
 samp = 10,10,0; seed_centers.push_back(samp);
 samp = -10,10,1; seed_centers.push_back(samp);
 samp = -10,-10,2; seed_centers.push_back(samp);
 run_test(seed_centers);
 }
 }
 } a;
}

AltStyle によって変換されたページ (->オリジナル) /