dlib C++ Library - function.h

// Copyright (C) 2007 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_SVm_FUNCTION
#define DLIB_SVm_FUNCTION
#include "function_abstract.h"
#include <cmath>
#include <limits>
#include <sstream>
#include "../matrix.h"
#include "../algs.h"
#include "../serialize.h"
#include "../rand.h"
#include "../statistics.h"
#include "kernel_matrix.h"
#include "kernel.h"
#include "sparse_kernel.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
 template <
 typename K
 >
 struct decision_function
 {
 typedef K kernel_type;
 typedef typename K::scalar_type scalar_type;
 typedef typename K::scalar_type result_type;
 typedef typename K::sample_type sample_type;
 typedef typename K::mem_manager_type mem_manager_type;
 typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type;
 typedef matrix<sample_type,0,1,mem_manager_type> sample_vector_type;
 scalar_vector_type alpha;
 scalar_type b;
 K kernel_function;
 sample_vector_type basis_vectors;
 decision_function (
 ) : b(0), kernel_function(K()) {}
 decision_function (
 const scalar_vector_type& alpha_,
 const scalar_type& b_,
 const K& kernel_function_,
 const sample_vector_type& basis_vectors_
 ) :
 alpha(alpha_),
 b(b_),
 kernel_function(kernel_function_),
 basis_vectors(basis_vectors_)
 {}
 result_type operator() (
 const sample_type& x
 ) const
 {
 result_type temp = 0;
 for (long i = 0; i < alpha.nr(); ++i)
 temp += alpha(i) * kernel_function(x,basis_vectors(i));
 return temp - b;
 }
 };
 template <
 typename K
 >
 void serialize (
 const decision_function<K>& item,
 std::ostream& out
 )
 {
 try
 {
 serialize(item.alpha, out);
 serialize(item.b, out);
 serialize(item.kernel_function, out);
 serialize(item.basis_vectors, out);
 }
 catch (serialization_error& e)
 { 
 throw serialization_error(e.info + "\n while serializing object of type decision_function"); 
 }
 }
 template <
 typename K
 >
 void deserialize (
 decision_function<K>& item,
 std::istream& in 
 )
 {
 try
 {
 deserialize(item.alpha, in);
 deserialize(item.b, in);
 deserialize(item.kernel_function, in);
 deserialize(item.basis_vectors, in);
 }
 catch (serialization_error& e)
 { 
 throw serialization_error(e.info + "\n while deserializing object of type decision_function"); 
 }
 }
// ----------------------------------------------------------------------------------------
 template <
 typename function_type
 >
 struct probabilistic_function
 {
 typedef typename function_type::scalar_type scalar_type;
 typedef typename function_type::result_type result_type;
 typedef typename function_type::sample_type sample_type;
 typedef typename function_type::mem_manager_type mem_manager_type;
 scalar_type alpha;
 scalar_type beta;
 function_type decision_funct;
 probabilistic_function (
 ) : alpha(0), beta(0), decision_funct(function_type()) {}
 probabilistic_function (
 const scalar_type a_,
 const scalar_type b_,
 const function_type& decision_funct_ 
 ) :
 alpha(a_),
 beta(b_),
 decision_funct(decision_funct_)
 {}
 result_type operator() (
 const sample_type& x
 ) const
 {
 result_type f = decision_funct(x);
 return 1/(1 + std::exp(alpha*f + beta));
 }
 };
 template <
 typename function_type 
 >
 void serialize (
 const probabilistic_function<function_type>& item,
 std::ostream& out
 )
 {
 try
 {
 serialize(item.alpha, out);
 serialize(item.beta, out);
 serialize(item.decision_funct, out);
 }
 catch (serialization_error& e)
 { 
 throw serialization_error(e.info + "\n while serializing object of type probabilistic_function"); 
 }
 }
 template <
 typename function_type
 >
 void deserialize (
 probabilistic_function<function_type>& item,
 std::istream& in 
 )
 {
 try
 {
 deserialize(item.alpha, in);
 deserialize(item.beta, in);
 deserialize(item.decision_funct, in);
 }
 catch (serialization_error& e)
 { 
 throw serialization_error(e.info + "\n while deserializing object of type probabilistic_function"); 
 }
 }
