// Copyright (C) 2014  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_SHAPE_PREDICToR_H_
#define DLIB_SHAPE_PREDICToR_H_

#include "shape_predictor_abstract.h"
#include "full_object_detection.h"
#include "../algs.h"
#include "../matrix.h"
#include "../geometry.h"
#include "../pixel.h"
#include "../console_progress_indicator.h"
#include "../statistics.h"
#include "../threads.h"
#include <utility>

namespace dlib
{

// ----------------------------------------------------------------------------------------

    namespace impl
    {
        struct split_feature
        {
            unsigned long idx1;
            unsigned long idx2;
            float thresh;

            friend inline void serialize (const split_feature& item, std::ostream& out)
            {
                dlib::serialize(item.idx1, out);
                dlib::serialize(item.idx2, out);
                dlib::serialize(item.thresh, out);
            }
            friend inline void deserialize (split_feature& item, std::istream& in)
            {
                dlib::deserialize(item.idx1, in);
                dlib::deserialize(item.idx2, in);
                dlib::deserialize(item.thresh, in);
            }
        };


        // a tree is just a std::vector<impl::split_feature>.  We use this function to navigate the
        // tree nodes
        inline unsigned long left_child (unsigned long idx) { return 2*idx + 1; }
        /*!
            ensures
                - returns the index of the left child of the binary tree node idx
        !*/
        inline unsigned long right_child (unsigned long idx) { return 2*idx + 2; }
        /*!
            ensures
                - returns the index of the left child of the binary tree node idx
        !*/

        struct regression_tree
        {
            std::vector<split_feature> splits;
            std::vector<matrix<float,0,1> > leaf_values;

            unsigned long num_leaves() const { return leaf_values.size(); }

            inline const matrix<float,0,1>& operator()(
                const std::vector<float>& feature_pixel_values,
                unsigned long& i
            ) const
            /*!
                requires
                    - All the index values in splits are less than feature_pixel_values.size()
                    - leaf_values.size() is a power of 2.
                      (i.e. we require a tree with all the levels fully filled out.
                    - leaf_values.size() == splits.size()+1
                      (i.e. there needs to be the right number of leaves given the number of splits in the tree)
                ensures
                    - runs through the tree and returns the vector at the leaf we end up in.
                    - #i == the selected leaf node index.
            !*/
            {
                i = 0;
                while (i < splits.size())
                {
                    if ((float)feature_pixel_values[splits[i].idx1] - (float)feature_pixel_values[splits[i].idx2] > splits[i].thresh)
                        i = left_child(i);
                    else
                        i = right_child(i);
                }
                i = i - splits.size();
                return leaf_values[i];
            }

            friend void serialize (const regression_tree& item, std::ostream& out)
            {
                dlib::serialize(item.splits, out);
                dlib::serialize(item.leaf_values, out);
            }
            friend void deserialize (regression_tree& item, std::istream& in)
            {
                dlib::deserialize(item.splits, in);
                dlib::deserialize(item.leaf_values, in);
            }
        };

    // ------------------------------------------------------------------------------------

        inline vector<float,2> location (
            const matrix<float,0,1>& shape,
            unsigned long idx
        )
        /*!
            requires
                - idx < shape.size()/2
                - shape.size()%2 == 0
            ensures
                - returns the idx-th point from the shape vector.
        !*/
        {
            return vector<float,2>(shape(idx*2), shape(idx*2+1));
        }

    // ------------------------------------------------------------------------------------

        inline unsigned long nearest_shape_point (
            const matrix<float,0,1>& shape,
            const dlib::vector<float,2>& pt
        )
        {
            // find the nearest part of the shape to this pixel
            float best_dist = std::numeric_limits<float>::infinity();
            const unsigned long num_shape_parts = shape.size()/2;
            unsigned long best_idx = 0;
            for (unsigned long j = 0; j < num_shape_parts; ++j)
            {
                const float dist = length_squared(location(shape,j)-pt);
                if (dist < best_dist)
                {
                    best_dist = dist;
                    best_idx = j;
                }
            }
            return best_idx;
        }

    // ------------------------------------------------------------------------------------

        inline void create_shape_relative_encoding (
            const matrix<float,0,1>& shape,
            const std::vector<dlib::vector<float,2> >& pixel_coordinates,
            std::vector<unsigned long>& anchor_idx, 
            std::vector<dlib::vector<float,2> >& deltas
        )
        /*!
            requires
                - shape.size()%2 == 0 
                - shape.size() > 0
            ensures
                - #anchor_idx.size() == pixel_coordinates.size()
                - #deltas.size()     == pixel_coordinates.size()
                - for all valid i:
                    - pixel_coordinates[i] == location(shape,#anchor_idx[i]) + #deltas[i]
        !*/
        {
            anchor_idx.resize(pixel_coordinates.size());
            deltas.resize(pixel_coordinates.size());


            for (unsigned long i = 0; i < pixel_coordinates.size(); ++i)
            {
                anchor_idx[i] = nearest_shape_point(shape, pixel_coordinates[i]);
                deltas[i] = pixel_coordinates[i] - location(shape,anchor_idx[i]);
            }
        }

    // ------------------------------------------------------------------------------------

