dlib C++ Library - find_max_factor_graph_viterbi.cpp

// Copyright (C) 2011 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#include <sstream>
#include <string>
#include <cstdlib>
#include <ctime>
#include <dlib/optimization.h>
#include <dlib/rand.h>
#include "tester.h"
namespace 
{
 using namespace test;
 using namespace dlib;
 using namespace std;
 logger dlog("test.find_max_factor_graph_viterbi");
// ----------------------------------------------------------------------------------------
 dlib::rand rnd;
// ----------------------------------------------------------------------------------------
 template <
 unsigned long O,
 unsigned long NS,
 unsigned long num_nodes,
 bool all_negative 
 >
 class map_problem
 {
 public:
 unsigned long order() const { return O; }
 unsigned long num_states() const { return NS; }
 map_problem()
 {
 data = randm(number_of_nodes(),(long)std::pow(num_states(),(double)order()+1), rnd);
 if (all_negative)
 data = -data;
 }
 unsigned long number_of_nodes (
 ) const
 {
 return num_nodes;
 }
 template <
 typename EXP 
 >
 double factor_value (
 unsigned long node_id,
 const matrix_exp<EXP>& node_states
 ) const
 {
 if (node_states.size() == 1)
 return data(node_id, node_states(0));
 else if (node_states.size() == 2)
 return data(node_id, node_states(0) + node_states(1)*NS);
 else if (node_states.size() == 3)
 return data(node_id, (node_states(0) + node_states(1)*NS)*NS + node_states(2));
 else 
 return data(node_id, ((node_states(0) + node_states(1)*NS)*NS + node_states(2))*NS + node_states(3));
 }
 matrix<double> data;
 };
// ----------------------------------------------------------------------------------------
 template <
 typename map_problem
 >
 void brute_force_find_max_factor_graph_viterbi (
 const map_problem& prob,
 std::vector<unsigned long>& map_assignment
 )
 {
 using namespace dlib::impl;
 const int order = prob.order();
 const int num_states = prob.num_states();
 map_assignment.resize(prob.number_of_nodes());
 double best_score = -std::numeric_limits<double>::infinity();
 matrix<unsigned long,1,0> node_states;
 node_states.set_size(prob.number_of_nodes());
 node_states = 0;
 do
 {
 double score = 0;
 for (unsigned long i = 0; i < prob.number_of_nodes(); ++i)
 {
 score += prob.factor_value(i, (colm(node_states,range(i,i-std::min<int>(order,i)))));
 }
 if (score > best_score)
 {
 for (unsigned long i = 0; i < map_assignment.size(); ++i)
 map_assignment[i] = node_states(i);
 best_score = score;
 }
 } while(advance_state(node_states,num_states));
 }
// ----------------------------------------------------------------------------------------
 template <
 unsigned long order,
 unsigned long num_states,
 unsigned long num_nodes,
 bool all_negative
 >
 void do_test_()
 {
 dlog << LINFO << "order: "<< order 
 << " num_states: " << num_states
 << " num_nodes: " << num_nodes
 << " all_negative: " << all_negative;
 for (int i = 0; i < 25; ++i)
 {
 print_spinner();
 map_problem<order,num_states,num_nodes,all_negative> prob;
 std::vector<unsigned long> assign, assign2;
 brute_force_find_max_factor_graph_viterbi(prob, assign);
 find_max_factor_graph_viterbi(prob, assign2);
 DLIB_TEST_MSG(mat(assign) == mat(assign2),
 trans(mat(assign))
 << trans(mat(assign2))
 );
 }
 }
 template <
 unsigned long order,
 unsigned long num_states,
 unsigned long num_nodes
 >
 void do_test()
 {
 do_test_<order,num_states,num_nodes,false>();
 }
 template <
 unsigned long order,
 unsigned long num_states,
 unsigned long num_nodes
 >
 void do_test_negative()
 {
 do_test_<order,num_states,num_nodes,true>();
 }
// ----------------------------------------------------------------------------------------
 class test_find_max_factor_graph_viterbi : public tester
 {
 public:
 test_find_max_factor_graph_viterbi (
 ) :
 tester ("test_find_max_factor_graph_viterbi",
 "Runs tests on the find_max_factor_graph_viterbi routine.")
 {}
 void perform_test (
 )
 {
 do_test<1,3,0>();
 do_test<1,3,1>();
 do_test<1,3,2>();
 do_test<0,3,2>();
 do_test_negative<0,3,2>();
 do_test<1,3,8>();
 do_test<2,3,7>();
 do_test_negative<2,3,7>();
 do_test<3,3,8>();
 do_test<4,3,8>();
 do_test_negative<4,3,8>();
 do_test<0,3,8>();
 do_test<4,3,1>();
 do_test<4,3,0>();
 do_test<3,2,1>();
 do_test<3,2,0>();
 do_test<3,2,2>();
 do_test<2,2,1>();
 do_test_negative<3,2,1>();
 do_test_negative<3,2,0>();
 do_test_negative<3,2,2>();
 do_test_negative<2,2,1>();
 do_test<0,3,0>();
 do_test<1,2,8>();
 do_test<2,2,7>();
 do_test<3,2,8>();
 do_test<0,2,8>();
 do_test<1,1,8>();
 do_test<2,1,8>();
 do_test<3,1,8>();
 do_test<0,1,8>();
 }
 } a;
}

AltStyle によって変換されたページ (->オリジナル) /