// ----------------------------------------------------------------------------------------
 template <
 typename K 
 >
 struct probabilistic_decision_function
 {
 typedef K kernel_type;
 typedef typename K::scalar_type scalar_type;
 typedef typename K::scalar_type result_type;
 typedef typename K::sample_type sample_type;
 typedef typename K::mem_manager_type mem_manager_type;
 scalar_type alpha;
 scalar_type beta;
 decision_function<K> decision_funct;
 probabilistic_decision_function (
 ) : alpha(0), beta(0), decision_funct(decision_function<K>()) {}
 probabilistic_decision_function (
 const probabilistic_function<decision_function<K> >& d
 ) : 
 alpha(d.alpha),
 beta(d.beta),
 decision_funct(d.decision_funct)
 {}
 probabilistic_decision_function (
 const scalar_type a_,
 const scalar_type b_,
 const decision_function<K>& decision_funct_ 
 ) :
 alpha(a_),
 beta(b_),
 decision_funct(decision_funct_)
 {}
 result_type operator() (
 const sample_type& x
 ) const
 {
 result_type f = decision_funct(x);
 return 1/(1 + std::exp(alpha*f + beta));
 }
 };
 template <
 typename K 
 >
 void serialize (
 const probabilistic_decision_function<K>& item,
 std::ostream& out
 )
 {
 try
 {
 serialize(item.alpha, out);
 serialize(item.beta, out);
 serialize(item.decision_funct, out);
 }
 catch (serialization_error& e)
 { 
 throw serialization_error(e.info + "\n while serializing object of type probabilistic_decision_function"); 
 }
 }
 template <
 typename K 
 >
 void deserialize (
 probabilistic_decision_function<K>& item,
 std::istream& in 
 )
 {
 try
 {
 deserialize(item.alpha, in);
 deserialize(item.beta, in);
 deserialize(item.decision_funct, in);
 }
 catch (serialization_error& e)
 { 
 throw serialization_error(e.info + "\n while deserializing object of type probabilistic_decision_function"); 
 }
 }
