// Copyright (C) 2014  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_STRUCTURAL_TRACK_ASSOCIATION_TRAnER_Hh_
#define DLIB_STRUCTURAL_TRACK_ASSOCIATION_TRAnER_Hh_

#include "structural_track_association_trainer_abstract.h"
#include "../algs.h"
#include "svm.h"
#include <utility>
#include "track_association_function.h"
#include "structural_assignment_trainer.h"
#include <map>

namespace dlib
{

// ----------------------------------------------------------------------------------------

    namespace impl
    {
        template <
            typename detection_type,
            typename label_type
            >
        std::vector<detection_type> get_unlabeled_dets (
            const std::vector<labeled_detection<detection_type,label_type> >& dets
        )
        {
            std::vector<detection_type> temp;
            temp.reserve(dets.size());
            for (unsigned long i = 0; i < dets.size(); ++i)
                temp.push_back(dets[i].det);
            return temp;
        }

    }

// ----------------------------------------------------------------------------------------

    class structural_track_association_trainer
    {
    public:

        structural_track_association_trainer (
        )  
        {
            set_defaults();
        }

        void set_num_threads (
            unsigned long num
        )
        {
            num_threads = num;
        }

        unsigned long get_num_threads (
        ) const
        {
            return num_threads;
        }

        void set_epsilon (
            double eps_
        )
        {
            // make sure requires clause is not broken
            DLIB_ASSERT(eps_ > 0,
                "\t void structural_track_association_trainer::set_epsilon()"
                << "\n\t eps_ must be greater than 0"
                << "\n\t eps_: " << eps_ 
                << "\n\t this: " << this
                );

            eps = eps_;
        }

        double get_epsilon (
        ) const { return eps; }

        void set_max_cache_size (
            unsigned long max_size
        )
        {
            max_cache_size = max_size;
        }

        unsigned long get_max_cache_size (
        ) const
        {
            return max_cache_size; 
        }

        void set_loss_per_false_association (
            double loss
        )
        {
            // make sure requires clause is not broken
            DLIB_ASSERT(loss > 0, 
                "\t void structural_track_association_trainer::set_loss_per_false_association(loss)"
                << "\n\t Invalid inputs were given to this function "
                << "\n\t loss: " << loss
                << "\n\t this: " << this
                );

            loss_per_false_association = loss;
        }

        double get_loss_per_false_association (
        ) const
        {
            return loss_per_false_association;
        }

        void set_loss_per_track_break (
            double loss
        )
        {
            // make sure requires clause is not broken
            DLIB_ASSERT(loss > 0, 
                "\t void structural_track_association_trainer::set_loss_per_track_break(loss)"
                << "\n\t Invalid inputs were given to this function "
                << "\n\t loss: " << loss
                << "\n\t this: " << this
                );

            loss_per_track_break = loss;
        }

        double get_loss_per_track_break (
        ) const
        {
            return loss_per_track_break;
        }

        void be_verbose (
        )
        {
            verbose = true;
        }

        void be_quiet (
        )
        {
            verbose = false;
        }

        void set_oca (
            const oca& item
        )
        {
            solver = item;
        }

        const oca get_oca (
        ) const
        {
            return solver;
        }

        void set_c (
            double C_ 
        )
        {
            // make sure requires clause is not broken
            DLIB_ASSERT(C_ > 0,
                "\t void structural_track_association_trainer::set_c()"
                << "\n\t C_ must be greater than 0"
                << "\n\t C_:    " << C_ 
                << "\n\t this: " << this
                );

            C = C_;
        }

        double get_c (
        ) const
        {
            return C;
        }

        bool learns_nonnegative_weights (
        ) const { return learn_nonnegative_weights; }
       
        void set_learns_nonnegative_weights (
            bool value
        )
        {
            learn_nonnegative_weights = value;
        }

        template <
            typename detection_type,
            typename label_type
            >
        const track_association_function<detection_type> train (  
            const std::vector<std::vector<std::vector<labeled_detection<detection_type,label_type> > > >& samples
        ) const
        {
            // make sure requires clause is not broken
            DLIB_ASSERT(is_track_association_problem(samples),
                        "\t track_association_function structural_track_association_trainer::train()"
                        << "\n\t invalid inputs were given to this function"
                        << "\n\t is_track_association_problem(samples): " << is_track_association_problem(samples)
            );

            typedef typename detection_type::track_type track_type;

            const unsigned long num_dims = find_num_dims(samples);

            feature_extractor_track_association<detection_type> fe(num_dims, learn_nonnegative_weights?num_dims:0);
            structural_assignment_trainer<feature_extractor_track_association<detection_type> > trainer(fe);


            if (verbose)
                trainer.be_verbose();

            trainer.set_c(C);
            trainer.set_epsilon(eps);
            trainer.set_max_cache_size(max_cache_size);
            trainer.set_num_threads(num_threads);
            trainer.set_oca(solver);
            trainer.set_loss_per_missed_association(loss_per_track_break);
            trainer.set_loss_per_false_association(loss_per_false_association);

            std::vector<std::pair<std::vector<detection_type>, std::vector<track_type> > > assignment_samples;
            std::vector<std::vector<long> > labels;
            for (unsigned long i = 0; i < samples.size(); ++i)
                convert_dets_to_association_sets(samples[i], assignment_samples, labels);


            return track_association_function<detection_type>(trainer.train(assignment_samples, labels));
        }