        inline point_transform_affine find_tform_between_shapes (
            const matrix<float,0,1>& from_shape,
            const matrix<float,0,1>& to_shape
        )
        {
            DLIB_ASSERT(from_shape.size() == to_shape.size() && (from_shape.size()%2) == 0 && from_shape.size() > 0,"");
            std::vector<vector<float,2> > from_points, to_points;
            const unsigned long num = from_shape.size()/2;
            from_points.reserve(num);
            to_points.reserve(num);
            if (num == 1)
            {
                // Just use an identity transform if there is only one landmark.
                return point_transform_affine();
            }

            for (unsigned long i = 0; i < num; ++i)
            {
                from_points.push_back(location(from_shape,i));
                to_points.push_back(location(to_shape,i));
            }
            return find_similarity_transform(from_points, to_points);
        }

    // ------------------------------------------------------------------------------------

        inline point_transform_affine normalizing_tform (
            const rectangle& rect
        )
        /*!
            ensures
                - returns a transform that maps rect.tl_corner() to (0,0) and rect.br_corner()
                  to (1,1).
        !*/
        {
            std::vector<vector<float,2> > from_points, to_points;
            from_points.push_back(rect.tl_corner()); to_points.push_back(point(0,0));
            from_points.push_back(rect.tr_corner()); to_points.push_back(point(1,0));
            from_points.push_back(rect.br_corner()); to_points.push_back(point(1,1));
            return find_affine_transform(from_points, to_points);
        }

    // ------------------------------------------------------------------------------------

        inline point_transform_affine unnormalizing_tform (
            const rectangle& rect
        )
        /*!
            ensures
                - returns a transform that maps (0,0) to rect.tl_corner() and (1,1) to
                  rect.br_corner().
        !*/
        {
            std::vector<vector<float,2> > from_points, to_points;
            to_points.push_back(rect.tl_corner()); from_points.push_back(point(0,0));
            to_points.push_back(rect.tr_corner()); from_points.push_back(point(1,0));
            to_points.push_back(rect.br_corner()); from_points.push_back(point(1,1));
            return find_affine_transform(from_points, to_points);
        }

    // ------------------------------------------------------------------------------------

        template <typename image_type, typename feature_type>
        void extract_feature_pixel_values (
            const image_type& img_,
            const rectangle& rect,
            const matrix<float,0,1>& current_shape,
            const matrix<float,0,1>& reference_shape,
            const std::vector<unsigned long>& reference_pixel_anchor_idx,
            const std::vector<dlib::vector<float,2> >& reference_pixel_deltas,
            std::vector<feature_type>& feature_pixel_values
        )
        /*!
            requires
                - image_type == an image object that implements the interface defined in
                  dlib/image_processing/generic_image.h 
                - reference_pixel_anchor_idx.size() == reference_pixel_deltas.size()
                - current_shape.size() == reference_shape.size()
                - reference_shape.size()%2 == 0
                - max(mat(reference_pixel_anchor_idx)) < reference_shape.size()/2
            ensures
                - #feature_pixel_values.size() == reference_pixel_deltas.size()
                - for all valid i:
                    - #feature_pixel_values[i] == the value of the pixel in img_ that
                      corresponds to the pixel identified by reference_pixel_anchor_idx[i]
                      and reference_pixel_deltas[i] when the pixel is located relative to
                      current_shape rather than reference_shape.
        !*/
        {
            const matrix<float,2,2> tform = matrix_cast<float>(find_tform_between_shapes(reference_shape, current_shape).get_m());
            const point_transform_affine tform_to_img = unnormalizing_tform(rect);

            const rectangle area = get_rect(img_);

            const_image_view<image_type> img(img_);
            feature_pixel_values.resize(reference_pixel_deltas.size());
            for (unsigned long i = 0; i < feature_pixel_values.size(); ++i)
            {
                // Compute the point in the current shape corresponding to the i-th pixel and
                // then map it from the normalized shape space into pixel space.
                point p = tform_to_img(tform*reference_pixel_deltas[i] + location(current_shape, reference_pixel_anchor_idx[i]));
                if (area.contains(p))
                    feature_pixel_values[i] = get_pixel_intensity(img[p.y()][p.x()]);
                else
                    feature_pixel_values[i] = 0;
            }
        }

    } // end namespace impl

// ----------------------------------------------------------------------------------------

    class shape_predictor
    {
    public:


        shape_predictor (
        ) 
        {}

        shape_predictor (
            const matrix<float,0,1>& initial_shape_,
            const std::vector<std::vector<impl::regression_tree> >& forests_,
            const std::vector<std::vector<dlib::vector<float,2> > >& pixel_coordinates
        ) : initial_shape(initial_shape_), forests(forests_)
        /*!
            requires
                - initial_shape.size()%2 == 0
                - forests.size() == pixel_coordinates.size() == the number of cascades
                - for all valid i:
                    - all the index values in forests[i] are less than pixel_coordinates[i].size()
                - for all valid i and j: 
                    - forests[i][j].leaf_values.size() is a power of 2.
                      (i.e. we require a tree with all the levels fully filled out.
                    - forests[i][j].leaf_values.size() == forests[i][j].splits.size()+1
                      (i.e. there need to be the right number of leaves given the number of splits in the tree)
        !*/
        {
            anchor_idx.resize(pixel_coordinates.size());
            deltas.resize(pixel_coordinates.size());
            // Each cascade uses a different set of pixels for its features.  We compute
            // their representations relative to the initial shape now and save it.
            for (unsigned long i = 0; i < pixel_coordinates.size(); ++i)
                impl::create_shape_relative_encoding(initial_shape, pixel_coordinates[i], anchor_idx[i], deltas[i]);
        }