// ----------------------------------------------------------------------------------------
 template <
 typename K
 >
 class distance_function
 {
 public:
 typedef K kernel_type;
 typedef typename K::scalar_type scalar_type;
 typedef typename K::scalar_type result_type;
 typedef typename K::sample_type sample_type;
 typedef typename K::mem_manager_type mem_manager_type;
 typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type;
 typedef matrix<sample_type,0,1,mem_manager_type> sample_vector_type;
 distance_function (
 ) : b(0), kernel_function(K()) {}
 explicit distance_function (
 const kernel_type& kern
 ) : b(0), kernel_function(kern) {}
 distance_function (
 const kernel_type& kern,
 const sample_type& samp
 ) :
 alpha(ones_matrix<scalar_type>(1,1)),
 b(kern(samp,samp)),
 kernel_function(kern)
 {
 basis_vectors.set_size(1,1);
 basis_vectors(0) = samp;
 }
 distance_function (
 const decision_function<K>& f
 ) :
 alpha(f.alpha),
 b(trans(f.alpha)*kernel_matrix(f.kernel_function,f.basis_vectors)*f.alpha),
 kernel_function(f.kernel_function),
 basis_vectors(f.basis_vectors)
 {
 // make sure requires clause is not broken
 DLIB_ASSERT(f.alpha.size() == f.basis_vectors.size(),
 "\t distance_function(f)"
 << "\n\t The supplied decision_function is invalid."
 << "\n\t f.alpha.size(): " << f.alpha.size()
 << "\n\t f.basis_vectors.size(): " << f.basis_vectors.size()
 );
 }
 distance_function (
 const scalar_vector_type& alpha_,
 const scalar_type& b_,
 const K& kernel_function_,
 const sample_vector_type& basis_vectors_
 ) :
 alpha(alpha_),
 b(b_),
 kernel_function(kernel_function_),
 basis_vectors(basis_vectors_)
 {
 // make sure requires clause is not broken
 DLIB_ASSERT(alpha_.size() == basis_vectors_.size(),
 "\t distance_function()"
 << "\n\t The supplied arguments are invalid."
 << "\n\t alpha_.size(): " << alpha_.size()
 << "\n\t basis_vectors_.size(): " << basis_vectors_.size()
 );
 }
 distance_function (
 const scalar_vector_type& alpha_,
 const K& kernel_function_,
 const sample_vector_type& basis_vectors_
 ) :
 alpha(alpha_),
 b(trans(alpha)*kernel_matrix(kernel_function_,basis_vectors_)*alpha),
 kernel_function(kernel_function_),
 basis_vectors(basis_vectors_)
 {
 // make sure requires clause is not broken
 DLIB_ASSERT(alpha_.size() == basis_vectors_.size(),
 "\t distance_function()"
 << "\n\t The supplied arguments are invalid."
 << "\n\t alpha_.size(): " << alpha_.size()
 << "\n\t basis_vectors_.size(): " << basis_vectors_.size()
 );
 }
 const scalar_vector_type& get_alpha (
 ) const { return alpha; }
 const scalar_type& get_squared_norm (
 ) const { return b; }
 const K& get_kernel(
 ) const { return kernel_function; }
 const sample_vector_type& get_basis_vectors (
 ) const { return basis_vectors; }
 result_type operator() (
 const sample_type& x
 ) const
 {
 result_type temp = 0;
 for (long i = 0; i < alpha.nr(); ++i)
 temp += alpha(i) * kernel_function(x,basis_vectors(i));
 temp = b + kernel_function(x,x) - 2*temp; 
 if (temp > 0)
 return std::sqrt(temp);
 else
 return 0;
 }
 result_type operator() (
 const distance_function& x
 ) const
 {
 result_type temp = 0;
 for (long i = 0; i < alpha.nr(); ++i)
 for (long j = 0; j < x.alpha.nr(); ++j)
 temp += alpha(i)*x.alpha(j) * kernel_function(basis_vectors(i), x.basis_vectors(j));
 temp = b + x.b - 2*temp;
 if (temp > 0)
 return std::sqrt(temp);
 else
 return 0;
 }
 distance_function operator* (
 const scalar_type& val
 ) const
 {
 return distance_function(val*alpha,
 val*val*b,
 kernel_function,
 basis_vectors);
 }
 distance_function operator/ (
 const scalar_type& val
 ) const
 {
 return distance_function(alpha/val,
 b/val/val,
 kernel_function,
 basis_vectors);
 }
 distance_function operator+ (
 const distance_function& rhs
 ) const
 {
 // make sure requires clause is not broken
 DLIB_ASSERT(get_kernel() == rhs.get_kernel(),
 "\t distance_function distance_function::operator+()"
 << "\n\t You can only add two distance_functions together if they use the same kernel."
 );
 if (alpha.size() == 0)
 return rhs;
 else if (rhs.alpha.size() == 0)
 return *this;
 else
 return distance_function(join_cols(alpha, rhs.alpha),
 b + rhs.b + 2*trans(alpha)*kernel_matrix(kernel_function,basis_vectors,rhs.basis_vectors)*rhs.alpha,
 kernel_function,
 join_cols(basis_vectors, rhs.basis_vectors));
 }
 distance_function operator- (
 const distance_function& rhs
 ) const
 {
 // make sure requires clause is not broken
 DLIB_ASSERT(get_kernel() == rhs.get_kernel(),
 "\t distance_function distance_function::operator-()"
 << "\n\t You can only subtract two distance_functions if they use the same kernel."
 );
 if (alpha.size() == 0 && rhs.alpha.size() == 0)
 return distance_function(kernel_function);
 else if (alpha.size() != 0 && rhs.alpha.size() == 0)
 return *this;
 else if (alpha.size() == 0 && rhs.alpha.size() != 0)
 return -1*rhs;
 else
 return distance_function(join_cols(alpha, -rhs.alpha),
 b + rhs.b - 2*trans(alpha)*kernel_matrix(kernel_function,basis_vectors,rhs.basis_vectors)*rhs.alpha,
 kernel_function,
 join_cols(basis_vectors, rhs.basis_vectors));
 }
 private:
 scalar_vector_type alpha;
 scalar_type b;
 K kernel_function;
 sample_vector_type basis_vectors;
 };
 template <
 typename K
 >
 distance_function<K> operator* (
 const typename K::scalar_type& val,
 const distance_function<K>& df
 ) { return df*val; }
 template <
 typename K
 >
 void serialize (
 const distance_function<K>& item,
 std::ostream& out
 )
 {
 try
 {
 serialize(item.alpha, out);
 serialize(item.b, out);
 serialize(item.kernel_function, out);
 serialize(item.basis_vectors, out);
 }
 catch (serialization_error& e)
 { 
 throw serialization_error(e.info + "\n while serializing object of type distance_function"); 
 }
 }
 template <
 typename K
 >
 void deserialize (
 distance_function<K>& item,
 std::istream& in 
 )
 {
 try
 {
 deserialize(item.alpha, in);
 deserialize(item.b, in);
 deserialize(item.kernel_function, in);
 deserialize(item.basis_vectors, in);
 }
 catch (serialization_error& e)
 { 
 throw serialization_error(e.info + "\n while deserializing object of type distance_function"); 
 }
 }
