dlib C++ Library - learning_to_track.cpp

// Copyright (C) 2014 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#include <sstream>
#include "tester.h"
#include <dlib/svm_threaded.h>
#include <dlib/rand.h>
namespace 
{
 using namespace test;
 using namespace dlib;
 using namespace std;
 logger dlog("test.learning_to_track");
// ----------------------------------------------------------------------------------------
 struct detection_dense
 {
 typedef struct track_dense track_type;
 matrix<double,0,1> measurements;
 };
 struct track_dense
 {
 typedef matrix<double,0,1> feature_vector_type;
 track_dense()
 {
 time_since_last_association = 0;
 }
 void get_similarity_features(const detection_dense det, feature_vector_type& feats) const
 {
 feats = abs(last_measurements - det.measurements);
 }
 void update_track(const detection_dense det)
 {
 last_measurements = det.measurements;
 time_since_last_association = 0;
 }
 void propagate_track()
 {
 ++time_since_last_association;
 }
 matrix<double,0,1> last_measurements;
 unsigned long time_since_last_association;
 };
// ----------------------------------------------------------------------------------------
 struct detection_sparse
 {
 typedef struct track_sparse track_type;
 matrix<double,0,1> measurements;
 };
 struct track_sparse
 {
 typedef std::vector<std::pair<unsigned long,double> > feature_vector_type;
 track_sparse()
 {
 time_since_last_association = 0;
 }
 void get_similarity_features(const detection_sparse det, feature_vector_type& feats) const
 {
 matrix<double,0,1> temp = abs(last_measurements - det.measurements);
 feats.clear();
 for (long i = 0; i < temp.size(); ++i)
 feats.push_back(make_pair(i, temp(i)));
 }
 void update_track(const detection_sparse det)
 {
 last_measurements = det.measurements;
 time_since_last_association = 0;
 }
 void propagate_track()
 {
 ++time_since_last_association;
 }
 matrix<double,0,1> last_measurements;
 unsigned long time_since_last_association;
 };
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
 dlib::rand rnd;
 const long num_objects = 4;
 const long num_properties = 6;
 std::vector<matrix<double,0,1> > object_properties(num_objects);
 void initialize_object_properties()
 {
 rnd.set_seed("23ja2oirfjaf");
 for (unsigned long i = 0; i < object_properties.size(); ++i)
 object_properties[i] = randm(num_properties,1,rnd);
 }
 template <typename detection>
 detection sample_detection_from_sensor(long object_id)
 {
 DLIB_CASSERT(object_id < num_objects, 
 "You can't ask to sample a detection from an object that doesn't exist."); 
 detection temp;
 // Set the measurements equal to the object's true property values plus a little bit of
 // noise.
 temp.measurements = object_properties[object_id] + randm(num_properties,1,rnd)*0.1;
 return temp;
 }
// ----------------------------------------------------------------------------------------
 template <typename detection>
 std::vector<std::vector<labeled_detection<detection> > > make_random_tracking_data_for_training()
 {
 typedef std::vector<labeled_detection<detection> > detections_at_single_time_step;
 typedef std::vector<detections_at_single_time_step> track_history;
 track_history data;
 // At each time step we get a set of detections from the objects in the world.
 // Simulate 100 time steps worth of data where there are 3 objects present. 
 const int num_time_steps = 100;
 for (int i = 0; i < num_time_steps; ++i)
 {
 detections_at_single_time_step dets(3);
 // sample a detection from object 0
 dets[0].det = sample_detection_from_sensor<detection>(0);
 dets[0].label = 0;
 // sample a detection from object 1
 dets[1].det = sample_detection_from_sensor<detection>(1);
 dets[1].label = 1;
 // sample a detection from object 2
 dets[2].det = sample_detection_from_sensor<detection>(2);
 dets[2].label = 2;
 randomize_samples(dets, rnd);
 data.push_back(dets);
 }
 // Now let's imagine object 1 and 2 are gone but a new object, object 3 has arrived. 
 for (int i = 0; i < num_time_steps; ++i)
 {
 detections_at_single_time_step dets(2);
 // sample a detection from object 0
 dets[0].det = sample_detection_from_sensor<detection>(0);
 dets[0].label = 0;
 // sample a detection from object 3
 dets[1].det = sample_detection_from_sensor<detection>(3);
 dets[1].label = 3;
 randomize_samples(dets, rnd);
 data.push_back(dets);
 }
 return data;
 }
// ----------------------------------------------------------------------------------------
 template <typename detection>
 std::vector<detection> make_random_detections(long num_dets)
 {
 DLIB_CASSERT(num_dets <= num_objects, 
 "You can't ask for more detections than there are objects in our little simulation."); 
 std::vector<detection> dets(num_dets);
 for (unsigned long i = 0; i < dets.size(); ++i)
 {
 dets[i] = sample_detection_from_sensor<detection>(i);
 }
 randomize_samples(dets, rnd);
 return dets;
 }
// ----------------------------------------------------------------------------------------
 template <typename detection>
 void test_tracking_stuff()
 {
 print_spinner();
 typedef std::vector<labeled_detection<detection> > detections_at_single_time_step;
 typedef std::vector<detections_at_single_time_step> track_history;
 std::vector<track_history> data;
 data.push_back(make_random_tracking_data_for_training<detection>());
 data.push_back(make_random_tracking_data_for_training<detection>());
 data.push_back(make_random_tracking_data_for_training<detection>());
 data.push_back(make_random_tracking_data_for_training<detection>());
 data.push_back(make_random_tracking_data_for_training<detection>());
 structural_track_association_trainer trainer;
 trainer.set_c(1000);
 track_association_function<detection> assoc = trainer.train(data);
 double test_val = test_track_association_function(assoc, data); 
 DLIB_TEST_MSG( test_val == 1, test_val);
 test_val = cross_validate_track_association_trainer(trainer, data, 5); 
 DLIB_TEST_MSG ( test_val == 1, test_val);
 typedef typename detection::track_type track;
 std::vector<track> tracks;
 std::vector<detection> dets = make_random_detections<detection>(3);
 assoc(tracks, dets);
 DLIB_TEST(tracks.size() == 3);
 dets = make_random_detections<detection>(3);
 assoc(tracks, dets);
 DLIB_TEST(tracks.size() == 3);
 dets = make_random_detections<detection>(3);
 assoc(tracks, dets);
 DLIB_TEST(tracks.size() == 3);
 dets = make_random_detections<detection>(4);
 assoc(tracks, dets);
 DLIB_TEST(tracks.size() == 4);
 dets = make_random_detections<detection>(3);
 assoc(tracks, dets);
 DLIB_TEST(tracks.size() == 4);
 unsigned long total_miss = 0;
 for (unsigned long i = 0; i < tracks.size(); ++i)
 total_miss += tracks[i].time_since_last_association;
 DLIB_TEST(total_miss == 1);
 dets = make_random_detections<detection>(3);
 assoc(tracks, dets);
 DLIB_TEST(tracks.size() == 4);
 total_miss = 0;
 unsigned long num_zero = 0;
 for (unsigned long i = 0; i < tracks.size(); ++i)
 {
 total_miss += tracks[i].time_since_last_association;
 if (tracks[i].time_since_last_association == 0)
 ++num_zero;
 }
 DLIB_TEST(total_miss == 2);
 DLIB_TEST(num_zero == 3);
 ostringstream sout; 
 serialize(assoc, sout);
 istringstream sin(sout.str());
 deserialize(assoc, sin);
 DLIB_TEST( test_track_association_function(assoc, data) == 1);
 }
// ----------------------------------------------------------------------------------------
 class test_learning_to_track : public tester
 {
 public:
 test_learning_to_track (
 ) :
 tester ("test_learning_to_track",
 "Runs tests on the assignment learning code.")
 {}
 void perform_test (
 )
 {
 initialize_object_properties();
 for (int i = 0; i < 3; ++i)
 {
 dlog << LINFO << "run test_tracking_stuff<detection_dense>()";
 test_tracking_stuff<detection_dense>();
 dlog << LINFO << "run test_tracking_stuff<detection_sparse>()";
 test_tracking_stuff<detection_sparse>();
 }
 }
 } a;
// ----------------------------------------------------------------------------------------
}

AltStyle によって変換されたページ (->オリジナル) /