        unsigned long num_parts (
        ) const
        {
            return initial_shape.size()/2;
        }

        unsigned long num_features (
        ) const
        {
            unsigned long num = 0;
            for (unsigned long iter = 0; iter < forests.size(); ++iter)
                for (unsigned long i = 0; i < forests[iter].size(); ++i)
                    num += forests[iter][i].num_leaves();
            return num;
        }

        template <typename image_type>
        full_object_detection operator()(
            const image_type& img,
            const rectangle& rect
        ) const
        {
            using namespace impl;
            matrix<float,0,1> current_shape = initial_shape;
            std::vector<float> feature_pixel_values;
            for (unsigned long iter = 0; iter < forests.size(); ++iter)
            {
                extract_feature_pixel_values(img, rect, current_shape, initial_shape,
                                             anchor_idx[iter], deltas[iter], feature_pixel_values);
                unsigned long leaf_idx;
                // evaluate all the trees at this level of the cascade.
                for (unsigned long i = 0; i < forests[iter].size(); ++i)
                    current_shape += forests[iter][i](feature_pixel_values, leaf_idx);
            }

            // convert the current_shape into a full_object_detection
            const point_transform_affine tform_to_img = unnormalizing_tform(rect);
            std::vector<point> parts(current_shape.size()/2);
            for (unsigned long i = 0; i < parts.size(); ++i)
                parts[i] = tform_to_img(location(current_shape, i));
            return full_object_detection(rect, parts);
        }

        template <typename image_type, typename T, typename U>
        full_object_detection operator()(
            const image_type& img,
            const rectangle& rect,
            std::vector<std::pair<T,U> >& feats
        ) const
        {
            feats.clear();
            using namespace impl;
            matrix<float,0,1> current_shape = initial_shape;
            std::vector<float> feature_pixel_values;
            unsigned long feat_offset = 0;
            for (unsigned long iter = 0; iter < forests.size(); ++iter)
            {
                extract_feature_pixel_values(img, rect, current_shape, initial_shape,
                                             anchor_idx[iter], deltas[iter], feature_pixel_values);
                // evaluate all the trees at this level of the cascade.
                for (unsigned long i = 0; i < forests[iter].size(); ++i)
                {
                    unsigned long leaf_idx;
                    current_shape += forests[iter][i](feature_pixel_values, leaf_idx);

                    feats.push_back(std::make_pair(feat_offset+leaf_idx, 1));
                    feat_offset += forests[iter][i].num_leaves();
                }
            }

            // convert the current_shape into a full_object_detection
            const point_transform_affine tform_to_img = unnormalizing_tform(rect);
            std::vector<point> parts(current_shape.size()/2);
            for (unsigned long i = 0; i < parts.size(); ++i)
                parts[i] = tform_to_img(location(current_shape, i));
            return full_object_detection(rect, parts);
        }

        friend void serialize (const shape_predictor& item, std::ostream& out);

        friend void deserialize (shape_predictor& item, std::istream& in);

    private:
        matrix<float,0,1> initial_shape;
        std::vector<std::vector<impl::regression_tree> > forests;
        std::vector<std::vector<unsigned long> > anchor_idx; 
        std::vector<std::vector<dlib::vector<float,2> > > deltas;
    };

    inline void serialize (const shape_predictor& item, std::ostream& out)
    {
        int version = 1;
        dlib::serialize(version, out);
        dlib::serialize(item.initial_shape, out);
        dlib::serialize(item.forests, out);
        dlib::serialize(item.anchor_idx, out);
        dlib::serialize(item.deltas, out);
    }

    inline void deserialize (shape_predictor& item, std::istream& in)
    {
        int version = 0;
        dlib::deserialize(version, in);
        if (version != 1)
            throw serialization_error("Unexpected version found while deserializing dlib::shape_predictor.");
        dlib::deserialize(item.initial_shape, in);
        dlib::deserialize(item.forests, in);
        dlib::deserialize(item.anchor_idx, in);
        dlib::deserialize(item.deltas, in);
    }
// ----------------------------------------------------------------------------------------

    class shape_predictor_trainer
    {
        /*!
            This thing really only works with unsigned char or rgb_pixel images (since we assume the threshold 
            should be in the range [-128,128]).
        !*/
    public:

        shape_predictor_trainer (
        )
        {
            _cascade_depth = 10;
            _tree_depth = 4;
            _num_trees_per_cascade_level = 500;
            _nu = 0.1;
            _oversampling_amount = 20;
            _feature_pool_size = 400;
            _lambda = 0.1;
            _num_test_splits = 20;
            _feature_pool_region_padding = 0;
            _verbose = false;
            _num_threads = 0;
        }

        unsigned long get_cascade_depth (
        ) const { return _cascade_depth; }

        void set_cascade_depth (
            unsigned long depth
        )
        {
            DLIB_CASSERT(depth > 0, 
                "\t void shape_predictor_trainer::set_cascade_depth()"
                << "\n\t Invalid inputs were given to this function. "
                << "\n\t depth:  " << depth
            );

            _cascade_depth = depth;
        }

        unsigned long get_tree_depth (
        ) const { return _tree_depth; }