// ----------------------------------------------------------------------------------------
 template <
 typename function_type,
 typename normalizer_type = vector_normalizer<typename function_type::sample_type>
 >
 struct normalized_function 
 {
 typedef typename function_type::result_type result_type;
 typedef typename function_type::sample_type sample_type;
 typedef typename function_type::mem_manager_type mem_manager_type;
 normalizer_type normalizer;
 function_type function;
 normalized_function (
 ){}
 const std::vector<result_type> get_labels(
 ) const { return function.get_labels(); }
 unsigned long number_of_classes (
 ) const { return function.number_of_classes(); }
 normalized_function (
 const vector_normalizer<sample_type>& normalizer_,
 const function_type& funct 
 ) : normalizer(normalizer_), function(funct) {}
 result_type operator() (
 const sample_type& x
 ) const { return function(normalizer(x)); }
 };
 template <
 typename function_type,
 typename normalizer_type 
 >
 void serialize (
 const normalized_function<function_type,normalizer_type>& item,
 std::ostream& out
 )
 {
 try
 {
 serialize(item.normalizer, out);
 serialize(item.function, out);
 }
 catch (serialization_error& e)
 { 
 throw serialization_error(e.info + "\n while serializing object of type normalized_function"); 
 }
 }
 template <
 typename function_type,
 typename normalizer_type 
 >
 void deserialize (
 normalized_function<function_type,normalizer_type>& item,
 std::istream& in 
 )
 {
 try
 {
 deserialize(item.normalizer, in);
 deserialize(item.function, in);
 }
 catch (serialization_error& e)
 { 
 throw serialization_error(e.info + "\n while deserializing object of type normalized_function"); 
 }
 }
// ----------------------------------------------------------------------------------------
 template <
 typename K
 >
 struct projection_function 
 {
 typedef K kernel_type;
 typedef typename K::scalar_type scalar_type;
 typedef typename K::sample_type sample_type;
 typedef typename K::mem_manager_type mem_manager_type;
 typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type;
 typedef matrix<scalar_type,0,0,mem_manager_type> scalar_matrix_type;
 typedef matrix<sample_type,0,1,mem_manager_type> sample_vector_type;
 typedef scalar_vector_type result_type;
 scalar_matrix_type weights;
 K kernel_function;
 sample_vector_type basis_vectors;
 projection_function (
 ) {}
 projection_function (
 const scalar_matrix_type& weights_,
 const K& kernel_function_,
 const sample_vector_type& basis_vectors_
 ) : weights(weights_), kernel_function(kernel_function_), basis_vectors(basis_vectors_) {}
 long out_vector_size (
 ) const { return weights.nr(); }
 const result_type& operator() (
 const sample_type& x
 ) const
 {
 // Run the x sample through all the basis functions we have and then
 // multiply it by the weights matrix and return the result. Note that
 // the temp vectors are here to avoid reallocating their memory every
 // time this function is called.
 temp1 = kernel_matrix(kernel_function, basis_vectors, x);
 temp2 = weights*temp1;
 return temp2;
 }
 private:
 mutable result_type temp1, temp2;
 };
 template <
 typename K
 >
 void serialize (
 const projection_function<K>& item,
 std::ostream& out
 )
 {
 try
 {
 serialize(item.weights, out);
 serialize(item.kernel_function, out);
 serialize(item.basis_vectors, out);
 }
 catch (serialization_error& e)
 { 
 throw serialization_error(e.info + "\n while serializing object of type projection_function"); 
 }
 }
 template <
 typename K
 >
 void deserialize (
 projection_function<K>& item,
 std::istream& in 
 )
 {
 try
 {
 deserialize(item.weights, in);
 deserialize(item.kernel_function, in);
 deserialize(item.basis_vectors, in);
 }
 catch (serialization_error& e)
 { 
 throw serialization_error(e.info + "\n while deserializing object of type projection_function"); 
 }
 }
