// Copyright (C) 2013  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_SEQUENCE_SeGMENTER_H_h_
#define DLIB_SEQUENCE_SeGMENTER_H_h_

#include "sequence_segmenter_abstract.h"
#include "../matrix.h"
#include "sequence_labeler.h"
#include <vector>

namespace dlib
{
    // This namespace contains implementation details for the sequence_segmenter.
    namespace impl_ss
    {

    // ------------------------------------------------------------------------------------

        // BIO/BILOU labels
        const unsigned int BEGIN   = 0;
        const unsigned int INSIDE  = 1;
        const unsigned int OUTSIDE = 2;
        const unsigned int LAST    = 3;
        const unsigned int UNIT    = 4;


    // ------------------------------------------------------------------------------------

        template <typename ss_feature_extractor>
        class feature_extractor
        {
            /*!
                WHAT THIS OBJECT REPRESENTS
                    This is a feature extractor for a sequence_labeler.  It serves to map
                    the interface defined by a sequence_labeler into the kind of interface
                    defined for a sequence_segmenter.
            !*/

        public:
            typedef typename ss_feature_extractor::sequence_type sequence_type;

            ss_feature_extractor fe;

            feature_extractor() {}
            feature_extractor(const ss_feature_extractor& ss_fe_) : fe(ss_fe_) {}

            unsigned long num_nonnegative_weights (
            ) const
            {
                const unsigned long NL = ss_feature_extractor::use_BIO_model ? 3 : 5;
                if (ss_feature_extractor::allow_negative_weights)
                {
                    return 0;
                }
                else
                {
                    // We make everything non-negative except for the label transition
                    // and bias features.
                    return num_features() - NL*NL - NL;
                }
            }

            friend void serialize(const feature_extractor& item, std::ostream& out) 
            {
                serialize(item.fe, out);
            }

            friend void deserialize(feature_extractor& item, std::istream& in) 
            {
                deserialize(item.fe, in);
            }

            unsigned long num_features() const
            {
                const unsigned long NL = ss_feature_extractor::use_BIO_model ? 3 : 5;
                if (ss_feature_extractor::use_high_order_features)
                    return NL + NL*NL + (NL*NL+NL)*fe.num_features()*fe.window_size();
                else
                    return NL + NL*NL + NL*fe.num_features()*fe.window_size();
            }

            unsigned long order() const 
            { 
                return 1; 
            }

            unsigned long num_labels() const 
            { 
                if (ss_feature_extractor::use_BIO_model)
                    return 3;
                else
                    return 5;
            }

        private:

            template <typename feature_setter>
            struct dot_functor
            {
                /*!
                    WHAT THIS OBJECT REPRESENTS
                        This class wraps the feature_setter used by a sequence_labeler
                        and turns it into the kind needed by a sequence_segmenter.
                !*/

                dot_functor(feature_setter& set_feature_, unsigned long offset_) : 
                    set_feature(set_feature_), offset(offset_) {}

                feature_setter& set_feature;
                unsigned long offset;

                inline void operator() (
                    unsigned long feat_index
                )
                {
                    set_feature(offset+feat_index);
                }

                inline void operator() (
                    unsigned long feat_index,
                    double feat_value
                )
                {
                    set_feature(offset+feat_index, feat_value);
                }
            };

        public:

            template <typename EXP>
            bool reject_labeling (
                const sequence_type& x,
                const matrix_exp<EXP>& y,
                unsigned long pos
            ) const
            {
                if (ss_feature_extractor::use_BIO_model)
                {
                    // Don't allow BIO label patterns that don't correspond to a sensical
                    // segmentation. 
                    if (y.size() > 1 && y(0) == INSIDE && y(1) == OUTSIDE)
                        return true;
                    if (y.size() == 1 && y(0) == INSIDE)
                        return true;
                }
                else
                {
                    // Don't allow BILOU label patterns that don't correspond to a sensical
                    // segmentation. 
                    if (y.size() > 1)
                    {
                        if (y(1) == BEGIN && y(0) == OUTSIDE)
                            return true;
                        if (y(1) == BEGIN && y(0) == UNIT)
                            return true;
                        if (y(1) == BEGIN && y(0) == BEGIN)
                            return true;

                        if (y(1) == INSIDE && y(0) == BEGIN)
                            return true;
                        if (y(1) == INSIDE && y(0) == OUTSIDE)
                            return true;
                        if (y(1) == INSIDE && y(0) == UNIT)
                            return true;

                        if (y(1) == OUTSIDE && y(0) == INSIDE)
                            return true;
                        if (y(1) == OUTSIDE && y(0) == LAST)
                            return true;

                        if (y(1) == LAST && y(0) == INSIDE)
                            return true;
                        if (y(1) == LAST && y(0) == LAST)
                            return true;

                        if (y(1) == UNIT && y(0) == INSIDE)
                            return true;
                        if (y(1) == UNIT && y(0) == LAST)
                            return true;

                        // if at the end of the sequence
                        if (pos == x.size()-1)
                        {
                            if (y(0) == BEGIN)
                                return true;
                            if (y(0) == INSIDE)
                                return true;
                        }
                    }
                    else
                    {
                        if (y(0) == INSIDE)
                            return true;
                        if (y(0) == LAST)
                            return true;

                        // if at the end of the sequence
                        if (pos == x.size()-1)
                        {
                            if (y(0) == BEGIN)
                                return true;
                        }
                    }
                }
                return false;
            }