        void set_tree_depth (
            unsigned long depth
        )
        {
            DLIB_CASSERT(depth > 0, 
                "\t void shape_predictor_trainer::set_tree_depth()"
                << "\n\t Invalid inputs were given to this function. "
                << "\n\t depth:  " << depth
            );

            _tree_depth = depth;
        }

        unsigned long get_num_trees_per_cascade_level (
        ) const { return _num_trees_per_cascade_level; }

        void set_num_trees_per_cascade_level (
            unsigned long num
        )
        {
            DLIB_CASSERT( num > 0,
                "\t void shape_predictor_trainer::set_num_trees_per_cascade_level()"
                << "\n\t Invalid inputs were given to this function. "
                << "\n\t num:  " << num
            );
            _num_trees_per_cascade_level = num;
        }

        double get_nu (
        ) const { return _nu; } 
        void set_nu (
            double nu
        )
        {
            DLIB_CASSERT(0 < nu && nu <= 1,
                "\t void shape_predictor_trainer::set_nu()"
                << "\n\t Invalid inputs were given to this function. "
                << "\n\t nu:  " << nu 
            );

            _nu = nu;
        }

        std::string get_random_seed (
        ) const { return rnd.get_seed(); }
        void set_random_seed (
            const std::string& seed
        ) { rnd.set_seed(seed); }

        unsigned long get_oversampling_amount (
        ) const { return _oversampling_amount; }
        void set_oversampling_amount (
            unsigned long amount
        )
        {
            DLIB_CASSERT(amount > 0, 
                "\t void shape_predictor_trainer::set_oversampling_amount()"
                << "\n\t Invalid inputs were given to this function. "
                << "\n\t amount: " << amount 
            );

            _oversampling_amount = amount;
        }

        unsigned long get_feature_pool_size (
        ) const { return _feature_pool_size; }
        void set_feature_pool_size (
            unsigned long size
        ) 
        {
            DLIB_CASSERT(size > 1, 
                "\t void shape_predictor_trainer::set_feature_pool_size()"
                << "\n\t Invalid inputs were given to this function. "
                << "\n\t size: " << size 
            );

            _feature_pool_size = size;
        }

        double get_lambda (
        ) const { return _lambda; }
        void set_lambda (
            double lambda
        )
        {
            DLIB_CASSERT(lambda > 0,
                "\t void shape_predictor_trainer::set_lambda()"
                << "\n\t Invalid inputs were given to this function. "
                << "\n\t lambda: " << lambda 
            );

            _lambda = lambda;
        }

        unsigned long get_num_test_splits (
        ) const { return _num_test_splits; }
        void set_num_test_splits (
            unsigned long num
        )
        {
            DLIB_CASSERT(num > 0, 
                "\t void shape_predictor_trainer::set_num_test_splits()"
                << "\n\t Invalid inputs were given to this function. "
                << "\n\t num: " << num 
            );

            _num_test_splits = num;
        }


        double get_feature_pool_region_padding (
        ) const { return _feature_pool_region_padding; }
        void set_feature_pool_region_padding (
            double padding 
        )
        {
            _feature_pool_region_padding = padding;
        }

        void be_verbose (
        )
        {
            _verbose = true;
        }

        void be_quiet (
        )
        {
            _verbose = false;
        }

        unsigned long get_num_threads (
        ) const { return _num_threads; }
        void set_num_threads (
                unsigned long num
        )
        {
            _num_threads = num;
        }