// ----------------------------------------------------------------------------------------
 template <
 typename K,
 typename result_type_ = typename K::scalar_type 
 >
 struct multiclass_linear_decision_function
 {
 typedef result_type_ result_type;
 typedef K kernel_type;
 typedef typename K::scalar_type scalar_type;
 typedef typename K::sample_type sample_type;
 typedef typename K::mem_manager_type mem_manager_type;
 typedef matrix<scalar_type,0,1,mem_manager_type> scalar_vector_type;
 typedef matrix<scalar_type,0,0,mem_manager_type> scalar_matrix_type;
 // You are getting a compiler error on this line because you supplied a non-linear kernel
 // to the multiclass_linear_decision_function object. You have to use one of the linear 
 // kernels with this object.
 COMPILE_TIME_ASSERT((is_same_type<K, linear_kernel<sample_type> >::value ||
 is_same_type<K, sparse_linear_kernel<sample_type> >::value ));
 scalar_matrix_type weights;
 scalar_vector_type b;
 std::vector<result_type> labels; 
 const std::vector<result_type>& get_labels(
 ) const { return labels; }
 unsigned long number_of_classes (
 ) const { return labels.size(); }
 std::pair<result_type, scalar_type> predict (
 const sample_type& x
 ) const
 {
 // make sure requires clause is not broken
 DLIB_ASSERT(weights.size() > 0 && 
 weights.nr() == (long)number_of_classes() &&
 weights.nr() == b.size(),
 "\t pair<result_type,scalar_type> multiclass_linear_decision_function::predict(x)"
 << "\n\t This object must be properly initialized before you can use it."
 << "\n\t weights.size(): " << weights.size()
 << "\n\t weights.nr(): " << weights.nr()
 << "\n\t number_of_classes(): " << number_of_classes()
 );
 // Rather than doing something like, best_idx = index_of_max(weights*x-b)
 // we do the following somewhat more complex thing because this supports
 // both sparse and dense samples.
 scalar_type best_val = dot(rowm(weights,0),x) - b(0);
 unsigned long best_idx = 0;
 for (unsigned long i = 1; i < labels.size(); ++i)
 {
 scalar_type temp = dot(rowm(weights,i),x) - b(i);
 if (temp > best_val)
 {
 best_val = temp;
 best_idx = i;
 }
 }
 return std::make_pair(labels[best_idx], best_val);
 }
 result_type operator() (
 const sample_type& x
 ) const
 {
 // make sure requires clause is not broken
 DLIB_ASSERT(weights.size() > 0 && 
 weights.nr() == (long)number_of_classes() &&
 weights.nr() == b.size(),
 "\t result_type multiclass_linear_decision_function::operator()(x)"
 << "\n\t This object must be properly initialized before you can use it."
 << "\n\t weights.size(): " << weights.size()
 << "\n\t weights.nr(): " << weights.nr()
 << "\n\t number_of_classes(): " << number_of_classes()
 );
 return predict(x).first;
 }
 };
 template <
 typename K,
 typename result_type_
 >
 void serialize (
 const multiclass_linear_decision_function<K,result_type_>& item,
 std::ostream& out
 )
 {
 try
 {
 serialize(item.weights, out);
 serialize(item.b, out);
 serialize(item.labels, out);
 }
 catch (serialization_error& e)
 { 
 throw serialization_error(e.info + "\n while serializing object of type multiclass_linear_decision_function"); 
 }
 }
 template <
 typename K,
 typename result_type_
 >
 void deserialize (
 multiclass_linear_decision_function<K,result_type_>& item,
 std::istream& in 
 )
 {
 try
 {
 deserialize(item.weights, in);
 deserialize(item.b, in);
 deserialize(item.labels, in);
 }
 catch (serialization_error& e)
 { 
 throw serialization_error(e.info + "\n while deserializing object of type multiclass_linear_decision_function"); 
 }
 }
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_SVm_FUNCTION

AltStyle によって変換されたページ (->オリジナル) /