dlib C++ Library - assignment_learning.cpp

// Copyright (C) 2011 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#include <sstream>
#include <string>
#include <cstdlib>
#include <ctime>
#include "tester.h"
#include <dlib/svm_threaded.h>
#include <dlib/rand.h>
typedef dlib::matrix<double,3,1> lhs_element;
typedef dlib::matrix<double,3,1> rhs_element;
namespace 
{
 using namespace test;
 using namespace dlib;
 using namespace std;
 logger dlog("test.assignment_learning");
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
 struct feature_extractor_dense
 {
 typedef matrix<double,3,1> feature_vector_type;
 typedef ::lhs_element lhs_element;
 typedef ::rhs_element rhs_element;
 unsigned long num_features() const
 {
 return 3;
 }
 void get_features (
 const lhs_element& left,
 const rhs_element& right,
 feature_vector_type& feats
 ) const
 {
 feats = squared(left - right);
 }
 };
 void serialize (const feature_extractor_dense& , std::ostream& ) {}
 void deserialize (feature_extractor_dense& , std::istream& ) {}
// ----------------------------------------------------------------------------------------
 struct feature_extractor_sparse
 {
 typedef std::vector<std::pair<unsigned long,double> > feature_vector_type;
 typedef ::lhs_element lhs_element;
 typedef ::rhs_element rhs_element;
 unsigned long num_features() const
 {
 return 3;
 }
 void get_features (
 const lhs_element& left,
 const rhs_element& right,
 feature_vector_type& feats
 ) const
 {
 feats.clear();
 feats.push_back(make_pair(0,squared(left-right)(0)));
 feats.push_back(make_pair(1,squared(left-right)(1)));
 feats.push_back(make_pair(2,squared(left-right)(2)));
 }
 };
 void serialize (const feature_extractor_sparse& , std::ostream& ) {}
 void deserialize (feature_extractor_sparse& , std::istream& ) {}
// ----------------------------------------------------------------------------------------
 typedef std::pair<std::vector<lhs_element>, std::vector<rhs_element> > sample_type;
 typedef std::vector<long> label_type;
// ----------------------------------------------------------------------------------------
 void make_data (
 std::vector<sample_type>& samples,
 std::vector<label_type>& labels
 )
 {
 lhs_element a, b, c, d;
 a = 1,0,0;
 b = 0,1,0;
 c = 0,0,1;
 d = 0,1,1;
 std::vector<lhs_element> lhs;
 std::vector<rhs_element> rhs;
 label_type label;
 lhs.push_back(a);
 lhs.push_back(b);
 lhs.push_back(c);
 rhs.push_back(b);
 rhs.push_back(a);
 rhs.push_back(c);
 label.push_back(1);
 label.push_back(0);
 label.push_back(2);
 samples.push_back(make_pair(lhs,rhs));
 labels.push_back(label);
 lhs.clear();
 rhs.clear();
 label.clear();
 lhs.push_back(a);
 lhs.push_back(b);
 lhs.push_back(c);
 rhs.push_back(c);
 rhs.push_back(b);
 rhs.push_back(a);
 rhs.push_back(d);
 label.push_back(2);
 label.push_back(1);
 label.push_back(0);
 samples.push_back(make_pair(lhs,rhs));
 labels.push_back(label);
 lhs.clear();
 rhs.clear();
 label.clear();
 lhs.push_back(a);
 lhs.push_back(b);
 lhs.push_back(c);
 rhs.push_back(c);
 rhs.push_back(a);
 rhs.push_back(d);
 label.push_back(1);
 label.push_back(-1);
 label.push_back(0);
 samples.push_back(make_pair(lhs,rhs));
 labels.push_back(label);
 lhs.clear();
 rhs.clear();
 label.clear();
 lhs.push_back(d);
 lhs.push_back(b);
 lhs.push_back(c);
 label.push_back(-1);
 label.push_back(-1);
 label.push_back(-1);
 samples.push_back(make_pair(lhs,rhs));
 labels.push_back(label);
 lhs.clear();
 rhs.clear();
 label.clear();
 samples.push_back(make_pair(lhs,rhs));
 labels.push_back(label);
 }
// ----------------------------------------------------------------------------------------
 void make_data_force (
 std::vector<sample_type>& samples,
 std::vector<label_type>& labels
 )
 {
 lhs_element a, b, c, d;
 a = 1,0,0;
 b = 0,1,0;
 c = 0,0,1;
 d = 0,1,1;
 std::vector<lhs_element> lhs;
 std::vector<rhs_element> rhs;
 label_type label;
 lhs.push_back(a);
 lhs.push_back(b);
 lhs.push_back(c);
 rhs.push_back(b);
 rhs.push_back(a);
 rhs.push_back(c);
 label.push_back(1);
 label.push_back(0);
 label.push_back(2);
 samples.push_back(make_pair(lhs,rhs));
 labels.push_back(label);
 lhs.clear();
 rhs.clear();
 label.clear();
 lhs.push_back(a);
 lhs.push_back(b);
 lhs.push_back(c);
 rhs.push_back(c);
 rhs.push_back(b);
 rhs.push_back(a);
 rhs.push_back(d);
 label.push_back(2);
 label.push_back(1);
 label.push_back(0);
 samples.push_back(make_pair(lhs,rhs));
 labels.push_back(label);
 lhs.clear();
 rhs.clear();
 label.clear();
 lhs.push_back(a);
 lhs.push_back(c);
 rhs.push_back(c);
 rhs.push_back(a);
 label.push_back(1);
 label.push_back(0);
 samples.push_back(make_pair(lhs,rhs));
 labels.push_back(label);
 lhs.clear();
 rhs.clear();
 label.clear();
 samples.push_back(make_pair(lhs,rhs));
 labels.push_back(label);
 }
// ----------------------------------------------------------------------------------------
 template <typename fe_type, typename F>
 void test1(F make_data, bool force_assignment)
 {
 print_spinner();
 std::vector<sample_type> samples;
 std::vector<label_type> labels;
 make_data(samples, labels);
 make_data(samples, labels);
 make_data(samples, labels);
 randomize_samples(samples, labels);
 structural_assignment_trainer<fe_type> trainer;
 DLIB_TEST(trainer.forces_assignment() == false);
 DLIB_TEST(trainer.get_c() == 100);
 DLIB_TEST(trainer.get_num_threads() == 2);
 DLIB_TEST(trainer.get_max_cache_size() == 5);
 trainer.set_forces_assignment(force_assignment);
 trainer.set_num_threads(3);
 trainer.set_c(50);
 DLIB_TEST(trainer.get_c() == 50);
 DLIB_TEST(trainer.get_num_threads() == 3);
 DLIB_TEST(trainer.forces_assignment() == force_assignment);
 assignment_function<fe_type> ass = trainer.train(samples, labels);
 for (unsigned long i = 0; i < samples.size(); ++i)
 {
 std::vector<long> out = ass(samples[i]);
 dlog << LINFO << "true labels: " << trans(mat(labels[i]));
 dlog << LINFO << "pred labels: " << trans(mat(out));
 DLIB_TEST(trans(mat(labels[i])) == trans(mat(out)));
 }
 double accuracy;
 dlog << LINFO << "samples.size(): "<< samples.size();
 accuracy = test_assignment_function(ass, samples, labels);
 dlog << LINFO << "accuracy: "<< accuracy;
 DLIB_TEST(accuracy == 1);
 accuracy = cross_validate_assignment_trainer(trainer, samples, labels, 3);
 dlog << LINFO << "cv accuracy: "<< accuracy;
 DLIB_TEST(accuracy == 1);
 ostringstream sout;
 serialize(ass, sout);
 istringstream sin(sout.str());
 assignment_function<fe_type> ass2;
 deserialize(ass2, sin);
 DLIB_TEST(ass2.forces_assignment() == ass.forces_assignment());
 DLIB_TEST(length(ass2.get_weights() - ass.get_weights()) < 1e-10);
 for (unsigned long i = 0; i < samples.size(); ++i)
 {
 std::vector<long> out = ass2(samples[i]);
 dlog << LINFO << "true labels: " << trans(mat(labels[i]));
 dlog << LINFO << "pred labels: " << trans(mat(out));
 DLIB_TEST(trans(mat(labels[i])) == trans(mat(out)));
 }
 }
// ----------------------------------------------------------------------------------------
 class test_assignment_learning : public tester
 {
 public:
 test_assignment_learning (
 ) :
 tester ("test_assignment_learning",
 "Runs tests on the assignment learning code.")
 {}
 void perform_test (
 )
 {
 test1<feature_extractor_dense>(make_data, false);
 test1<feature_extractor_sparse>(make_data, false);
 test1<feature_extractor_dense>(make_data_force, false);
 test1<feature_extractor_sparse>(make_data_force, false);
 test1<feature_extractor_dense>(make_data_force, true);
 test1<feature_extractor_sparse>(make_data_force, true);
 }
 } a;
// ----------------------------------------------------------------------------------------
}

AltStyle によって変換されたページ (->オリジナル) /