            template <typename feature_setter, typename EXP>
            void get_features (
                feature_setter& set_feature,
                const sequence_type& x,
                const matrix_exp<EXP>& y,
                unsigned long position
            ) const
            {
                unsigned long offset = 0;

                const int window_size = fe.window_size();

                const int base_dims = fe.num_features();
                for (int i = 0; i < window_size; ++i)
                {
                    const long pos = i-window_size/2 + static_cast<long>(position);
                    if (0 <= pos && pos < (long)x.size())
                    {
                        const unsigned long off1 = y(0)*base_dims;
                        dot_functor<feature_setter> fs1(set_feature, offset+off1);
                        fe.get_features(fs1, x, pos);

                        if (ss_feature_extractor::use_high_order_features && y.size() > 1)
                        {
                            const unsigned long off2 = num_labels()*base_dims + (y(0)*num_labels()+y(1))*base_dims;
                            dot_functor<feature_setter> fs2(set_feature, offset+off2);
                            fe.get_features(fs2, x, pos);
                        }
                    }

                    if (ss_feature_extractor::use_high_order_features)
                        offset += num_labels()*base_dims + num_labels()*num_labels()*base_dims;
                    else
                        offset += num_labels()*base_dims;
                }

                // Pull out an indicator feature for the type of transition between the
                // previous label and the current label.
                if (y.size() > 1)
                    set_feature(offset + y(1)*num_labels() + y(0));

                offset += num_labels()*num_labels();
                // pull out an indicator feature for the current label.  This is the per
                // label bias.
                set_feature(offset + y(0));
            }
        };

    } // end namespace impl_ss

// ----------------------------------------------------------------------------------------

    template <
        typename feature_extractor
        >
    unsigned long total_feature_vector_size (
        const feature_extractor& fe
    )
    {
        const unsigned long NL = feature_extractor::use_BIO_model ? 3 : 5;
        if (feature_extractor::use_high_order_features)
            return NL + NL*NL + (NL*NL+NL)*fe.num_features()*fe.window_size();
        else
            return NL + NL*NL + NL*fe.num_features()*fe.window_size();
    }

// ----------------------------------------------------------------------------------------

    template <
        typename feature_extractor
        >
    class sequence_segmenter
    {
    public:
        typedef typename feature_extractor::sequence_type sample_sequence_type;
        typedef std::vector<std::pair<unsigned long, unsigned long> > segmented_sequence_type;


        sequence_segmenter()
        {
#ifdef ENABLE_ASSERTS
            const feature_extractor& fe = labeler.get_feature_extractor().fe;
            DLIB_ASSERT(fe.window_size() >= 1 && fe.num_features() >= 1,
                "\t sequence_segmenter::sequence_segmenter()"
                << "\n\t An invalid feature extractor was supplied."
                << "\n\t fe.window_size():  " << fe.window_size() 
                << "\n\t fe.num_features(): " << fe.num_features() 
                << "\n\t this: " << this
            );
#endif
        }

        explicit sequence_segmenter(
            const matrix<double,0,1>& weights
        ) : 
            labeler(weights)
        {
#ifdef ENABLE_ASSERTS
            const feature_extractor& fe = labeler.get_feature_extractor().fe;
            // make sure requires clause is not broken
            DLIB_ASSERT(total_feature_vector_size(fe) == (unsigned long)weights.size(),
                "\t sequence_segmenter::sequence_segmenter(weights)"
                << "\n\t These sizes should match"
                << "\n\t total_feature_vector_size(fe):  " << total_feature_vector_size(fe) 
                << "\n\t weights.size(): " << weights.size() 
                << "\n\t this: " << this
                );
            DLIB_ASSERT(fe.window_size() >= 1 && fe.num_features() >= 1,
                "\t sequence_segmenter::sequence_segmenter()"
                << "\n\t An invalid feature extractor was supplied."
                << "\n\t fe.window_size():  " << fe.window_size() 
                << "\n\t fe.num_features(): " << fe.num_features() 
                << "\n\t this: " << this
            );
#endif
        }