        template <typename image_array>
        shape_predictor train (
            const image_array& images,
            const std::vector<std::vector<full_object_detection> >& objects
        ) const
        {
            using namespace impl;
            DLIB_CASSERT(images.size() == objects.size() && images.size() > 0,
                "\t shape_predictor shape_predictor_trainer::train()"
                << "\n\t Invalid inputs were given to this function. "
                << "\n\t images.size():  " << images.size() 
                << "\n\t objects.size(): " << objects.size() 
            );
            // make sure the objects agree on the number of parts and that there is at
            // least one full_object_detection. 
            unsigned long num_parts = 0;
            std::vector<int> part_present;
            for (unsigned long i = 0; i < objects.size(); ++i)
            {
                for (unsigned long j = 0; j < objects[i].size(); ++j)
                {
                    if (num_parts == 0)
                    {
                        num_parts = objects[i][j].num_parts();
                        DLIB_CASSERT(objects[i][j].num_parts() != 0,
                            "\t shape_predictor shape_predictor_trainer::train()"
                            << "\n\t You can't give objects that don't have any parts to the trainer."
                        );
                        part_present.resize(num_parts);
                    }
                    else
                    {
                        DLIB_CASSERT(objects[i][j].num_parts() == num_parts,
                            "\t shape_predictor shape_predictor_trainer::train()"
                            << "\n\t All the objects must agree on the number of parts. "
                            << "\n\t objects["<<i<<"]["<<j<<"].num_parts(): " << objects[i][j].num_parts()
                            << "\n\t num_parts:  " << num_parts 
                        );
                    }
                    for (unsigned long p = 0; p < objects[i][j].num_parts(); ++p)
                    {
                        if (objects[i][j].part(p) != OBJECT_PART_NOT_PRESENT)
                            part_present[p] = 1;
                    }
                }
            }
            DLIB_CASSERT(num_parts != 0,
                "\t shape_predictor shape_predictor_trainer::train()"
                << "\n\t You must give at least one full_object_detection if you want to train a shape model and it must have parts."
            );
            DLIB_CASSERT(sum(mat(part_present)) == (long)num_parts,
                "\t shape_predictor shape_predictor_trainer::train()"
                << "\n\t Each part must appear at least once in this training data.  That is, "
                << "\n\t you can't have a part that is always set to OBJECT_PART_NOT_PRESENT."
            );

            // creating thread pool. if num_threads <= 1, trainer should work in caller thread
            thread_pool tp(_num_threads > 1 ? _num_threads : 0);

            // determining the type of features used for this type of images
            typedef typename std::remove_const<typename std::remove_reference<decltype(images[0])>::type>::type image_type;
            typedef typename image_traits<image_type>::pixel_type pixel_type;
            typedef typename pixel_traits<pixel_type>::basic_pixel_type feature_type;

            rnd.set_seed(get_random_seed());

            std::vector<training_sample<feature_type>> samples;
            const matrix<float,0,1> initial_shape = populate_training_sample_shapes(objects, samples);
            const std::vector<std::vector<dlib::vector<float,2> > > pixel_coordinates = randomly_sample_pixel_coordinates(initial_shape);

            unsigned long trees_fit_so_far = 0;
            console_progress_indicator pbar(get_cascade_depth()*get_num_trees_per_cascade_level());
            if (_verbose)
                std::cout << "Fitting trees..." << std::endl;

            std::vector<std::vector<impl::regression_tree> > forests(get_cascade_depth());
            // Now start doing the actual training by filling in the forests
            for (unsigned long cascade = 0; cascade < get_cascade_depth(); ++cascade)
            {
                // Each cascade uses a different set of pixels for its features.  We compute
                // their representations relative to the initial shape first.
                std::vector<unsigned long> anchor_idx; 
                std::vector<dlib::vector<float,2> > deltas;
                create_shape_relative_encoding(initial_shape, pixel_coordinates[cascade], anchor_idx, deltas);

                // First compute the feature_pixel_values for each training sample at this
                // level of the cascade.
                parallel_for(tp, 0, samples.size(), [&](unsigned long i)
                {
                    impl::extract_feature_pixel_values(images[samples[i].image_idx], samples[i].rect,
                                                 samples[i].current_shape, initial_shape, anchor_idx,
                                                 deltas, samples[i].feature_pixel_values);
                }, 1);

                // Now start building the trees at this cascade level.
                for (unsigned long i = 0; i < get_num_trees_per_cascade_level(); ++i)
                {
                    forests[cascade].push_back(make_regression_tree(tp, samples, pixel_coordinates[cascade]));

                    if (_verbose)
                    {
                        ++trees_fit_so_far;
                        pbar.print_status(trees_fit_so_far);
                    }
                }
            }

            if (_verbose)
                std::cout << "Training complete                          " << std::endl;

            return shape_predictor(initial_shape, forests, pixel_coordinates);
        }

    private:

        static void object_to_shape (
            const full_object_detection& obj,
            matrix<float,0,1>& shape,
            matrix<float,0,1>& present // a mask telling which elements of #shape are present.
        )
        {
            shape.set_size(obj.num_parts()*2);
            present.set_size(obj.num_parts()*2);
            const point_transform_affine tform_from_img = impl::normalizing_tform(obj.get_rect());
            for (unsigned long i = 0; i < obj.num_parts(); ++i)
            {
                if (obj.part(i) != OBJECT_PART_NOT_PRESENT)
                {
                    vector<float,2> p = tform_from_img(obj.part(i));
                    shape(2*i)   = p.x();
                    shape(2*i+1) = p.y();
                    present(2*i)   = 1;
                    present(2*i+1) = 1;
                }
                else
                {
                    shape(2*i)   = 0;
                    shape(2*i+1) = 0;
                    present(2*i)   = 0;
                    present(2*i+1) = 0;
                }
            }
        }

        template<typename feature_type>
        struct training_sample
        {
            /*!

            CONVENTION
                - feature_pixel_values.size() == get_feature_pool_size()
                - feature_pixel_values[j] == the value of the j-th feature pool
                  pixel when you look it up relative to the shape in current_shape.

                - target_shape == The truth shape.  Stays constant during the whole
                  training process (except for the parts that are not present, those are
                  always equal to the current_shape values).
                - present == 0/1 mask saying which parts of target_shape are present.
                - rect == the position of the object in the image_idx-th image.  All shape
                  coordinates are coded relative to this rectangle.
                - diff_shape == temporary value for holding difference between current
                  shape and target shape
            !*/

            unsigned long image_idx;
            rectangle rect;
            matrix<float,0,1> target_shape;
            matrix<float,0,1> present;

            matrix<float,0,1> current_shape;
            matrix<float,0,1> diff_shape;
            std::vector<feature_type> feature_pixel_values;

            void swap(training_sample& item)
            {
                std::swap(image_idx, item.image_idx);
                std::swap(rect, item.rect);
                target_shape.swap(item.target_shape);
                present.swap(item.present);
                current_shape.swap(item.current_shape);
                diff_shape.swap(item.diff_shape);
                feature_pixel_values.swap(item.feature_pixel_values);
            }
        };