        template <
            typename detection_type,
            typename label_type
            >
        const track_association_function<detection_type> train (  
            const std::vector<std::vector<labeled_detection<detection_type,label_type> > >& sample
        ) const
        {
            std::vector<std::vector<std::vector<labeled_detection<detection_type,label_type> > > > samples;
            samples.push_back(sample);
            return train(samples);
        }

    private:

        template <
            typename detection_type,
            typename label_type
            >
        static unsigned long find_num_dims (
            const std::vector<std::vector<std::vector<labeled_detection<detection_type,label_type> > > >& samples
        )
        {
            typedef typename detection_type::track_type track_type;
            // find a detection_type object so we can call get_similarity_features() and
            // find out how big the feature vectors are.

            // for all detection histories 
            for (unsigned long i = 0; i < samples.size(); ++i)
            {
                // for all time instances in the detection history
                for (unsigned j = 0; j < samples[i].size(); ++j)
                {
                    if (samples[i][j].size() > 0)
                    {
                        track_type new_track;
                        new_track.update_track(samples[i][j][0].det);
                        typename track_type::feature_vector_type feats;
                        new_track.get_similarity_features(samples[i][j][0].det, feats);
                        return feats.size();
                    }
                }
            }

            DLIB_CASSERT(false, 
                "No detection objects were given in the call to dlib::structural_track_association_trainer::train()");
        }

        template <
            typename detections_at_single_time_step,
            typename detection_type,
            typename track_type
            >
        static void convert_dets_to_association_sets (
            const std::vector<detections_at_single_time_step>& det_history,
            std::vector<std::pair<std::vector<detection_type>, std::vector<track_type> > >& data,
            std::vector<std::vector<long> >& labels
        ) 
        {
            if (det_history.size() < 1)
                return;

            typedef typename detections_at_single_time_step::value_type::label_type label_type;
            std::vector<track_type> tracks;
            // track_labels maps from detection labels to the index in tracks.  So track
            // with detection label X is at tracks[track_labels[X]].
            std::map<label_type,unsigned long> track_labels;
            add_dets_to_tracks(tracks, track_labels, det_history[0]);

            using namespace impl;
            for (unsigned long i = 1; i < det_history.size(); ++i)
            {
                data.push_back(std::make_pair(get_unlabeled_dets(det_history[i]), tracks));
                labels.push_back(get_association_labels(det_history[i], track_labels));
                add_dets_to_tracks(tracks, track_labels, det_history[i]);
            }
        }

        template <
            typename labeled_detection,
            typename label_type
            >
        static std::vector<long> get_association_labels(
            const std::vector<labeled_detection>& dets,
            const std::map<label_type,unsigned long>& track_labels
        )
        {
            std::vector<long> assoc(dets.size(),-1);
            // find out which detections associate to what tracks
            for (unsigned long i = 0; i < dets.size(); ++i)
            {
                typename std::map<label_type,unsigned long>::const_iterator j;
                j = track_labels.find(dets[i].label);
                // If this detection matches one of the tracks then record which track it
                // matched with.
                if (j != track_labels.end())
                    assoc[i] = j->second;
            }
            return assoc;
        }

        template <
            typename track_type,
            typename label_type,
            typename labeled_detection
            >
        static void add_dets_to_tracks (
            std::vector<track_type>& tracks,
            std::map<label_type,unsigned long>& track_labels,
            const std::vector<labeled_detection>& dets
        )
        {
            std::vector<bool> updated_track(tracks.size(), false);

            // first assign the dets to the tracks
            for (unsigned long i = 0; i < dets.size(); ++i)
            {
                const label_type& label = dets[i].label;
                if (track_labels.count(label))
                {
                    const unsigned long track_idx = track_labels[label];
                    tracks[track_idx].update_track(dets[i].det);
                    updated_track[track_idx] = true;
                }
                else
                {
                    // this detection creates a new track
                    track_type new_track;
                    new_track.update_track(dets[i].det);
                    tracks.push_back(new_track);
                    track_labels[label] = tracks.size()-1;
                }

            }

            // Now propagate all the tracks that didn't get any detections.
            for (unsigned long i = 0; i < updated_track.size(); ++i)
            {
                if (!updated_track[i])
                    tracks[i].propagate_track();
            }
        }

        double C;
        oca solver;
        double eps;
        bool verbose;
        unsigned long num_threads;
        unsigned long max_cache_size;
        bool learn_nonnegative_weights;
        double loss_per_track_break;
        double loss_per_false_association;

        void set_defaults ()
        {
            C = 100;
            verbose = false;
            eps = 0.001;
            num_threads = 2;
            max_cache_size = 5;
            learn_nonnegative_weights = false;
            loss_per_track_break = 1;
            loss_per_false_association = 1;
        }
    };

}

#endif // DLIB_STRUCTURAL_TRACK_ASSOCIATION_TRAnER_Hh_