dlib C++ Library - lspi.cpp

// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#include "tester.h"
#include <dlib/control.h>
#include <vector>
#include <sstream>
#include <ctime>
namespace 
{
 using namespace test;
 using namespace dlib;
 using namespace std;
 dlib::logger dlog("test.lspi");
 template <bool have_prior>
 struct chain_model
 {
 typedef int state_type;
 typedef int action_type; // 0 is move left, 1 is move right
 const static bool force_last_weight_to_1 = have_prior;
 const static int num_states = 4; // not required in the model interface
 matrix<double,8,1> offset;
 chain_model()
 {
 offset = 
 2.048 ,
 2.56 ,
 2.048 ,
 3.2 ,
 2.56 ,
 4 ,
 3.2, 
 5 ;
 if (!have_prior)
 offset = 0;
 }
 unsigned long num_features(
 ) const 
 {
 if (have_prior)
 return num_states*2 + 1; 
 else
 return num_states*2; 
 }
 action_type find_best_action (
 const state_type& state,
 const matrix<double,0,1>& w
 ) const
 {
 if (w(state*2)+offset(state*2) >= w(state*2+1)+offset(state*2+1))
 //if (w(state*2) >= w(state*2+1))
 return 0;
 else
 return 1;
 }
 void get_features (
 const state_type& state,
 const action_type& action,
 matrix<double,0,1>& feats
 ) const
 {
 feats.set_size(num_features());
 feats = 0;
 feats(state*2 + action) = 1;
 if (have_prior)
 feats(num_features()-1) = offset(state*2+action);
 }
 };
 void test_lspi_prior1()
 {
 print_spinner();
 typedef process_sample<chain_model<true> > sample_type;
 std::vector<sample_type> samples;
 samples.push_back(sample_type(0,0,0,0));
 samples.push_back(sample_type(0,1,1,0));
 samples.push_back(sample_type(1,0,0,0));
 samples.push_back(sample_type(1,1,2,0));
 samples.push_back(sample_type(2,0,1,0));
 samples.push_back(sample_type(2,1,3,0));
 samples.push_back(sample_type(3,0,2,0));
 samples.push_back(sample_type(3,1,3,1));
 lspi<chain_model<true> > trainer;
 //trainer.be_verbose();
 trainer.set_lambda(0);
 policy<chain_model<true> > pol = trainer.train(samples);
 dlog << LINFO << pol.get_weights();
 matrix<double,0,1> w = pol.get_weights();
 DLIB_TEST(pol.get_weights().size() == 9);
 DLIB_TEST(w(w.size()-1) == 1);
 w(w.size()-1) = 0;
 DLIB_TEST_MSG(length(w) < 1e-12, length(w));
 dlog << LINFO << "action: " << pol(0);
 dlog << LINFO << "action: " << pol(1);
 dlog << LINFO << "action: " << pol(2);
 dlog << LINFO << "action: " << pol(3);
 DLIB_TEST(pol(0) == 1);
 DLIB_TEST(pol(1) == 1);
 DLIB_TEST(pol(2) == 1);
 DLIB_TEST(pol(3) == 1);
 }
 void test_lspi_prior2()
 {
 print_spinner();
 typedef process_sample<chain_model<true> > sample_type;
 std::vector<sample_type> samples;
 samples.push_back(sample_type(0,0,0,0));
 samples.push_back(sample_type(0,1,1,0));
 samples.push_back(sample_type(1,0,0,0));
 samples.push_back(sample_type(1,1,2,0));
 samples.push_back(sample_type(2,0,1,0));
 samples.push_back(sample_type(2,1,3,1));
 samples.push_back(sample_type(3,0,2,0));
 samples.push_back(sample_type(3,1,3,0));
 lspi<chain_model<true> > trainer;
 //trainer.be_verbose();
 trainer.set_lambda(0);
 policy<chain_model<true> > pol = trainer.train(samples);
 dlog << LINFO << "action: " << pol(0);
 dlog << LINFO << "action: " << pol(1);
 dlog << LINFO << "action: " << pol(2);
 dlog << LINFO << "action: " << pol(3);
 DLIB_TEST(pol(0) == 1);
 DLIB_TEST(pol(1) == 1);
 DLIB_TEST(pol(2) == 1);
 DLIB_TEST(pol(3) == 0);
 }
 void test_lspi_noprior1()
 {
 print_spinner();
 typedef process_sample<chain_model<false> > sample_type;
 std::vector<sample_type> samples;
 samples.push_back(sample_type(0,0,0,0));
 samples.push_back(sample_type(0,1,1,0));
 samples.push_back(sample_type(1,0,0,0));
 samples.push_back(sample_type(1,1,2,0));
 samples.push_back(sample_type(2,0,1,0));
 samples.push_back(sample_type(2,1,3,0));
 samples.push_back(sample_type(3,0,2,0));
 samples.push_back(sample_type(3,1,3,1));
 lspi<chain_model<false> > trainer;
 //trainer.be_verbose();
 trainer.set_lambda(0.01);
 policy<chain_model<false> > pol = trainer.train(samples);
 dlog << LINFO << pol.get_weights();
 DLIB_TEST(pol.get_weights().size() == 8);
 dlog << LINFO << "action: " << pol(0);
 dlog << LINFO << "action: " << pol(1);
 dlog << LINFO << "action: " << pol(2);
 dlog << LINFO << "action: " << pol(3);
 DLIB_TEST(pol(0) == 1);
 DLIB_TEST(pol(1) == 1);
 DLIB_TEST(pol(2) == 1);
 DLIB_TEST(pol(3) == 1);
 }
 void test_lspi_noprior2()
 {
 print_spinner();
 typedef process_sample<chain_model<false> > sample_type;
 std::vector<sample_type> samples;
 samples.push_back(sample_type(0,0,0,0));
 samples.push_back(sample_type(0,1,1,0));
 samples.push_back(sample_type(1,0,0,0));
 samples.push_back(sample_type(1,1,2,1));
 samples.push_back(sample_type(2,0,1,0));
 samples.push_back(sample_type(2,1,3,0));
 samples.push_back(sample_type(3,0,2,0));
 samples.push_back(sample_type(3,1,3,0));
 lspi<chain_model<false> > trainer;
 //trainer.be_verbose();
 trainer.set_lambda(0.01);
 policy<chain_model<false> > pol = trainer.train(samples);
 dlog << LINFO << pol.get_weights();
 DLIB_TEST(pol.get_weights().size() == 8);
 dlog << LINFO << "action: " << pol(0);
 dlog << LINFO << "action: " << pol(1);
 dlog << LINFO << "action: " << pol(2);
 dlog << LINFO << "action: " << pol(3);
 DLIB_TEST(pol(0) == 1);
 DLIB_TEST(pol(1) == 1);
 DLIB_TEST(pol(2) == 0);
 DLIB_TEST(pol(3) == 0);
 }
 class lspi_tester : public tester
 {
 public:
 lspi_tester (
 ) :
 tester (
 "test_lspi", // the command line argument name for this test
 "Run tests on the lspi object.", // the command line argument description
 0 // the number of command line arguments for this test
 )
 {
 }
 void perform_test (
 )
 {
 test_lspi_prior1();
 test_lspi_prior2();
 test_lspi_noprior1();
 test_lspi_noprior2();
 }
 };
 lspi_tester a;
}

AltStyle によって変換されたページ (->オリジナル) /