        template<typename feature_type>
        impl::regression_tree make_regression_tree (
            thread_pool& tp,
            std::vector<training_sample<feature_type>>& samples,
            const std::vector<dlib::vector<float,2> >& pixel_coordinates
        ) const
        {
            using namespace impl;
            std::deque<std::pair<unsigned long, unsigned long> > parts;
            parts.push_back(std::make_pair(0, (unsigned long)samples.size()));

            impl::regression_tree tree;

            // walk the tree in breadth first order
            const unsigned long num_split_nodes = static_cast<unsigned long>(std::pow(2.0, (double)get_tree_depth())-1);
            std::vector<matrix<float,0,1> > sums(num_split_nodes*2+1);
            if (tp.num_threads_in_pool() > 1)
            {
                // Here we need to calculate shape differences and store sum of differences into sums[0]
                // to make it I am splitting of samples into blocks, each block will be processed by
                // separate thread, and the sum of differences of each block is stored into separate
                // place in block_sums

                const unsigned long num_workers = std::max(1UL, tp.num_threads_in_pool());
                const unsigned long num =  samples.size();
                const unsigned long block_size = std::max(1UL, (num + num_workers - 1) / num_workers);
                std::vector<matrix<float,0,1> > block_sums(num_workers);

                parallel_for(tp, 0, num_workers, [&](unsigned long block)
                {
                    const unsigned long block_begin = block * block_size;
                    const unsigned long block_end =  std::min(num, block_begin + block_size);
                    for (unsigned long i = block_begin; i < block_end; ++i)
                    {
                        samples[i].diff_shape = samples[i].target_shape - samples[i].current_shape;
                        block_sums[block] += samples[i].diff_shape;
                    }
                }, 1);

                // now calculate the total result from separate blocks
                for (unsigned long i = 0; i < block_sums.size(); ++i)
                    sums[0] += block_sums[i];
            }
            else
            {
                // synchronous implementation
                for (unsigned long i = 0; i < samples.size(); ++i)
                {
                    samples[i].diff_shape = samples[i].target_shape - samples[i].current_shape;
                    sums[0] += samples[i].diff_shape;
                }
            }

            for (unsigned long i = 0; i < num_split_nodes; ++i)
            {
                std::pair<unsigned long,unsigned long> range = parts.front();
                parts.pop_front();

                const impl::split_feature split = generate_split(tp, samples, range.first,
                    range.second, pixel_coordinates, sums[i], sums[left_child(i)],
                    sums[right_child(i)]);
                tree.splits.push_back(split);
                const unsigned long mid = partition_samples(split, samples, range.first, range.second);

                parts.push_back(std::make_pair(range.first, mid));
                parts.push_back(std::make_pair(mid, range.second));
            }

            // Now all the parts contain the ranges for the leaves so we can use them to
            // compute the average leaf values.
            matrix<float,0,1> present_counts(samples[0].target_shape.size());
            tree.leaf_values.resize(parts.size());
            for (unsigned long i = 0; i < parts.size(); ++i)
            {
                // Get the present counts for each dimension so we can divide each
                // dimension by the number of observations we have on it to find the mean
                // displacement in each leaf.
                present_counts = 0;
                for (unsigned long j = parts[i].first; j < parts[i].second; ++j)
                    present_counts += samples[j].present;
                present_counts = dlib::reciprocal(present_counts);

                if (parts[i].second != parts[i].first)
                    tree.leaf_values[i] = pointwise_multiply(present_counts,sums[num_split_nodes+i]*get_nu());
                else
                    tree.leaf_values[i] = zeros_matrix(samples[0].target_shape);

                // now adjust the current shape based on these predictions
                parallel_for(tp, parts[i].first, parts[i].second, [&](unsigned long j)
                {
                    samples[j].current_shape += tree.leaf_values[i];
                    // For parts that aren't present in the training data, we just make
                    // sure that the target shape always matches and therefore gives zero
                    // error.  So this makes the algorithm simply ignore non-present
                    // landmarks.
                    for (long k = 0; k < samples[j].present.size(); ++k)
                    {
                        // if this part is not present
                        if (samples[j].present(k) == 0)
                            samples[j].target_shape(k) = samples[j].current_shape(k);
                    }
                }, 1);
            }

            return tree;
        }

        impl::split_feature randomly_generate_split_feature (
            const std::vector<dlib::vector<float,2> >& pixel_coordinates
        ) const
        {
            const double lambda = get_lambda(); 
            impl::split_feature feat;
            double accept_prob;
            do 
            {
                feat.idx1   = rnd.get_random_32bit_number()%get_feature_pool_size();
                feat.idx2   = rnd.get_random_32bit_number()%get_feature_pool_size();
                const double dist = length(pixel_coordinates[feat.idx1]-pixel_coordinates[feat.idx2]);
                accept_prob = std::exp(-dist/lambda);
            }
            while(feat.idx1 == feat.idx2 || !(accept_prob > rnd.get_random_double()));

            feat.thresh = (rnd.get_random_double()*256 - 128)/2.0;

            return feat;
        }

