dlib C++ Library - bsp.cpp

// Copyright (C) 2012 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#include <dlib/bsp.h>
#include <dlib/threads.h>
#include <dlib/pipe.h>
#include <dlib/matrix.h>
#include "tester.h"
namespace 
{
 using namespace test;
 using namespace dlib;
 using namespace std;
 logger dlog("test.bsp");
 template <typename funct>
 struct callfunct_helper
 {
 callfunct_helper (
 funct f_,
 int port_,
 bool& error_occurred_
 ) :f(f_), port(port_), error_occurred(error_occurred_) {}
 funct f;
 int port;
 bool& error_occurred;
 void operator() (
 ) const
 {
 try
 {
 bsp_listen(port, f);
 }
 catch (exception& e)
 {
 dlog << LERROR << "error calling bsp_listen(): " << e.what();
 error_occurred = true;
 }
 }
 };
 template <typename funct>
 callfunct_helper<funct> callfunct(funct f, int port, bool& error_occurred)
 {
 return callfunct_helper<funct>(f,port,error_occurred);
 }
// ----------------------------------------------------------------------------------------
 template <typename funct>
 struct callfunct_helper_pn
 {
 callfunct_helper_pn (
 funct f_,
 int port_,
 bool& error_occurred_,
 dlib::pipe<unsigned short>& port_pipe_
 ) :f(f_), port(port_), error_occurred(error_occurred_), port_pipe(port_pipe_) {}
 funct f;
 int port;
 bool& error_occurred;
 dlib::pipe<unsigned short>& port_pipe;
 struct helper
 {
 helper (
 dlib::pipe<unsigned short>& port_pipe_
 ) : port_pipe(port_pipe_) {}
 dlib::pipe<unsigned short>& port_pipe;
 void operator() (unsigned short p) { port_pipe.enqueue(p); }
 };
 void operator() (
 ) const
 {
 try
 {
 bsp_listen_dynamic_port(port, helper(port_pipe), f);
 }
 catch (exception& e)
 {
 dlog << LERROR << "error calling bsp_listen_dynamic_port(): " << e.what();
 error_occurred = true;
 }
 }
 };
 template <typename funct>
 callfunct_helper_pn<funct> callfunct(funct f, int port, bool& error_occurred, dlib::pipe<unsigned short>& port_pipe)
 {
 return callfunct_helper_pn<funct>(f,port,error_occurred,port_pipe);
 }
// ----------------------------------------------------------------------------------------
 void sum_array_driver (
 bsp_context& obj,
 const std::vector<int>& v,
 int& result
 )
 {
 obj.broadcast(v);
 result = 0;
 int val;
 while(obj.try_receive(val))
 result += val;
 }
 void sum_array_other (
 bsp_context& obj
 )
 {
 std::vector<int> v;
 obj.receive(v);
 int sum = 0;
 for (unsigned long i = 0; i < v.size(); ++i)
 sum += v[i];
 obj.send(sum, 0);
 }
 void dotest1()
 {
 dlog << LINFO << "start dotest1()";
 print_spinner();
 bool error_occurred = false;
 {
 thread_function t1(callfunct(sum_array_other, 12345, error_occurred));
 thread_function t2(callfunct(sum_array_other, 12346, error_occurred));
 thread_function t3(callfunct(sum_array_other, 12347, error_occurred));
 std::vector<int> v;
 int true_value = 0;
 for (int i = 0; i < 10; ++i)
 {
 v.push_back(i);
 true_value += i;
 }
 // wait a little bit for the threads to start up
 dlib::sleep(200);
 try
 {
 int result;
 std::vector<network_address> hosts;
 hosts.push_back("127.0.0.1:12345");
 hosts.push_back("localhost:12346");
 hosts.push_back("127.0.0.1:12347");
 bsp_connect(hosts, sum_array_driver, dlib::ref(v), dlib::ref(result));
 dlog << LINFO << "result: "<< result;
 dlog << LINFO << "should be: "<< 3*true_value; 
 DLIB_TEST(result == 3*true_value);
 }
 catch (std::exception& e)
 {
 dlog << LERROR << "error during bsp_context: " << e.what();
 DLIB_TEST(false);
 }
 }
 DLIB_TEST(error_occurred == false);
 }
// ----------------------------------------------------------------------------------------
 template <unsigned long id>
 void test2_job(bsp_context& obj)
 {
 if (obj.node_id() == id)
 dlib::sleep(100);
 }
 template <unsigned long id>
 void dotest2()
 {
 dlog << LINFO << "start dotest2()";
 print_spinner();
 bool error_occurred = false;
 {
 thread_function t1(callfunct(test2_job<id>, 12345, error_occurred));
 thread_function t2(callfunct(test2_job<id>, 12346, error_occurred));
 thread_function t3(callfunct(test2_job<id>, 12347, error_occurred));
 // wait a little bit for the threads to start up
 dlib::sleep(200);
 try
 {
 std::vector<network_address> hosts;
 hosts.push_back("127.0.0.1:12345");
 hosts.push_back("127.0.0.1:12346");
 hosts.push_back("127.0.0.1:12347");
 bsp_connect(hosts, test2_job<id>);
 }
 catch (std::exception& e)
 {
 dlog << LERROR << "error during bsp_context: " << e.what();
 DLIB_TEST(false);
 }
 }
 DLIB_TEST(error_occurred == false);
 }
// ----------------------------------------------------------------------------------------
 void test3_job_driver(bsp_context& obj, int& result)
 {
 obj.broadcast(obj.node_id());
 int accum = 0;
 int temp = 0;
 while(obj.try_receive(temp))
 accum += temp;
 // send to node 1 so it can sum everything
 if (obj.node_id() != 1)
 obj.send(accum, 1);
 while(obj.try_receive(temp))
 accum += temp;
 // Now hop the accum values along the nodes until the value from node 1 gets to
 // node 0.
 obj.send(accum, (obj.node_id()+1)%obj.number_of_nodes());
 obj.receive(accum);
 obj.send(accum, (obj.node_id()+1)%obj.number_of_nodes());
 obj.receive(accum);
 obj.send(accum, (obj.node_id()+1)%obj.number_of_nodes());
 obj.receive(accum);
 // this whole block is a noop since it doesn't end up doing anything.
 for (int k = 0; k < 100; ++k)
 {
 dlog << LINFO << "k: " << k;
 for (int i = 0; i < 4; ++i)
 {
 obj.send(accum, (obj.node_id()+1)%obj.number_of_nodes());
 obj.receive(accum);
 }
 }
 dlog << LINFO << "TERMINATE";
 if (obj.node_id() == 0)
 result = accum;
 }
 void test3_job(bsp_context& obj)
 {
 int junk;
 test3_job_driver(obj, junk);
 }
 void dotest3()
 {
 dlog << LINFO << "start dotest3()";
 print_spinner();
 bool error_occurred = false;
 {
 dlib::pipe<unsigned short> ports(5);
 thread_function t1(callfunct(test3_job, 12345, error_occurred, ports));
 thread_function t2(callfunct(test3_job, 0, error_occurred, ports));
 thread_function t3(callfunct(test3_job, 12347, error_occurred, ports));
 try
 {
 std::vector<network_address> hosts;
 unsigned short port;
 ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port;
 ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port;
 ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port;
 int result = 0;
 const int expected = 1+2+3 + 0+2+3 + 0+1+3 + 0+1+2;
 bsp_connect(hosts, test3_job_driver, dlib::ref(result));
 dlog << LINFO << "result: " << result;
 dlog << LINFO << "should be: " << expected;
 DLIB_TEST(result == expected);
 }
 catch (std::exception& e)
 {
 dlog << LERROR << "error during bsp_context: " << e.what();
 DLIB_TEST(false);
 }
 }
 DLIB_TEST(error_occurred == false);
 }
// ----------------------------------------------------------------------------------------
 void test4_job_driver(bsp_context& obj, int& result)
 {
 obj.broadcast(obj.node_id());
 int accum = 0;
 int temp = 0;
 while(obj.try_receive(temp))
 accum += temp;
 // send to node 1 so it can sum everything
 if (obj.node_id() != 1)
 obj.send(accum, 1);
 while(obj.try_receive(temp))
 accum += temp;
 // Now hop the accum values along the nodes until the value from node 1 gets to
 // node 0.
 obj.send(accum, (obj.node_id()+1)%obj.number_of_nodes());
 obj.receive(accum);
 obj.send(accum, (obj.node_id()+1)%obj.number_of_nodes());
 obj.receive(accum);
 obj.send(accum, (obj.node_id()+1)%obj.number_of_nodes());
 obj.receive(accum);
 // this whole block is a noop since it doesn't end up doing anything.
 for (int k = 0; k < 40; ++k)
 {
 dlog << LINFO << "k: " << k;
 for (int i = 0; i < 4; ++i)
 {
 obj.send(accum, (obj.node_id()+1)%obj.number_of_nodes());
 obj.receive(accum);
 obj.receive();
 }
 }
 dlog << LINFO << "TERMINATE";
 if (obj.node_id() == 0)
 result = accum;
 }
 void test4_job(bsp_context& obj)
 {
 int junk;
 test4_job_driver(obj, junk);
 }
 void dotest4()
 {
 dlog << LINFO << "start dotest4()";
 print_spinner();
 bool error_occurred = false;
 {
 dlib::pipe<unsigned short> ports(5);
 thread_function t1(callfunct(test4_job, 0, error_occurred, ports));
 thread_function t2(callfunct(test4_job, 0, error_occurred, ports));
 thread_function t3(callfunct(test4_job, 0, error_occurred, ports));
 try
 {
 std::vector<network_address> hosts;
 unsigned short port;
 ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port;
 ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port;
 ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port;
 int result = 0;
 const int expected = 1+2+3 + 0+2+3 + 0+1+3 + 0+1+2;
 bsp_connect(hosts, test4_job_driver, dlib::ref(result));
 dlog << LINFO << "result: " << result;
 dlog << LINFO << "should be: " << expected;
 DLIB_TEST(result == expected);
 }
 catch (std::exception& e)
 {
 dlog << LERROR << "error during bsp_context: " << e.what();
 DLIB_TEST(false);
 }
 }
 DLIB_TEST(error_occurred == false);
 }
// ----------------------------------------------------------------------------------------
 void test5_job(
 bsp_context& ,
 int& val
 )
 {
 val = 25;
 }
 void dotest5()
 {
 dlog << LINFO << "start dotest5()";
 print_spinner();
 std::vector<network_address> hosts;
 int val = 0;
 bsp_connect(hosts, test5_job, dlib::ref(val));
 DLIB_TEST(val == 25);
 }
// ----------------------------------------------------------------------------------------
 double f ( double x)
 {
 return std::pow(x-2.0, 2.0);
 }
 void bsp_job_node_0 (
 bsp_context& context,
 double& min_value,
 double& optimal_x
 )
 {
 double left = -100;
 double right = 100;
 min_value = std::numeric_limits<double>::infinity();
 double interval_width = std::abs(right-left);
 // This is doing a BSP based grid search for the minimum of f(). Here we 
 // do 100 iterations where we keep shrinking the grid size.
 for (int i = 0; i < 100; ++i)
 {
 context.broadcast(left);
 context.broadcast(right);
 for (unsigned int k = 1; k < context.number_of_nodes(); ++k)
 {
 std::pair<double,double> val;
 context.receive(val);
 if (val.second < min_value)
 {
 min_value = val.second;
 optimal_x = val.first;
 }
 }
 interval_width *= 0.5;
 left = optimal_x - interval_width/2;
 right = optimal_x + interval_width/2;
 }
 }
 void bsp_job_other_nodes (
 bsp_context& context
 )
 {
 double left, right;
 while (context.try_receive(left))
 {
 context.receive(right);
 const double l = (context.node_id()-1)/(context.number_of_nodes()-1.0);
 const double r = context.node_id() /(context.number_of_nodes()-1.0);
 const double width = right-left;
 matrix<double> values_to_check = linspace(left +l*width, left + r*width, 100);
 double best_x = 0;
 double best_val = std::numeric_limits<double>::infinity();
 for (long j = 0; j < values_to_check.size(); ++j)
 {
 double temp = f(values_to_check(j));
 if (temp < best_val)
 {
 best_val = temp;
 best_x = values_to_check(j);
 }
 }
 context.send(make_pair(best_x, best_val), 0);
 }
 }
 void dotest6()
 {
 dlog << LINFO << "start dotest6()";
 print_spinner();
 bool error_occurred = false;
 {
 dlib::pipe<unsigned short> ports(5);
 thread_function t1(callfunct(bsp_job_other_nodes, 0, error_occurred, ports));
 thread_function t2(callfunct(bsp_job_other_nodes, 0, error_occurred, ports));
 thread_function t3(callfunct(bsp_job_other_nodes, 0, error_occurred, ports));
 try
 {
 std::vector<network_address> hosts;
 unsigned short port;
 ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port;
 ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port;
 ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port;
 double min_value = 10, optimal_x = 0;
 bsp_connect(hosts, bsp_job_node_0, dlib::ref(min_value), dlib::ref(optimal_x));
 dlog << LINFO << "min_value: " << min_value;
 dlog << LINFO << "optimal_x: " << optimal_x;
 DLIB_TEST(std::abs(min_value - 0) < 1e-14);
 DLIB_TEST(std::abs(optimal_x - 2) < 1e-14);
 }
 catch (std::exception& e)
 {
 dlog << LERROR << "error during bsp_context: " << e.what();
 DLIB_TEST(false);
 }
 }
 DLIB_TEST(error_occurred == false);
 }
// ----------------------------------------------------------------------------------------
 class bsp_tester : public tester
 {
 public:
 bsp_tester (
 ) :
 tester ("test_bsp",
 "Runs tests on the BSP components.")
 {}
 void perform_test (
 )
 {
 for (int i = 0; i < 3; ++i)
 {
 dotest1();
 dotest2<0>();
 dotest2<1>();
 dotest2<2>();
 dotest3();
 dotest4();
 dotest5();
 dotest6();
 }
 }
 } a;
}

AltStyle によって変換されたページ (->オリジナル) /