dlib C++ Library - validation.h

// Copyright (C) 2016 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_DNn_VALIDATION_H_
#define DLIB_DNn_VALIDATION_H_
#include "../svm/cross_validate_object_detection_trainer_abstract.h"
#include "../svm/cross_validate_object_detection_trainer.h"
#include "layers.h"
#include <set>
namespace dlib
{
 namespace impl
 {
 inline std::set<std::string> get_labels (
 const std::vector<mmod_rect>& rects1,
 const std::vector<mmod_rect>& rects2
 )
 {
 std::set<std::string> labels;
 for (auto& rr : rects1)
 labels.insert(rr.label);
 for (auto& rr : rects2)
 labels.insert(rr.label);
 return labels;
 }
 }
 template <
 typename SUBNET,
 typename image_array_type
 >
 const matrix<double,1,3> test_object_detection_function (
 loss_mmod<SUBNET>& detector,
 const image_array_type& images,
 const std::vector<std::vector<mmod_rect>>& truth_dets,
 const test_box_overlap& overlap_tester = test_box_overlap(),
 const double adjust_threshold = 0,
 const test_box_overlap& overlaps_ignore_tester = test_box_overlap()
 )
 {
 // make sure requires clause is not broken
 DLIB_CASSERT( is_learning_problem(images,truth_dets) == true , 
 "\t matrix test_object_detection_function()"
 << "\n\t invalid inputs were given to this function"
 << "\n\t is_learning_problem(images,truth_dets): " << is_learning_problem(images,truth_dets)
 << "\n\t images.size(): " << images.size() 
 );
 double correct_hits = 0;
 double total_true_targets = 0;
 std::vector<std::pair<double,bool> > all_dets;
 unsigned long missing_detections = 0;
 resizable_tensor temp;
 for (unsigned long i = 0; i < images.size(); ++i)
 {
 std::vector<mmod_rect> hits; 
 detector.to_tensor(&images[i], &images[i]+1, temp);
 detector.subnet().forward(temp);
 detector.loss_details().to_label(temp, detector.subnet(), &hits, adjust_threshold);
 for (auto& label : impl::get_labels(truth_dets[i], hits))
 {
 std::vector<full_object_detection> truth_boxes;
 std::vector<rectangle> ignore;
 std::vector<std::pair<double,rectangle>> boxes;
 // copy hits and truth_dets into the above three objects
 for (auto&& b : truth_dets[i])
 {
 if (b.ignore)
 {
 ignore.push_back(b);
 }
 else if (b.label == label)
 {
 truth_boxes.push_back(full_object_detection(b.rect));
 ++total_true_targets;
 }
 }
 for (auto&& b : hits)
 {
 if (b.label == label)
 boxes.push_back(std::make_pair(b.detection_confidence, b.rect));
 }
 correct_hits += impl::number_of_truth_hits(truth_boxes, ignore, boxes, overlap_tester, all_dets, missing_detections, overlaps_ignore_tester);
 }
 }
 std::sort(all_dets.rbegin(), all_dets.rend());
 double precision, recall;
 double total_hits = all_dets.size();
 if (total_hits == 0)
 precision = 1;
 else
 precision = correct_hits / total_hits;
 if (total_true_targets == 0)
 recall = 1;
 else
 recall = correct_hits / total_true_targets;
 matrix<double, 1, 3> res;
 res = precision, recall, average_precision(all_dets, missing_detections);
 return res;
 }
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_DNn_VALIDATION_H_

AltStyle によって変換されたページ (->オリジナル) /