        template<typename feature_type>
        impl::split_feature generate_split (
            thread_pool& tp,
            const std::vector<training_sample<feature_type>>& samples,
            unsigned long begin,
            unsigned long end,
            const std::vector<dlib::vector<float,2> >& pixel_coordinates,
            const matrix<float,0,1>& sum,
            matrix<float,0,1>& left_sum,
            matrix<float,0,1>& right_sum 
        ) const
        {
            // generate a bunch of random splits and test them and return the best one.

            const unsigned long num_test_splits = get_num_test_splits();  

            // sample the random features we test in this function
            std::vector<impl::split_feature> feats;
            feats.reserve(num_test_splits);
            for (unsigned long i = 0; i < num_test_splits; ++i)
                feats.push_back(randomly_generate_split_feature(pixel_coordinates));

            std::vector<matrix<float,0,1> > left_sums(num_test_splits);
            std::vector<unsigned long> left_cnt(num_test_splits);

            const unsigned long num_workers = std::max(1UL, tp.num_threads_in_pool());
            const unsigned long block_size = std::max(1UL, (num_test_splits + num_workers - 1) / num_workers);

            // now compute the sums of vectors that go left for each feature
            parallel_for(tp, 0, num_workers, [&](unsigned long block)
            {
                const unsigned long block_begin = block * block_size;
                const unsigned long block_end   = std::min(block_begin + block_size, num_test_splits);

                for (unsigned long j = begin; j < end; ++j)
                {
                    for (unsigned long i = block_begin; i < block_end; ++i)
                    {
                        if ((float)samples[j].feature_pixel_values[feats[i].idx1] - (float)samples[j].feature_pixel_values[feats[i].idx2] > feats[i].thresh)
                        {
                            left_sums[i] += samples[j].diff_shape;
                            ++left_cnt[i];
                        }
                    }
                }

            }, 1);

            // now figure out which feature is the best
            double best_score = -1;
            unsigned long best_feat = 0;
            matrix<float,0,1> temp;
            for (unsigned long i = 0; i < num_test_splits; ++i)
            {
                // check how well the feature splits the space.
                double score = 0;
                unsigned long right_cnt = end-begin-left_cnt[i];
                if (left_cnt[i] != 0 && right_cnt != 0)
                {
                    temp = sum - left_sums[i];
                    score = dot(left_sums[i],left_sums[i])/left_cnt[i] + dot(temp,temp)/right_cnt;
                    if (score > best_score)
                    {
                        best_score = score;
                        best_feat = i;
                    }
                }
            }

            left_sums[best_feat].swap(left_sum);
            if (left_sum.size() != 0)
            {
                right_sum = sum - left_sum;
            }
            else
            {
                right_sum = sum;
                left_sum = zeros_matrix(sum);
            }
            return feats[best_feat];
        }

        template<typename feature_type>
        unsigned long partition_samples (
            const impl::split_feature& split,
            std::vector<training_sample<feature_type>>& samples,
            unsigned long begin,
            unsigned long end
        ) const
        {
            // splits samples based on split (sorta like in quick sort) and returns the mid
            // point.  make sure you return the mid in a way compatible with how we walk
            // through the tree.

            unsigned long i = begin;
            for (unsigned long j = begin; j < end; ++j)
            {
                if ((float)samples[j].feature_pixel_values[split.idx1] - (float)samples[j].feature_pixel_values[split.idx2] > split.thresh)
                {
                    samples[i].swap(samples[j]);
                    ++i;
                }
            }
            return i;
        }



        template<typename feature_type>
        matrix<float,0,1> populate_training_sample_shapes(
            const std::vector<std::vector<full_object_detection> >& objects,
            std::vector<training_sample<feature_type>>& samples
        ) const
        {
            samples.clear();
            matrix<float,0,1> mean_shape;
            matrix<float,0,1> count;
            // first fill out the target shapes
            for (unsigned long i = 0; i < objects.size(); ++i)
            {
                for (unsigned long j = 0; j < objects[i].size(); ++j)
                {
                    training_sample<feature_type> sample;
                    sample.image_idx = i;
                    sample.rect = objects[i][j].get_rect();
                    object_to_shape(objects[i][j], sample.target_shape, sample.present);
                    for (unsigned long itr = 0; itr < get_oversampling_amount(); ++itr)
                        samples.push_back(sample);
                    mean_shape += sample.target_shape;
                    count += sample.present;
                }
            }

            mean_shape = pointwise_multiply(mean_shape,reciprocal(count));

            // now go pick random initial shapes
            for (unsigned long i = 0; i < samples.size(); ++i)
            {
                if ((i%get_oversampling_amount()) == 0)
                {
                    // The mean shape is what we really use as an initial shape so always
                    // include it in the training set as an example starting shape.
                    samples[i].current_shape = mean_shape;
                }
                else
                {
                    samples[i].current_shape.set_size(0);

                    matrix<float,0,1> hits(mean_shape.size());
                    hits = 0;

                    int iter = 0;
                    // Pick a few samples at random and randomly average them together to
                    // make the initial shape.  Note that we make sure we get at least one
                    // observation (i.e. non-OBJECT_PART_NOT_PRESENT) on each part
                    // location.
                    while(min(hits) == 0 || iter < 2)
                    {
                        ++iter;
                        const unsigned long rand_idx = rnd.get_random_32bit_number()%samples.size();
                        const double alpha = rnd.get_random_double()+0.1;
                        samples[i].current_shape += alpha*samples[rand_idx].target_shape;
                        hits += alpha*samples[rand_idx].present;
                    }
                    samples[i].current_shape = pointwise_multiply(samples[i].current_shape, reciprocal(hits));
                }

            }
            for (unsigned long i = 0; i < samples.size(); ++i)
            {
                for (long k = 0; k < samples[i].present.size(); ++k)
                {
                    // if this part is not present
                    if (samples[i].present(k) == 0)
                        samples[i].target_shape(k) = samples[i].current_shape(k);
                }
            }


            return mean_shape;
        }


