dlib C++ Library - bsp.cpp

// Copyright (C) 2012 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_BSP_CPph_
#define DLIB_BSP_CPph_
#include "bsp.h"
#include <memory>
#include <stack>
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
namespace dlib
{
 namespace impl1
 {
 void connect_all (
 map_id_to_con& cons,
 const std::vector<network_address>& hosts,
 unsigned long node_id
 )
 {
 cons.clear();
 for (unsigned long i = 0; i < hosts.size(); ++i)
 {
 std::unique_ptr<bsp_con> con(new bsp_con(hosts[i]));
 dlib::serialize(node_id, con->stream); // tell the other end our node_id
 unsigned long id = i+1;
 cons.add(id, con);
 }
 }
 void connect_all_hostinfo (
 map_id_to_con& cons,
 const std::vector<hostinfo>& hosts,
 unsigned long node_id,
 std::string& error_string 
 )
 {
 cons.clear();
 for (unsigned long i = 0; i < hosts.size(); ++i)
 {
 try
 {
 std::unique_ptr<bsp_con> con(new bsp_con(hosts[i].addr));
 dlib::serialize(node_id, con->stream); // tell the other end our node_id
 con->stream.flush();
 unsigned long id = hosts[i].node_id;
 cons.add(id, con);
 }
 catch (std::exception&)
 {
 std::ostringstream sout;
 sout << "Could not connect to " << hosts[i].addr;
 error_string = sout.str();
 break;
 }
 }
 }
 void send_out_connection_orders (
 map_id_to_con& cons,
 const std::vector<network_address>& hosts
 )
 {
 // tell everyone their node ids
 cons.reset();
 while (cons.move_next())
 {
 dlib::serialize(cons.element().key(), cons.element().value()->stream);
 }
 // now tell them who to connect to
 std::vector<hostinfo> targets; 
 for (unsigned long i = 0; i < hosts.size(); ++i)
 {
 hostinfo info(hosts[i], i+1);
 dlib::serialize(targets, cons[info.node_id]->stream);
 targets.push_back(info);
 // let the other host know how many incoming connections to expect
 const unsigned long num = hosts.size()-targets.size();
 dlib::serialize(num, cons[info.node_id]->stream);
 cons[info.node_id]->stream.flush();
 }
 }
 // ------------------------------------------------------------------------------------
 }
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
 namespace impl2
 {
 // These control bytes are sent before each message between nodes. Note that many
 // of these are only sent between the control node (node 0) and the other nodes.
 // This is because the controller node is responsible for handling the
 // synchronization that needs to happen when all nodes block on calls to
 // receive_data()
 // at the same time.
 // denotes a normal content message.
 const static char MESSAGE_HEADER = 0; 
 // sent to the controller node when someone receives a message via receive_data().
 const static char GOT_MESSAGE = 1; 
 // sent to the controller node when someone sends a message via send().
 const static char SENT_MESSAGE = 2; 
 // sent to the controller node when someone enters a call to receive_data()
 const static char IN_WAITING_STATE = 3; 
 // broadcast when a node terminates itself. 
 const static char NODE_TERMINATE = 5; 
 // broadcast by the controller node when it determines that all nodes are blocked
 // on calls to receive_data() and there aren't any messages in flight. This is also
 // what makes us go to the next epoch.
 const static char SEE_ALL_IN_WAITING_STATE = 6; 
 // This isn't ever transmitted between nodes. It is used internally to indicate
 // that an error occurred.
 const static char READ_ERROR = 7;
 // ------------------------------------------------------------------------------------
 void read_thread (
 impl1::bsp_con* con,
 unsigned long node_id,
 unsigned long sender_id,
 impl1::thread_safe_message_queue& msg_buffer
 )
 {
 try
 {
 while(true)
 {
 impl1::msg_data msg;
 deserialize(msg.msg_type, con->stream);
 msg.sender_id = sender_id;
 if (msg.msg_type == MESSAGE_HEADER)
 {
 msg.data.reset(new std::vector<char>);
 deserialize(msg.epoch, con->stream);
 deserialize(*msg.data, con->stream);
 }
 msg_buffer.push_and_consume(msg);
 if (msg.msg_type == NODE_TERMINATE)
 break;
 }
 }
 catch (std::exception& e)
 {
 impl1::msg_data msg;
 msg.data.reset(new std::vector<char>);
 vectorstream sout(*msg.data);
 sout << "An exception was thrown while attempting to receive a message from processing node " << sender_id << ".\n";
 sout << " Sending processing node address: " << con->con->get_foreign_ip() << ":" << con->con->get_foreign_port() << std::endl;
 sout << " Receiving processing node address: " << con->con->get_local_ip() << ":" << con->con->get_local_port() << std::endl;
 sout << " Receiving processing node id: " << node_id << std::endl;
 sout << " Error message in the exception: " << e.what() << std::endl;
 msg.sender_id = sender_id;
 msg.msg_type = READ_ERROR;
 msg_buffer.push_and_consume(msg);
 }
 catch (...)
 {
 impl1::msg_data msg;
 msg.data.reset(new std::vector<char>);
 vectorstream sout(*msg.data);
 sout << "An exception was thrown while attempting to receive a message from processing node " << sender_id << ".\n";
 sout << " Sending processing node address: " << con->con->get_foreign_ip() << ":" << con->con->get_foreign_port() << std::endl;
 sout << " Receiving processing node address: " << con->con->get_local_ip() << ":" << con->con->get_local_port() << std::endl;
 sout << " Receiving processing node id: " << node_id << std::endl;
 msg.sender_id = sender_id;
 msg.msg_type = READ_ERROR;
 msg_buffer.push_and_consume(msg);
 }
 }
 // ------------------------------------------------------------------------------------
 }
 
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// IMPLEMENTATION OF bsp_context OBJECT MEMBERS
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
 void bsp_context::
 close_all_connections_gracefully(
 )
 {
 if (node_id() != 0)
 {
 _cons.reset();
 while (_cons.move_next())
 {
 // tell the other end that we are intentionally dropping the connection
 serialize(impl2::NODE_TERMINATE,_cons.element().value()->stream);
 _cons.element().value()->stream.flush();
 }
 }
 impl1::msg_data msg;
 // now wait for all the other nodes to terminate
 while (num_terminated_nodes < _cons.size() )
 {
 if (node_id() == 0 && num_waiting_nodes + num_terminated_nodes == _cons.size() && outstanding_messages == 0)
 {
 num_waiting_nodes = 0;
 broadcast_byte(impl2::SEE_ALL_IN_WAITING_STATE);
 ++current_epoch;
 }
 if (!msg_buffer.pop(msg))
 throw dlib::socket_error("Error reading from msg_buffer in dlib::bsp_context.");
 if (msg.msg_type == impl2::NODE_TERMINATE)
 {
 ++num_terminated_nodes;
 _cons[msg.sender_id]->terminated = true;
 }
 else if (msg.msg_type == impl2::READ_ERROR)
 {
 throw dlib::socket_error(msg.data_to_string());
 }
 else if (msg.msg_type == impl2::MESSAGE_HEADER)
 {
 throw dlib::socket_error("A BSP node received a message after it has terminated.");
 }
 else if (msg.msg_type == impl2::GOT_MESSAGE)
 {
 --num_waiting_nodes;
 --outstanding_messages;
 }
 else if (msg.msg_type == impl2::SENT_MESSAGE)
 {
 ++outstanding_messages;
 }
 else if (msg.msg_type == impl2::IN_WAITING_STATE)
 {
 ++num_waiting_nodes;
 }
 }
 if (node_id() == 0)
 {
 _cons.reset();
 while (_cons.move_next())
 {
 // tell the other end that we are intentionally dropping the connection
 serialize(impl2::NODE_TERMINATE,_cons.element().value()->stream);
 _cons.element().value()->stream.flush();
 }
 if (outstanding_messages != 0)
 {
 std::ostringstream sout;
 sout << "A BSP job was allowed to terminate before all sent messages have been received.\n";
 sout << "There are at least " << outstanding_messages << " messages still in flight. Make sure all sent messages\n";
 sout << "have a corresponding call to receive().";
 throw dlib::socket_error(sout.str());
 }
 }
 }
// ----------------------------------------------------------------------------------------
 bsp_context::
 ~bsp_context()
 {
 _cons.reset();
 while (_cons.move_next())
 {
 _cons.element().value()->con->shutdown();
 }
 msg_buffer.disable();
 // this will wait for all the threads to terminate
 threads.clear();
 }
// ----------------------------------------------------------------------------------------
 bsp_context::
 bsp_context(
 unsigned long node_id_,
 impl1::map_id_to_con& cons_
 ) :
 outstanding_messages(0),
 num_waiting_nodes(0),
 num_terminated_nodes(0),
 current_epoch(1),
 _cons(cons_),
 _node_id(node_id_)
 {
 // spawn a bunch of read threads, one for each connection
 _cons.reset();
 while (_cons.move_next())
 {
 std::unique_ptr<thread_function> ptr(new thread_function(&impl2::read_thread,
 _cons.element().value().get(),
 _node_id,
 _cons.element().key(),
 ref(msg_buffer)));
 threads.push_back(ptr);
 }
 }
// ----------------------------------------------------------------------------------------
 bool bsp_context::
 receive_data (
 std::shared_ptr<std::vector<char> >& item,
 unsigned long& sending_node_id
 ) 
 {
 notify_control_node(impl2::IN_WAITING_STATE);
 while (true)
 {
 // If there aren't any nodes left to give us messages then return right now.
 // We need to check the msg_buffer size to make sure there aren't any
 // unprocessed message there. Recall that this can happen because status
 // messages always jump to the front of the message buffer. So we might have
 // learned about the node terminations before processing their messages for us.
 if (num_terminated_nodes == _cons.size() && msg_buffer.size() == 0)
 {
 return false;
 }
 // if all running nodes are currently blocking forever on receive_data()
 if (node_id() == 0 && outstanding_messages == 0 && num_terminated_nodes + num_waiting_nodes == _cons.size())
 {
 num_waiting_nodes = 0;
 broadcast_byte(impl2::SEE_ALL_IN_WAITING_STATE);
 // Note that the reason we have this epoch counter is so we can tell if a
 // sent message is from before or after one of these "all nodes waiting"
 // synchronization events. If we didn't have the epoch count we would have
 // a race condition where one node gets the SEE_ALL_IN_WAITING_STATE
 // message before others and then sends out a message to another node
 // before that node got the SEE_ALL_IN_WAITING_STATE message. Then that
 // node would think the normal message came before SEE_ALL_IN_WAITING_STATE
 // which would be bad.
 ++current_epoch;
 return false;
 }
 impl1::msg_data data;
 if (!msg_buffer.pop(data, current_epoch))
 throw dlib::socket_error("Error reading from msg_buffer in dlib::bsp_context.");
 switch(data.msg_type)
 {
 case impl2::MESSAGE_HEADER: {
 item = data.data;
 sending_node_id = data.sender_id;
 notify_control_node(impl2::GOT_MESSAGE);
 return true;
 } break;
 case impl2::IN_WAITING_STATE: {
 ++num_waiting_nodes;
 } break;
 case impl2::GOT_MESSAGE: {
 --outstanding_messages;
 --num_waiting_nodes;
 } break;
 case impl2::SENT_MESSAGE: {
 ++outstanding_messages;
 } break;
 case impl2::NODE_TERMINATE: {
 ++num_terminated_nodes;
 _cons[data.sender_id]->terminated = true;
 } break;
 case impl2::SEE_ALL_IN_WAITING_STATE: {
 ++current_epoch;
 return false;
 } break;
 case impl2::READ_ERROR: {
 throw dlib::socket_error(data.data_to_string());
 } break;
 default: {
 throw dlib::socket_error("Unknown message received by dlib::bsp_context");
 } break;
 } // end switch()
 } // end while (true)
 }
// ----------------------------------------------------------------------------------------
 void bsp_context::
 notify_control_node (
 char val
 )
 {
 if (node_id() == 0)
 {
 using namespace impl2;
 switch(val)
 {
 case SENT_MESSAGE: {
 ++outstanding_messages;
 } break;
 case GOT_MESSAGE: {
 --outstanding_messages;
 } break;
 case IN_WAITING_STATE: {
 // nothing to do in this case
 } break;
 default:
 DLIB_CASSERT(false,"This should never happen");
 }
 }
 else
 {
 serialize(val, _cons[0]->stream);
 _cons[0]->stream.flush();
 }
 }
// ----------------------------------------------------------------------------------------
 void bsp_context::
 broadcast_byte (
 char val
 )
 {
 for (unsigned long i = 0; i < number_of_nodes(); ++i)
 {
 // don't send to yourself or to terminated nodes
 if (i == node_id() || _cons[i]->terminated)
 continue;
 serialize(val, _cons[i]->stream);
 _cons[i]->stream.flush();
 }
 }
// ----------------------------------------------------------------------------------------
 void bsp_context::
 send_data(
 const std::vector<char>& item,
 unsigned long target_node_id
 ) 
 {
 using namespace impl2;
 if (_cons[target_node_id]->terminated)
 throw socket_error("Attempt to send a message to a node that has terminated.");
 serialize(MESSAGE_HEADER, _cons[target_node_id]->stream);
 serialize(current_epoch, _cons[target_node_id]->stream);
 serialize(item, _cons[target_node_id]->stream);
 _cons[target_node_id]->stream.flush();
 notify_control_node(SENT_MESSAGE);
 }
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_BSP_CPph_

AltStyle によって変換されたページ (->オリジナル) /