dlib C++ Library - lspi.h

// Copyright (C) 2015 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_LSPI_Hh_
#define DLIB_LSPI_Hh_
#include "lspi_abstract.h"
#include "approximate_linear_models.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
 template <
 typename feature_extractor
 >
 class lspi
 {
 public:
 typedef feature_extractor feature_extractor_type;
 typedef typename feature_extractor::state_type state_type;
 typedef typename feature_extractor::action_type action_type;
 explicit lspi(
 const feature_extractor& fe_
 ) : fe(fe_)
 {
 init();
 }
 lspi(
 )
 {
 init();
 }
 double get_discount (
 ) const { return discount; }
 void set_discount (
 double value
 )
 {
 // make sure requires clause is not broken
 DLIB_ASSERT(0 < value && value <= 1,
 "\t void lspi::set_discount(value)"
 << "\n\t invalid inputs were given to this function"
 << "\n\t value: " << value 
 );
 discount = value;
 }
 const feature_extractor& get_feature_extractor (
 ) const { return fe; }
 void be_verbose (
 )
 {
 verbose = true;
 }
 void be_quiet (
 )
 {
 verbose = false;
 }
 void set_epsilon (
 double eps_
 )
 {
 // make sure requires clause is not broken
 DLIB_ASSERT(eps_ > 0,
 "\t void lspi::set_epsilon(eps_)"
 << "\n\t invalid inputs were given to this function"
 << "\n\t eps_: " << eps_ 
 );
 eps = eps_;
 }
 double get_epsilon (
 ) const
 { 
 return eps;
 }
 void set_lambda (
 double lambda_ 
 ) 
 {
 // make sure requires clause is not broken
 DLIB_ASSERT(lambda_ >= 0,
 "\t void lspi::set_lambda(lambda_)"
 << "\n\t invalid inputs were given to this function"
 << "\n\t lambda_: " << lambda_ 
 );
 lambda = lambda_;
 }
 double get_lambda (
 ) const 
 { 
 return lambda; 
 }
 void set_max_iterations (
 unsigned long max_iter
 ) { max_iterations = max_iter; }
 unsigned long get_max_iterations (
 ) { return max_iterations; }
 template <typename vector_type>
 policy<feature_extractor> train (
 const vector_type& samples
 ) const
 {
 // make sure requires clause is not broken
 DLIB_ASSERT(samples.size() > 0,
 "\t policy lspi::train(samples)"
 << "\n\t invalid inputs were given to this function"
 );
 matrix<double,0,1> w(fe.num_features());
 w = 0;
 matrix<double,0,1> prev_w, b, f1, f2;
 matrix<double> A;
 double change; 
 unsigned long iter = 0;
 do
 {
 A = identity_matrix<double>(fe.num_features())*lambda;
 b = 0;
 for (unsigned long i = 0; i < samples.size(); ++i)
 {
 fe.get_features(samples[i].state, samples[i].action, f1);
 fe.get_features(samples[i].next_state, 
 fe.find_best_action(samples[i].next_state,w), 
 f2);
 A += f1*trans(f1 - discount*f2);
 b += f1*samples[i].reward;
 }
 prev_w = w;
 if (feature_extractor::force_last_weight_to_1)
 w = join_cols(pinv(colm(A,range(0,A.nc()-2)))*(b-colm(A,A.nc()-1)),mat(1.0));
 else
 w = pinv(A)*b;
 change = length(w-prev_w);
 ++iter;
 if (verbose)
 std::cout << "iteration: " << iter << "\tchange: " << change << std::endl;
 } while(change > eps && iter < max_iterations);
 return policy<feature_extractor>(w,fe);
 }
 private:
 void init()
 {
 lambda = 0.01;
 discount = 0.8;
 eps = 0.01;
 verbose = false;
 max_iterations = 100;
 }
 double lambda;
 double discount;
 double eps;
 bool verbose;
 unsigned long max_iterations;
 feature_extractor fe;
 };
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_LSPI_Hh_

AltStyle によって変換されたページ (->オリジナル) /