        void randomly_sample_pixel_coordinates (
            std::vector<dlib::vector<float,2> >& pixel_coordinates,
            const double min_x,
            const double min_y,
            const double max_x,
            const double max_y
        ) const
        /*!
            ensures
                - #pixel_coordinates.size() == get_feature_pool_size() 
                - for all valid i:
                    - pixel_coordinates[i] == a point in the box defined by the min/max x/y arguments.
        !*/
        {
            pixel_coordinates.resize(get_feature_pool_size());
            for (unsigned long i = 0; i < get_feature_pool_size(); ++i)
            {
                pixel_coordinates[i].x() = rnd.get_random_double()*(max_x-min_x) + min_x;
                pixel_coordinates[i].y() = rnd.get_random_double()*(max_y-min_y) + min_y;
            }
        }

        std::vector<std::vector<dlib::vector<float,2> > > randomly_sample_pixel_coordinates (
            const matrix<float,0,1>& initial_shape
        ) const
        {
            const double padding = get_feature_pool_region_padding();
            // Figure figure out the bounds on the object shapes.  We will sample uniformly
            // from this box.
            matrix<float> temp = reshape(initial_shape, initial_shape.size()/2, 2);
            const double min_x = min(colm(temp,0))-padding;
            const double min_y = min(colm(temp,1))-padding;
            const double max_x = max(colm(temp,0))+padding;
            const double max_y = max(colm(temp,1))+padding;

            std::vector<std::vector<dlib::vector<float,2> > > pixel_coordinates;
            pixel_coordinates.resize(get_cascade_depth());
            for (unsigned long i = 0; i < get_cascade_depth(); ++i)
                randomly_sample_pixel_coordinates(pixel_coordinates[i], min_x, min_y, max_x, max_y);
            return pixel_coordinates;
        }



        mutable dlib::rand rnd;

        unsigned long _cascade_depth;
        unsigned long _tree_depth;
        unsigned long _num_trees_per_cascade_level;
        double _nu;
        unsigned long _oversampling_amount;
        unsigned long _feature_pool_size;
        double _lambda;
        unsigned long _num_test_splits;
        double _feature_pool_region_padding;
        bool _verbose;
        unsigned long _num_threads;
    };

// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------

    template <
        typename image_array
        >
    double test_shape_predictor (
        const shape_predictor& sp,
        const image_array& images,
        const std::vector<std::vector<full_object_detection> >& objects,
        const std::vector<std::vector<double> >& scales
    )
    {
        // make sure requires clause is not broken
#ifdef ENABLE_ASSERTS
        DLIB_CASSERT( images.size() == objects.size() ,
            "\t double test_shape_predictor()"
            << "\n\t Invalid inputs were given to this function. "
            << "\n\t images.size():  " << images.size() 
            << "\n\t objects.size(): " << objects.size() 
        );
        for (unsigned long i = 0; i < objects.size(); ++i)
        {
            for (unsigned long j = 0; j < objects[i].size(); ++j)
            {
                DLIB_CASSERT(objects[i][j].num_parts() == sp.num_parts(), 
                    "\t double test_shape_predictor()"
                    << "\n\t Invalid inputs were given to this function. "
                    << "\n\t objects["<<i<<"]["<<j<<"].num_parts(): " << objects[i][j].num_parts()
                    << "\n\t sp.num_parts(): " << sp.num_parts()
                );
            }
            if (scales.size() != 0)
            {
                DLIB_CASSERT(objects[i].size() == scales[i].size(), 
                    "\t double test_shape_predictor()"
                    << "\n\t Invalid inputs were given to this function. "
                    << "\n\t objects["<<i<<"].size(): " << objects[i].size()
                    << "\n\t scales["<<i<<"].size(): " << scales[i].size()
                );

            }
        }
#endif

        running_stats<double> rs;
        for (unsigned long i = 0; i < objects.size(); ++i)
        {
            for (unsigned long j = 0; j < objects[i].size(); ++j)
            {
                // Just use a scale of 1 (i.e. no scale at all) if the caller didn't supply
                // any scales.
                const double scale = scales.size()==0 ? 1 : scales[i][j]; 

                full_object_detection det = sp(images[i], objects[i][j].get_rect());

                for (unsigned long k = 0; k < det.num_parts(); ++k)
                {
                    if (objects[i][j].part(k) != OBJECT_PART_NOT_PRESENT)
                    {
                        double score = length(det.part(k) - objects[i][j].part(k))/scale;
                        rs.add(score);
                    }
                }
            }
        }
        return rs.mean();
    }

// ----------------------------------------------------------------------------------------

    template <
        typename image_array
        >
    double test_shape_predictor (
        const shape_predictor& sp,
        const image_array& images,
        const std::vector<std::vector<full_object_detection> >& objects
    )
    {
        std::vector<std::vector<double> > no_scales;
        return test_shape_predictor(sp, images, objects, no_scales);
    }

// ----------------------------------------------------------------------------------------

}

#endif // DLIB_SHAPE_PREDICToR_H_