        sequence_segmenter(
            const matrix<double,0,1>& weights,
            const feature_extractor& fe
        ) :
            labeler(weights, impl_ss::feature_extractor<feature_extractor>(fe))
        {
            // make sure requires clause is not broken
            DLIB_ASSERT(total_feature_vector_size(fe) == (unsigned long)weights.size(),
                "\t sequence_segmenter::sequence_segmenter(weights,fe)"
                << "\n\t These sizes should match"
                << "\n\t total_feature_vector_size(fe):  " << total_feature_vector_size(fe) 
                << "\n\t weights.size(): " << weights.size() 
                << "\n\t this: " << this
                );
            DLIB_ASSERT(fe.window_size() >= 1 && fe.num_features() >= 1,
                "\t sequence_segmenter::sequence_segmenter()"
                << "\n\t An invalid feature extractor was supplied."
                << "\n\t fe.window_size():  " << fe.window_size() 
                << "\n\t fe.num_features(): " << fe.num_features() 
                << "\n\t this: " << this
            );
        }

        const feature_extractor& get_feature_extractor (
        ) const { return labeler.get_feature_extractor().fe; }

        const matrix<double,0,1>& get_weights (
        ) const { return labeler.get_weights(); }

        segmented_sequence_type operator() (
            const sample_sequence_type& x
        ) const
        {
            segmented_sequence_type y;
            segment_sequence(x,y);
            return y;
        }

        void segment_sequence (
            const sample_sequence_type& x,
            segmented_sequence_type& y
        ) const
        {
            y.clear();
            std::vector<unsigned long> labels;
            labeler.label_sequence(x, labels);

            if (feature_extractor::use_BIO_model)
            {
                // Convert from BIO tagging to the explicit segments representation.
                for (unsigned long i = 0; i < labels.size(); ++i)
                {
                    if (labels[i] == impl_ss::BEGIN)
                    {
                        const unsigned long begin = i;
                        ++i;
                        while (i < labels.size() && labels[i] == impl_ss::INSIDE)
                            ++i;

                        y.push_back(std::make_pair(begin, i));
                        --i;
                    }
                }
            }
            else
            {
                // Convert from BILOU tagging to the explicit segments representation.
                for (unsigned long i = 0; i < labels.size(); ++i)
                {
                    if (labels[i] == impl_ss::BEGIN)
                    {
                        const unsigned long begin = i;
                        ++i;
                        while (i < labels.size() && labels[i] == impl_ss::INSIDE)
                            ++i;

                        y.push_back(std::make_pair(begin, i+1));
                    }
                    else if (labels[i] == impl_ss::UNIT)
                    {
                        y.push_back(std::make_pair(i, i+1));
                    }
                }
            }
        }

        friend void serialize(const sequence_segmenter& item, std::ostream& out)
        {
            int version = 1;
            serialize(version, out);

            // Save these just so we can compare them when we deserialize and make
            // sure the feature_extractor being used is compatible with the model being
            // loaded.
            serialize(feature_extractor::use_BIO_model, out);
            serialize(feature_extractor::use_high_order_features, out);
            serialize(total_feature_vector_size(item.get_feature_extractor()), out);

            serialize(item.labeler, out);
        }

        friend void deserialize(sequence_segmenter& item, std::istream& in)
        {
            int version = 0;
            deserialize(version, in);
            if (version != 1)
                throw serialization_error("Unexpected version found while deserializing dlib::sequence_segmenter.");

            // Try to check if the saved model is compatible with the current feature
            // extractor.
            bool use_BIO_model, use_high_order_features;
            unsigned long dims;
            deserialize(use_BIO_model, in);
            deserialize(use_high_order_features, in);
            deserialize(dims, in);
            deserialize(item.labeler, in);
            if (use_BIO_model != feature_extractor::use_BIO_model)
            {
                throw serialization_error("Incompatible feature extractor found while deserializing "
                    "dlib::sequence_segmenter. Wrong value of use_BIO_model.");
            }
            if (use_high_order_features != feature_extractor::use_high_order_features)
            {
                throw serialization_error("Incompatible feature extractor found while deserializing "
                    "dlib::sequence_segmenter. Wrong value of use_high_order_features.");
            }
            if (dims != total_feature_vector_size(item.get_feature_extractor()))
            {
                throw serialization_error("Incompatible feature extractor found while deserializing "
                    "dlib::sequence_segmenter. Wrong value of total_feature_vector_size().");
            }
        }

    private:
        sequence_labeler<impl_ss::feature_extractor<feature_extractor> > labeler;
    };

// ----------------------------------------------------------------------------------------

}

#endif // DLIB_SEQUENCE_SeGMENTER_H_h_